model.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. import os
  2. import random
  3. from tqdm import trange, tqdm
  4. from scipy.misc import imsave
  5. import tensorflow as tf
  6. import numpy as np
  7. from generator import Generator
  8. from encoder import Encoder
  9. from discriminator import Discriminator
  10. from utils import logger
  11. class BicycleGAN(object):
  12. def __init__(self, args):
  13. self._log_step = args.log_step
  14. self._batch_size = args.batch_size
  15. self._image_size = args.image_size
  16. self._latent_dim = args.latent_dim
  17. self._coeff_reconstruct = args.coeff_reconstruct
  18. self._coeff_latent = args.coeff_latent
  19. self._coeff_kl = args.coeff_kl
  20. self._norm = 'instance' if args.instance_normalization else 'batch'
  21. self._use_resnet = args.use_resnet
  22. self._augment_size = self._image_size + (30 if self._image_size == 256 else 15)
  23. self._image_shape = [self._image_size, self._image_size, 3]
  24. self.is_train = tf.placeholder(tf.bool, name='is_train')
  25. self.lr = tf.placeholder(tf.float32, name='lr')
  26. self.global_step = tf.train.get_or_create_global_step(graph=None)
  27. image_a = self.image_a = \
  28. tf.placeholder(tf.float32, [self._batch_size] + self._image_shape, name='image_a')
  29. image_b = self.image_b = \
  30. tf.placeholder(tf.float32, [self._batch_size] + self._image_shape, name='image_b')
  31. z = self.z = \
  32. tf.placeholder(tf.float32, [self._batch_size, self._latent_dim], name='z')
  33. # Data augmentation
  34. seed = random.randint(0, 2**31 - 1)
  35. def augment_image(image):
  36. image = tf.image.resize_images(image, [self._augment_size, self._augment_size])
  37. image = tf.random_crop(image, [self._batch_size] + self._image_shape, seed=seed)
  38. image = tf.map_fn(lambda x: tf.image.random_flip_left_right(x, seed), image)
  39. return image
  40. image_a = tf.cond(self.is_train,
  41. lambda: augment_image(image_a),
  42. lambda: image_a)
  43. image_b = tf.cond(self.is_train,
  44. lambda: augment_image(image_b),
  45. lambda: image_b)
  46. # Generator
  47. G = Generator('G', is_train=self.is_train,
  48. norm=self._norm, image_size=self._image_size)
  49. # Discriminator
  50. D = Discriminator('D', is_train=self.is_train,
  51. norm=self._norm, activation='leaky',
  52. image_size=self._image_size)
  53. # Encoder
  54. E = Encoder('E', is_train=self.is_train,
  55. norm=self._norm, activation='relu',
  56. image_size=self._image_size, latent_dim=self._latent_dim,
  57. use_resnet=self._use_resnet)
  58. # conditional VAE-GAN: B -> z -> B'
  59. z_encoded, z_encoded_mu, z_encoded_log_sigma = E(image_b)
  60. image_ab_encoded = G(image_a, z_encoded)
  61. # conditional Latent Regressor-GAN: z -> B' -> z'
  62. image_ab = self.image_ab = G(image_a, z)
  63. z_recon, z_recon_mu, z_recon_log_sigma = E(image_ab)
  64. # Discriminate real/fake images
  65. D_real = D(image_b)
  66. D_fake = D(image_ab)
  67. D_fake_encoded = D(image_ab_encoded)
  68. loss_vae_gan = (tf.reduce_mean(tf.squared_difference(D_real, 0.9)) +
  69. tf.reduce_mean(tf.square(D_fake_encoded)))
  70. loss_image_cycle = tf.reduce_mean(tf.abs(image_b - image_ab_encoded))
  71. loss_gan = (tf.reduce_mean(tf.squared_difference(D_real, 0.9)) +
  72. tf.reduce_mean(tf.square(D_fake)))
  73. loss_latent_cycle = tf.reduce_mean(tf.abs(z - z_recon))
  74. loss_kl = -0.5 * tf.reduce_mean(1 + 2 * z_encoded_log_sigma - z_encoded_mu ** 2 -
  75. tf.exp(2 * z_encoded_log_sigma))
  76. loss = loss_vae_gan - self._coeff_reconstruct * loss_image_cycle + \
  77. loss_gan - self._coeff_latent * loss_latent_cycle - \
  78. self._coeff_kl * loss_kl
  79. # Optimizer
  80. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  81. with tf.control_dependencies(update_ops):
  82. self.optimizer_D = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
  83. .minimize(loss, var_list=D.var_list, global_step=self.global_step)
  84. self.optimizer_G = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
  85. .minimize(-loss, var_list=G.var_list)
  86. self.optimizer_E = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
  87. .minimize(-loss, var_list=E.var_list)
  88. # Summaries
  89. self.loss_vae_gan = loss_vae_gan
  90. self.loss_image_cycle = loss_image_cycle
  91. self.loss_latent_cycle = loss_latent_cycle
  92. self.loss_gan = loss_gan
  93. self.loss_kl = loss_kl
  94. self.loss = loss
  95. tf.summary.scalar('loss/vae_gan', loss_vae_gan)
  96. tf.summary.scalar('loss/image_cycle', loss_image_cycle)
  97. tf.summary.scalar('loss/latent_cycle', loss_latent_cycle)
  98. tf.summary.scalar('loss/gan', loss_gan)
  99. tf.summary.scalar('loss/kl', loss_kl)
  100. tf.summary.scalar('loss/total', loss)
  101. tf.summary.scalar('model/D_real', tf.reduce_mean(D_real))
  102. tf.summary.scalar('model/D_fake', tf.reduce_mean(D_fake))
  103. tf.summary.scalar('model/D_fake_encoded', tf.reduce_mean(D_fake_encoded))
  104. tf.summary.scalar('model/lr', self.lr)
  105. tf.summary.image('image/A', image_a[0:1])
  106. tf.summary.image('image/B', image_b[0:1])
  107. tf.summary.image('image/A-B', image_ab[0:1])
  108. tf.summary.image('image/A-B_encoded', image_ab_encoded[0:1])
  109. self.summary_op = tf.summary.merge_all()
  110. def train(self, sess, summary_writer, data_A, data_B):
  111. logger.info('Start training.')
  112. logger.info(' {} images from A'.format(len(data_A)))
  113. logger.info(' {} images from B'.format(len(data_B)))
  114. assert len(data_A) == len(data_B), \
  115. 'Data size mismatch dataA(%d) dataB(%d)' % (len(data_A), len(data_B))
  116. data_size = len(data_A)
  117. num_batch = data_size // self._batch_size
  118. epoch_length = num_batch * self._batch_size
  119. num_initial_iter = 8
  120. num_decay_iter = 2
  121. lr = lr_initial = 0.0002
  122. lr_decay = lr_initial / num_decay_iter
  123. initial_step = sess.run(self.global_step)
  124. num_global_step = (num_initial_iter + num_decay_iter) * epoch_length
  125. t = trange(initial_step, num_global_step,
  126. total=num_global_step, initial=initial_step)
  127. for step in t:
  128. #TODO: resume training with global_step
  129. epoch = step // epoch_length
  130. iter = step % epoch_length
  131. if epoch > num_initial_iter:
  132. lr = max(0.0, lr_initial - (epoch - num_initial_iter) * lr_decay)
  133. if iter == 0:
  134. data = zip(data_A, data_B)
  135. random.shuffle(data)
  136. data_A, data_B = zip(*data)
  137. image_a = np.stack(data_A[iter*self._batch_size:(iter+1)*self._batch_size])
  138. image_b = np.stack(data_B[iter*self._batch_size:(iter+1)*self._batch_size])
  139. sample_z = np.random.normal(size=(self._batch_size, self._latent_dim))
  140. fetches = [self.loss, self.optimizer_D,
  141. self.optimizer_G, self.optimizer_E]
  142. if step % self._log_step == 0:
  143. fetches += [self.summary_op]
  144. fetched = sess.run(fetches, feed_dict={self.image_a: image_a,
  145. self.image_b: image_b,
  146. self.is_train: True,
  147. self.lr: lr,
  148. self.z: sample_z})
  149. if step % self._log_step == 0:
  150. z = np.random.normal(size=(1, self._latent_dim))
  151. image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
  152. self.z: z,
  153. self.is_train: False})
  154. imsave('results/r_{}.jpg'.format(step), np.squeeze(image_ab, axis=0))
  155. summary_writer.add_summary(fetched[-1], step)
  156. summary_writer.flush()
  157. t.set_description('Loss({:.3f})'.format(fetched[0]))
  158. def test(self, sess, data_A, data_B, base_dir):
  159. step = 0
  160. for (dataA, dataB) in tqdm(zip(data_A, data_B)):
  161. step += 1
  162. image_a = np.expand_dims(dataA, axis=0)
  163. image_b = np.expand_dims(dataB, axis=0)
  164. images_random = []
  165. images_random.append(image_a)
  166. images_random.append(image_b)
  167. images_linear = []
  168. images_linear.append(image_a)
  169. images_linear.append(image_b)
  170. for i in range(23):
  171. z = np.random.normal(size=(1, self._latent_dim))
  172. image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
  173. self.z: z,
  174. self.is_train: False})
  175. images_random.append(image_ab)
  176. z = np.zeros((1, self._latent_dim))
  177. z[0][0] = (i / 23.0 - 0.5) * 2.0
  178. image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
  179. self.z: z,
  180. self.is_train: False})
  181. images_linear.append(image_ab)
  182. image_rows = []
  183. for i in range(5):
  184. image_rows.append(np.concatenate(images_random[i*5:(i+1)*5], axis=2))
  185. images = np.concatenate(image_rows, axis=1)
  186. images = np.squeeze(images, axis=0)
  187. imsave(os.path.join(base_dir, 'random_{}.jpg'.format(step)), images)
  188. image_rows = []
  189. for i in range(5):
  190. image_rows.append(np.concatenate(images_linear[i*5:(i+1)*5], axis=2))
  191. images = np.concatenate(image_rows, axis=1)
  192. images = np.squeeze(images, axis=0)
  193. imsave(os.path.join(base_dir, 'linear_{}.jpg'.format(step)), images)