model.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. import os
  2. import random
  3. from tqdm import trange, tqdm
  4. #from scipy.misc import imsave
  5. from skimage.io import imsave
  6. import tensorflow as tf
  7. import numpy as np
  8. from generator import Generator
  9. from encoder import Encoder
  10. from discriminator import Discriminator
  11. from utils import logger
  12. class BicycleGAN(object):
  13. def __init__(self, args):
  14. self._log_step = args.log_step
  15. self._batch_size = args.batch_size
  16. self._image_size = args.image_size
  17. self._latent_dim = args.latent_dim
  18. self._coeff_gan = args.coeff_gan
  19. self._coeff_vae = args.coeff_vae
  20. self._coeff_reconstruct = args.coeff_reconstruct
  21. self._coeff_latent = args.coeff_latent
  22. self._coeff_kl = args.coeff_kl
  23. self._norm = 'instance' if args.instance_normalization else 'batch'
  24. self._use_resnet = args.use_resnet
  25. self._augment_size = self._image_size + (30 if self._image_size == 256 else 15)
  26. self._image_shape = [self._image_size, self._image_size, 3]
  27. self.is_train = tf.placeholder(tf.bool, name='is_train')
  28. self.lr = tf.placeholder(tf.float32, name='lr')
  29. self.global_step = tf.train.get_or_create_global_step(graph=None)
  30. image_a = self.image_a = \
  31. tf.placeholder(tf.float32, [self._batch_size] + self._image_shape, name='image_a')
  32. image_b = self.image_b = \
  33. tf.placeholder(tf.float32, [self._batch_size] + self._image_shape, name='image_b')
  34. z = self.z = \
  35. tf.placeholder(tf.float32, [self._batch_size, self._latent_dim], name='z')
  36. # Data augmentation
  37. seed = random.randint(0, 2**31 - 1)
  38. def augment_image(image):
  39. image = tf.image.resize_images(image, [self._augment_size, self._augment_size])
  40. image = tf.random_crop(image, [self._batch_size] + self._image_shape, seed=seed)
  41. image = tf.map_fn(lambda x: tf.image.random_flip_left_right(x, seed), image)
  42. return image
  43. image_a = tf.cond(self.is_train,
  44. lambda: augment_image(image_a),
  45. lambda: image_a)
  46. image_b = tf.cond(self.is_train,
  47. lambda: augment_image(image_b),
  48. lambda: image_b)
  49. # Generator
  50. G = Generator('G', is_train=self.is_train,
  51. norm=self._norm, image_size=self._image_size)
  52. # Discriminator
  53. D = Discriminator('D', is_train=self.is_train,
  54. norm=self._norm, activation='leaky',
  55. image_size=self._image_size)
  56. # Encoder
  57. E = Encoder('E', is_train=self.is_train,
  58. norm=self._norm, activation='relu',
  59. image_size=self._image_size, latent_dim=self._latent_dim,
  60. use_resnet=self._use_resnet)
  61. # conditional VAE-GAN: B -> z -> B'
  62. z_encoded, z_encoded_mu, z_encoded_log_sigma = E(image_b)
  63. image_ab_encoded = G(image_a, z_encoded)
  64. # conditional Latent Regressor-GAN: z -> B' -> z'
  65. image_ab = self.image_ab = G(image_a, z)
  66. z_recon, z_recon_mu, z_recon_log_sigma = E(image_ab)
  67. # Discriminate real/fake images
  68. D_real = D(image_b)
  69. D_fake = D(image_ab)
  70. D_fake_encoded = D(image_ab_encoded)
  71. loss_vae_gan = (tf.reduce_mean(tf.squared_difference(D_real, 0.9)) +
  72. tf.reduce_mean(tf.square(D_fake_encoded)))
  73. loss_image_cycle = tf.reduce_mean(tf.abs(image_b - image_ab_encoded))
  74. loss_gan = (tf.reduce_mean(tf.squared_difference(D_real, 0.9)) +
  75. tf.reduce_mean(tf.square(D_fake)))
  76. loss_latent_cycle = tf.reduce_mean(tf.abs(z - z_recon))
  77. loss_kl = -0.5 * tf.reduce_mean(1 + 2 * z_encoded_log_sigma - z_encoded_mu ** 2 -
  78. tf.exp(2 * z_encoded_log_sigma))
  79. loss = self._coeff_vae * loss_vae_gan - self._coeff_reconstruct * loss_image_cycle + \
  80. self._coeff_gan * loss_gan - self._coeff_latent * loss_latent_cycle - \
  81. self._coeff_kl * loss_kl
  82. # Optimizer
  83. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  84. with tf.control_dependencies(update_ops):
  85. self.optimizer_D = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
  86. .minimize(loss, var_list=D.var_list, global_step=self.global_step)
  87. self.optimizer_G = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
  88. .minimize(-loss, var_list=G.var_list)
  89. self.optimizer_E = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
  90. .minimize(-loss, var_list=E.var_list)
  91. # Summaries
  92. self.loss_vae_gan = loss_vae_gan
  93. self.loss_image_cycle = loss_image_cycle
  94. self.loss_latent_cycle = loss_latent_cycle
  95. self.loss_gan = loss_gan
  96. self.loss_kl = loss_kl
  97. self.loss = loss
  98. tf.summary.scalar('loss/vae_gan', loss_vae_gan)
  99. tf.summary.scalar('loss/image_cycle', loss_image_cycle)
  100. tf.summary.scalar('loss/latent_cycle', loss_latent_cycle)
  101. tf.summary.scalar('loss/gan', loss_gan)
  102. tf.summary.scalar('loss/kl', loss_kl)
  103. tf.summary.scalar('loss/total', loss)
  104. tf.summary.scalar('model/D_real', tf.reduce_mean(D_real))
  105. tf.summary.scalar('model/D_fake', tf.reduce_mean(D_fake))
  106. tf.summary.scalar('model/D_fake_encoded', tf.reduce_mean(D_fake_encoded))
  107. tf.summary.scalar('model/lr', self.lr)
  108. tf.summary.image('image/A', image_a[0:1])
  109. tf.summary.image('image/B', image_b[0:1])
  110. tf.summary.image('image/A-B', image_ab[0:1])
  111. tf.summary.image('image/A-B_encoded', image_ab_encoded[0:1])
  112. self.summary_op = tf.summary.merge_all()
  113. def train(self, sess, summary_writer, data_A, data_B):
  114. logger.info('Start training.')
  115. logger.info(' {} images from A'.format(len(data_A)))
  116. logger.info(' {} images from B'.format(len(data_B)))
  117. assert len(data_A) == len(data_B), \
  118. 'Data size mismatch dataA(%d) dataB(%d)' % (len(data_A), len(data_B))
  119. data_size = len(data_A)
  120. num_batch = data_size // self._batch_size
  121. epoch_length = num_batch * self._batch_size
  122. num_initial_iter = 8
  123. num_decay_iter = 2
  124. lr = lr_initial = 0.0002
  125. lr_decay = lr_initial / num_decay_iter
  126. initial_step = sess.run(self.global_step)
  127. num_global_step = (num_initial_iter + num_decay_iter) * epoch_length
  128. t = trange(initial_step, num_global_step,
  129. total=num_global_step, initial=initial_step)
  130. for step in t:
  131. #TODO: resume training with global_step
  132. epoch = step // epoch_length
  133. iter = step % epoch_length
  134. if epoch > num_initial_iter:
  135. lr = max(0.0, lr_initial - (epoch - num_initial_iter) * lr_decay)
  136. if iter == 0:
  137. data = list(zip(data_A, data_B))
  138. random.shuffle(data)
  139. data_A, data_B = zip(*data)
  140. image_a = np.stack(data_A[iter*self._batch_size:(iter+1)*self._batch_size])
  141. image_b = np.stack(data_B[iter*self._batch_size:(iter+1)*self._batch_size])
  142. sample_z = np.random.normal(size=(self._batch_size, self._latent_dim))
  143. fetches = [self.loss, self.optimizer_D,
  144. self.optimizer_G, self.optimizer_E]
  145. if step % self._log_step == 0:
  146. fetches += [self.summary_op]
  147. fetched = sess.run(fetches, feed_dict={self.image_a: image_a,
  148. self.image_b: image_b,
  149. self.is_train: True,
  150. self.lr: lr,
  151. self.z: sample_z})
  152. if step % self._log_step == 0:
  153. z = np.random.normal(size=(1, self._latent_dim))
  154. image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
  155. self.z: z,
  156. self.is_train: False})
  157. imsave('results/r_{}.jpg'.format(step), np.squeeze(image_ab, axis=0))
  158. summary_writer.add_summary(fetched[-1], step)
  159. summary_writer.flush()
  160. t.set_description('Loss({:.3f})'.format(fetched[0]))
  161. def test(self, sess, data_A, data_B, base_dir):
  162. step = 0
  163. for (dataA, dataB) in tqdm(zip(data_A, data_B)):
  164. step += 1
  165. image_a = np.expand_dims(dataA, axis=0)
  166. image_b = np.expand_dims(dataB, axis=0)
  167. images_random = []
  168. images_random.append(image_a)
  169. images_random.append(image_b)
  170. images_linear = []
  171. images_linear.append(image_a)
  172. images_linear.append(image_b)
  173. for i in range(23):
  174. z = np.random.normal(size=(1, self._latent_dim))
  175. image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
  176. self.z: z,
  177. self.is_train: False})
  178. images_random.append(image_ab)
  179. z = np.zeros((1, self._latent_dim))
  180. z[0][0] = (i / 23.0 - 0.5) * 2.0
  181. image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
  182. self.z: z,
  183. self.is_train: False})
  184. images_linear.append(image_ab)
  185. image_rows = []
  186. for i in range(5):
  187. image_rows.append(np.concatenate(images_random[i*5:(i+1)*5], axis=2))
  188. images = np.concatenate(image_rows, axis=1)
  189. images = np.squeeze(images, axis=0)
  190. imsave(os.path.join(base_dir, 'random_{}.jpg'.format(step)), images)
  191. image_rows = []
  192. for i in range(5):
  193. image_rows.append(np.concatenate(images_linear[i*5:(i+1)*5], axis=2))
  194. images = np.concatenate(image_rows, axis=1)
  195. images = np.squeeze(images, axis=0)
  196. imsave(os.path.join(base_dir, 'linear_{}.jpg'.format(step)), images)