model.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. import os
  2. import random
  3. from tqdm import trange, tqdm
  4. from scipy.misc import imsave
  5. import tensorflow as tf
  6. import numpy as np
  7. from generator import Generator
  8. from encoder import Encoder
  9. from discriminator import Discriminator
  10. from utils import logger
  11. class BicycleGAN(object):
  12. def __init__(self, args):
  13. self._log_step = args.log_step
  14. self._batch_size = args.batch_size
  15. self._image_size = args.image_size
  16. self._latent_dim = args.latent_dim
  17. self._coeff_gan = args.coeff_gan
  18. self._coeff_vae = args.coeff_vae
  19. self._coeff_reconstruct = args.coeff_reconstruct
  20. self._coeff_latent = args.coeff_latent
  21. self._coeff_kl = args.coeff_kl
  22. self._norm = 'instance' if args.instance_normalization else 'batch'
  23. self._use_resnet = args.use_resnet
  24. self._augment_size = self._image_size + (30 if self._image_size == 256 else 15)
  25. self._image_shape = [self._image_size, self._image_size, 3]
  26. self.is_train = tf.placeholder(tf.bool, name='is_train')
  27. self.lr = tf.placeholder(tf.float32, name='lr')
  28. self.global_step = tf.train.get_or_create_global_step(graph=None)
  29. image_a = self.image_a = \
  30. tf.placeholder(tf.float32, [self._batch_size] + self._image_shape, name='image_a')
  31. image_b = self.image_b = \
  32. tf.placeholder(tf.float32, [self._batch_size] + self._image_shape, name='image_b')
  33. z = self.z = \
  34. tf.placeholder(tf.float32, [self._batch_size, self._latent_dim], name='z')
  35. # Data augmentation
  36. seed = random.randint(0, 2**31 - 1)
  37. def augment_image(image):
  38. image = tf.image.resize_images(image, [self._augment_size, self._augment_size])
  39. image = tf.random_crop(image, [self._batch_size] + self._image_shape, seed=seed)
  40. image = tf.map_fn(lambda x: tf.image.random_flip_left_right(x, seed), image)
  41. return image
  42. image_a = tf.cond(self.is_train,
  43. lambda: augment_image(image_a),
  44. lambda: image_a)
  45. image_b = tf.cond(self.is_train,
  46. lambda: augment_image(image_b),
  47. lambda: image_b)
  48. # Generator
  49. G = Generator('G', is_train=self.is_train,
  50. norm=self._norm, image_size=self._image_size)
  51. # Discriminator
  52. D = Discriminator('D', is_train=self.is_train,
  53. norm=self._norm, activation='leaky',
  54. image_size=self._image_size)
  55. # Encoder
  56. E = Encoder('E', is_train=self.is_train,
  57. norm=self._norm, activation='relu',
  58. image_size=self._image_size, latent_dim=self._latent_dim,
  59. use_resnet=self._use_resnet)
  60. # conditional VAE-GAN: B -> z -> B'
  61. z_encoded, z_encoded_mu, z_encoded_log_sigma = E(image_b)
  62. image_ab_encoded = G(image_a, z_encoded)
  63. # conditional Latent Regressor-GAN: z -> B' -> z'
  64. image_ab = self.image_ab = G(image_a, z)
  65. z_recon, z_recon_mu, z_recon_log_sigma = E(image_ab)
  66. # Discriminate real/fake images
  67. D_real = D(image_b)
  68. D_fake = D(image_ab)
  69. D_fake_encoded = D(image_ab_encoded)
  70. loss_vae_gan = (tf.reduce_mean(tf.squared_difference(D_real, 0.9)) +
  71. tf.reduce_mean(tf.square(D_fake_encoded)))
  72. loss_image_cycle = tf.reduce_mean(tf.abs(image_b - image_ab_encoded))
  73. loss_gan = (tf.reduce_mean(tf.squared_difference(D_real, 0.9)) +
  74. tf.reduce_mean(tf.square(D_fake)))
  75. loss_latent_cycle = tf.reduce_mean(tf.abs(z - z_recon))
  76. loss_kl = -0.5 * tf.reduce_mean(1 + 2 * z_encoded_log_sigma - z_encoded_mu ** 2 -
  77. tf.exp(2 * z_encoded_log_sigma))
  78. loss = self._coeff_vae * loss_vae_gan - self._coeff_reconstruct * loss_image_cycle + \
  79. self._coeff_gan * loss_gan - self._coeff_latent * loss_latent_cycle - \
  80. self._coeff_kl * loss_kl
  81. # Optimizer
  82. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  83. with tf.control_dependencies(update_ops):
  84. self.optimizer_D = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
  85. .minimize(loss, var_list=D.var_list, global_step=self.global_step)
  86. self.optimizer_G = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
  87. .minimize(-loss, var_list=G.var_list)
  88. self.optimizer_E = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
  89. .minimize(-loss, var_list=E.var_list)
  90. # Summaries
  91. self.loss_vae_gan = loss_vae_gan
  92. self.loss_image_cycle = loss_image_cycle
  93. self.loss_latent_cycle = loss_latent_cycle
  94. self.loss_gan = loss_gan
  95. self.loss_kl = loss_kl
  96. self.loss = loss
  97. tf.summary.scalar('loss/vae_gan', loss_vae_gan)
  98. tf.summary.scalar('loss/image_cycle', loss_image_cycle)
  99. tf.summary.scalar('loss/latent_cycle', loss_latent_cycle)
  100. tf.summary.scalar('loss/gan', loss_gan)
  101. tf.summary.scalar('loss/kl', loss_kl)
  102. tf.summary.scalar('loss/total', loss)
  103. tf.summary.scalar('model/D_real', tf.reduce_mean(D_real))
  104. tf.summary.scalar('model/D_fake', tf.reduce_mean(D_fake))
  105. tf.summary.scalar('model/D_fake_encoded', tf.reduce_mean(D_fake_encoded))
  106. tf.summary.scalar('model/lr', self.lr)
  107. tf.summary.image('image/A', image_a[0:1])
  108. tf.summary.image('image/B', image_b[0:1])
  109. tf.summary.image('image/A-B', image_ab[0:1])
  110. tf.summary.image('image/A-B_encoded', image_ab_encoded[0:1])
  111. self.summary_op = tf.summary.merge_all()
  112. def train(self, sess, summary_writer, data_A, data_B):
  113. logger.info('Start training.')
  114. logger.info(' {} images from A'.format(len(data_A)))
  115. logger.info(' {} images from B'.format(len(data_B)))
  116. assert len(data_A) == len(data_B), \
  117. 'Data size mismatch dataA(%d) dataB(%d)' % (len(data_A), len(data_B))
  118. data_size = len(data_A)
  119. num_batch = data_size // self._batch_size
  120. epoch_length = num_batch * self._batch_size
  121. num_initial_iter = 8
  122. num_decay_iter = 2
  123. lr = lr_initial = 0.0002
  124. lr_decay = lr_initial / num_decay_iter
  125. initial_step = sess.run(self.global_step)
  126. num_global_step = (num_initial_iter + num_decay_iter) * epoch_length
  127. t = trange(initial_step, num_global_step,
  128. total=num_global_step, initial=initial_step)
  129. for step in t:
  130. #TODO: resume training with global_step
  131. epoch = step // epoch_length
  132. iter = step % epoch_length
  133. if epoch > num_initial_iter:
  134. lr = max(0.0, lr_initial - (epoch - num_initial_iter) * lr_decay)
  135. if iter == 0:
  136. data = zip(data_A, data_B)
  137. random.shuffle(data)
  138. data_A, data_B = zip(*data)
  139. image_a = np.stack(data_A[iter*self._batch_size:(iter+1)*self._batch_size])
  140. image_b = np.stack(data_B[iter*self._batch_size:(iter+1)*self._batch_size])
  141. sample_z = np.random.normal(size=(self._batch_size, self._latent_dim))
  142. fetches = [self.loss, self.optimizer_D,
  143. self.optimizer_G, self.optimizer_E]
  144. if step % self._log_step == 0:
  145. fetches += [self.summary_op]
  146. fetched = sess.run(fetches, feed_dict={self.image_a: image_a,
  147. self.image_b: image_b,
  148. self.is_train: True,
  149. self.lr: lr,
  150. self.z: sample_z})
  151. if step % self._log_step == 0:
  152. z = np.random.normal(size=(1, self._latent_dim))
  153. image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
  154. self.z: z,
  155. self.is_train: False})
  156. imsave('results/r_{}.jpg'.format(step), np.squeeze(image_ab, axis=0))
  157. summary_writer.add_summary(fetched[-1], step)
  158. summary_writer.flush()
  159. t.set_description('Loss({:.3f})'.format(fetched[0]))
  160. def test(self, sess, data_A, data_B, base_dir):
  161. step = 0
  162. for (dataA, dataB) in tqdm(zip(data_A, data_B)):
  163. step += 1
  164. image_a = np.expand_dims(dataA, axis=0)
  165. image_b = np.expand_dims(dataB, axis=0)
  166. images_random = []
  167. images_random.append(image_a)
  168. images_random.append(image_b)
  169. images_linear = []
  170. images_linear.append(image_a)
  171. images_linear.append(image_b)
  172. for i in range(23):
  173. z = np.random.normal(size=(1, self._latent_dim))
  174. image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
  175. self.z: z,
  176. self.is_train: False})
  177. images_random.append(image_ab)
  178. z = np.zeros((1, self._latent_dim))
  179. z[0][0] = (i / 23.0 - 0.5) * 2.0
  180. image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
  181. self.z: z,
  182. self.is_train: False})
  183. images_linear.append(image_ab)
  184. image_rows = []
  185. for i in range(5):
  186. image_rows.append(np.concatenate(images_random[i*5:(i+1)*5], axis=2))
  187. images = np.concatenate(image_rows, axis=1)
  188. images = np.squeeze(images, axis=0)
  189. imsave(os.path.join(base_dir, 'random_{}.jpg'.format(step)), images)
  190. image_rows = []
  191. for i in range(5):
  192. image_rows.append(np.concatenate(images_linear[i*5:(i+1)*5], axis=2))
  193. images = np.concatenate(image_rows, axis=1)
  194. images = np.squeeze(images, axis=0)
  195. imsave(os.path.join(base_dir, 'linear_{}.jpg'.format(step)), images)