model.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. from __future__ import division
  2. import os
  3. import time
  4. from glob import glob
  5. import tensorflow as tf
  6. import numpy as np
  7. from collections import namedtuple
  8. from module import *
  9. from utils import *
  10. class cyclegan(object):
  11. def __init__(self, sess, args):
  12. self.sess = sess
  13. self.batch_size = args.batch_size
  14. self.image_size = args.fine_size
  15. self.input_c_dim = args.input_nc
  16. self.output_c_dim = args.output_nc
  17. self.L1_lambda = args.L1_lambda
  18. self.dataset_dir = args.dataset_dir
  19. self.discriminator = discriminator
  20. if args.use_resnet:
  21. self.generator = generator_resnet
  22. else:
  23. self.generator = generator_unet
  24. if args.use_lsgan:
  25. self.criterionGAN = mae_criterion
  26. else:
  27. self.criterionGAN = sce_criterion
  28. OPTIONS = namedtuple('OPTIONS', 'batch_size image_size \
  29. gf_dim df_dim output_c_dim is_training')
  30. self.options = OPTIONS._make((args.batch_size, args.fine_size,
  31. args.ngf, args.ndf, args.output_nc,
  32. args.phase == 'train'))
  33. self._build_model()
  34. self.saver = tf.train.Saver()
  35. self.pool = ImagePool(args.max_size)
  36. def _build_model(self):
  37. self.real_data = tf.placeholder(tf.float32,
  38. [None, self.image_size, self.image_size,
  39. self.input_c_dim + self.output_c_dim],
  40. name='real_A_and_B_images')
  41. self.real_A = self.real_data[:, :, :, :self.input_c_dim]
  42. self.real_B = self.real_data[:, :, :, self.input_c_dim:self.input_c_dim + self.output_c_dim]
  43. self.fake_B = self.generator(self.real_A, self.options, False, name="generatorA2B")
  44. self.fake_A_ = self.generator(self.fake_B, self.options, False, name="generatorB2A")
  45. self.fake_A = self.generator(self.real_B, self.options, True, name="generatorB2A")
  46. self.fake_B_ = self.generator(self.fake_A, self.options, True, name="generatorA2B")
  47. self.DB_fake = self.discriminator(self.fake_B, self.options, reuse=False, name="discriminatorB")
  48. self.DA_fake = self.discriminator(self.fake_A, self.options, reuse=False, name="discriminatorA")
  49. self.g_loss_a2b = self.criterionGAN(self.DB_fake, tf.ones_like(self.DB_fake)) \
  50. + self.L1_lambda * abs_criterion(self.real_A, self.fake_A_) \
  51. + self.L1_lambda * abs_criterion(self.real_B, self.fake_B_)
  52. self.g_loss_b2a = self.criterionGAN(self.DA_fake, tf.ones_like(self.DA_fake)) \
  53. + self.L1_lambda * abs_criterion(self.real_A, self.fake_A_) \
  54. + self.L1_lambda * abs_criterion(self.real_B, self.fake_B_)
  55. self.g_loss = self.criterionGAN(self.DA_fake, tf.ones_like(self.DA_fake)) \
  56. + self.criterionGAN(self.DB_fake, tf.ones_like(self.DB_fake)) \
  57. + self.L1_lambda * abs_criterion(self.real_A, self.fake_A_) \
  58. + self.L1_lambda * abs_criterion(self.real_B, self.fake_B_)
  59. self.fake_A_sample = tf.placeholder(tf.float32,
  60. [None, self.image_size, self.image_size,
  61. self.input_c_dim], name='fake_A_sample')
  62. self.fake_B_sample = tf.placeholder(tf.float32,
  63. [None, self.image_size, self.image_size,
  64. self.output_c_dim], name='fake_B_sample')
  65. self.DB_real = self.discriminator(self.real_B, self.options, reuse=True, name="discriminatorB")
  66. self.DA_real = self.discriminator(self.real_A, self.options, reuse=True, name="discriminatorA")
  67. self.DB_fake_sample = self.discriminator(self.fake_B_sample, self.options, reuse=True, name="discriminatorB")
  68. self.DA_fake_sample = self.discriminator(self.fake_A_sample, self.options, reuse=True, name="discriminatorA")
  69. self.db_loss_real = self.criterionGAN(self.DB_real, tf.ones_like(self.DB_real))
  70. self.db_loss_fake = self.criterionGAN(self.DB_fake_sample, tf.zeros_like(self.DB_fake_sample))
  71. self.db_loss = (self.db_loss_real + self.db_loss_fake) / 2
  72. self.da_loss_real = self.criterionGAN(self.DA_real, tf.ones_like(self.DA_real))
  73. self.da_loss_fake = self.criterionGAN(self.DA_fake_sample, tf.zeros_like(self.DA_fake_sample))
  74. self.da_loss = (self.da_loss_real + self.da_loss_fake) / 2
  75. self.d_loss = self.da_loss + self.db_loss
  76. self.g_loss_a2b_sum = tf.summary.scalar("g_loss_a2b", self.g_loss_a2b)
  77. self.g_loss_b2a_sum = tf.summary.scalar("g_loss_b2a", self.g_loss_b2a)
  78. self.g_loss_sum = tf.summary.scalar("g_loss", self.g_loss)
  79. self.g_sum = tf.summary.merge([self.g_loss_a2b_sum, self.g_loss_b2a_sum, self.g_loss_sum])
  80. self.db_loss_sum = tf.summary.scalar("db_loss", self.db_loss)
  81. self.da_loss_sum = tf.summary.scalar("da_loss", self.da_loss)
  82. self.d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)
  83. self.db_loss_real_sum = tf.summary.scalar("db_loss_real", self.db_loss_real)
  84. self.db_loss_fake_sum = tf.summary.scalar("db_loss_fake", self.db_loss_fake)
  85. self.da_loss_real_sum = tf.summary.scalar("da_loss_real", self.da_loss_real)
  86. self.da_loss_fake_sum = tf.summary.scalar("da_loss_fake", self.da_loss_fake)
  87. self.d_sum = tf.summary.merge(
  88. [self.da_loss_sum, self.da_loss_real_sum, self.da_loss_fake_sum,
  89. self.db_loss_sum, self.db_loss_real_sum, self.db_loss_fake_sum,
  90. self.d_loss_sum]
  91. )
  92. self.test_A = tf.placeholder(tf.float32,
  93. [None, self.image_size, self.image_size,
  94. self.input_c_dim], name='test_A')
  95. self.test_B = tf.placeholder(tf.float32,
  96. [None, self.image_size, self.image_size,
  97. self.output_c_dim], name='test_B')
  98. self.testB = self.generator(self.test_A, self.options, True, name="generatorA2B")
  99. self.testA = self.generator(self.test_B, self.options, True, name="generatorB2A")
  100. t_vars = tf.trainable_variables()
  101. self.d_vars = [var for var in t_vars if 'discriminator' in var.name]
  102. self.g_vars = [var for var in t_vars if 'generator' in var.name]
  103. for var in t_vars: print(var.name)
  104. def train(self, args):
  105. """Train cyclegan"""
  106. self.lr = tf.placeholder(tf.float32, None, name='learning_rate')
  107. self.d_optim = tf.train.AdamOptimizer(self.lr, beta1=args.beta1) \
  108. .minimize(self.d_loss, var_list=self.d_vars)
  109. self.g_optim = tf.train.AdamOptimizer(self.lr, beta1=args.beta1) \
  110. .minimize(self.g_loss, var_list=self.g_vars)
  111. init_op = tf.global_variables_initializer()
  112. self.sess.run(init_op)
  113. self.writer = tf.summary.FileWriter("./logs", self.sess.graph)
  114. counter = 1
  115. start_time = time.time()
  116. if args.continue_train:
  117. if self.load(args.checkpoint_dir):
  118. print(" [*] Load SUCCESS")
  119. else:
  120. print(" [!] Load failed...")
  121. for epoch in range(args.epoch):
  122. dataA = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/trainA'))
  123. dataB = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/trainB'))
  124. np.random.shuffle(dataA)
  125. np.random.shuffle(dataB)
  126. batch_idxs = min(min(len(dataA), len(dataB)), args.train_size) // self.batch_size
  127. lr = args.lr if epoch < args.epoch_step else args.lr*(args.epoch-epoch)/(args.epoch-args.epoch_step)
  128. for idx in range(0, batch_idxs):
  129. batch_files = list(zip(dataA[idx * self.batch_size:(idx + 1) * self.batch_size],
  130. dataB[idx * self.batch_size:(idx + 1) * self.batch_size]))
  131. batch_images = [load_train_data(batch_file, args.load_size, args.fine_size) for batch_file in batch_files]
  132. batch_images = np.array(batch_images).astype(np.float32)
  133. # Update G network and record fake outputs
  134. fake_A, fake_B, _, summary_str = self.sess.run(
  135. [self.fake_A, self.fake_B, self.g_optim, self.g_sum],
  136. feed_dict={self.real_data: batch_images, self.lr: lr})
  137. self.writer.add_summary(summary_str, counter)
  138. [fake_A, fake_B] = self.pool([fake_A, fake_B])
  139. # Update D network
  140. _, summary_str = self.sess.run(
  141. [self.d_optim, self.d_sum],
  142. feed_dict={self.real_data: batch_images,
  143. self.fake_A_sample: fake_A,
  144. self.fake_B_sample: fake_B,
  145. self.lr: lr})
  146. self.writer.add_summary(summary_str, counter)
  147. counter += 1
  148. print(("Epoch: [%2d] [%4d/%4d] time: %4.4f" % (
  149. epoch, idx, batch_idxs, time.time() - start_time)))
  150. if np.mod(counter, args.print_freq) == 1:
  151. self.sample_model(args.sample_dir, epoch, idx)
  152. if np.mod(counter, args.save_freq) == 2:
  153. self.save(args.checkpoint_dir, counter)
  154. def save(self, checkpoint_dir, step):
  155. model_name = "cyclegan.model"
  156. model_dir = "%s_%s" % (self.dataset_dir, self.image_size)
  157. checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
  158. if not os.path.exists(checkpoint_dir):
  159. os.makedirs(checkpoint_dir)
  160. self.saver.save(self.sess,
  161. os.path.join(checkpoint_dir, model_name),
  162. global_step=step)
  163. def load(self, checkpoint_dir):
  164. print(" [*] Reading checkpoint...")
  165. model_dir = "%s_%s" % (self.dataset_dir, self.image_size)
  166. checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
  167. ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
  168. if ckpt and ckpt.model_checkpoint_path:
  169. ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
  170. self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
  171. return True
  172. else:
  173. return False
  174. def sample_model(self, sample_dir, epoch, idx):
  175. dataA = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/testA'))
  176. dataB = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/testB'))
  177. np.random.shuffle(dataA)
  178. np.random.shuffle(dataB)
  179. batch_files = list(zip(dataA[:self.batch_size], dataB[:self.batch_size]))
  180. sample_images = [load_train_data(batch_file, is_testing=True) for batch_file in batch_files]
  181. sample_images = np.array(sample_images).astype(np.float32)
  182. fake_A, fake_B = self.sess.run(
  183. [self.fake_A, self.fake_B],
  184. feed_dict={self.real_data: sample_images}
  185. )
  186. save_images(fake_A, [self.batch_size, 1],
  187. './{}/A_{:02d}_{:04d}.jpg'.format(sample_dir, epoch, idx))
  188. save_images(fake_B, [self.batch_size, 1],
  189. './{}/B_{:02d}_{:04d}.jpg'.format(sample_dir, epoch, idx))
  190. def test(self, args):
  191. """Test cyclegan"""
  192. init_op = tf.global_variables_initializer()
  193. self.sess.run(init_op)
  194. if args.which_direction == 'AtoB':
  195. sample_files = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/testA'))
  196. elif args.which_direction == 'BtoA':
  197. sample_files = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/testB'))
  198. else:
  199. raise Exception('--which_direction must be AtoB or BtoA')
  200. if self.load(args.checkpoint_dir):
  201. print(" [*] Load SUCCESS")
  202. else:
  203. print(" [!] Load failed...")
  204. # write html for visual comparison
  205. index_path = os.path.join(args.test_dir, '{0}_index.html'.format(args.which_direction))
  206. index = open(index_path, "w")
  207. index.write("<html><body><table><tr>")
  208. index.write("<th>name</th><th>input</th><th>output</th></tr>")
  209. out_var, in_var = (self.testB, self.test_A) if args.which_direction == 'AtoB' else (
  210. self.testA, self.test_B)
  211. for sample_file in sample_files:
  212. print('Processing image: ' + sample_file)
  213. sample_image = [load_test_data(sample_file, args.fine_size)]
  214. sample_image = np.array(sample_image).astype(np.float32)
  215. image_path = os.path.join(args.test_dir,
  216. '{0}_{1}'.format(args.which_direction, os.path.basename(sample_file)))
  217. fake_img = self.sess.run(out_var, feed_dict={in_var: sample_image})
  218. save_images(fake_img, [1, 1], image_path)
  219. index.write("<td>%s</td>" % os.path.basename(image_path))
  220. index.write("<td><img src='%s'></td>" % (sample_file if os.path.isabs(sample_file) else (
  221. '..' + os.path.sep + sample_file)))
  222. index.write("<td><img src='%s'></td>" % (image_path if os.path.isabs(image_path) else (
  223. '..' + os.path.sep + image_path)))
  224. index.write("</tr>")
  225. index.close()