model.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. import os
  2. import random
  3. from tqdm import trange, tqdm
  4. from scipy.misc import imsave
  5. import tensorflow as tf
  6. import numpy as np
  7. from generator import Generator
  8. from encoder import Encoder
  9. from discriminator import Discriminator
  10. from utils import logger
  11. class BicycleGAN(object):
  12. def __init__(self, args):
  13. self._log_step = args.log_step
  14. self._batch_size = args.batch_size
  15. self._image_size = args.image_size
  16. self._latent_dim = args.latent_dim
  17. self._coeff_reconstruct = args.coeff_reconstruct
  18. self._coeff_latent = args.coeff_latent
  19. self._coeff_kl = args.coeff_kl
  20. self._augment_size = self._image_size + (30 if self._image_size == 256 else 15)
  21. self._image_shape = [self._image_size, self._image_size, 3]
  22. self.is_train = tf.placeholder(tf.bool, name='is_train')
  23. self.lr = tf.placeholder(tf.float32, name='lr')
  24. self.global_step = tf.contrib.framework.get_or_create_global_step(graph=None)
  25. image_a = self.image_a = \
  26. tf.placeholder(tf.float32, [self._batch_size] + self._image_shape, name='image_a')
  27. image_b = self.image_b = \
  28. tf.placeholder(tf.float32, [self._batch_size] + self._image_shape, name='image_b')
  29. z = self.z = \
  30. tf.placeholder(tf.float32, [self._batch_size, self._latent_dim], name='z')
  31. # Data augmentation
  32. seed = random.randint(0, 2**31 - 1)
  33. def augment_image(image):
  34. image = tf.image.resize_images(image, [self._augment_size, self._augment_size])
  35. image = tf.random_crop(image, [self._batch_size] + self._image_shape, seed=seed)
  36. image = tf.map_fn(lambda x: tf.image.random_flip_left_right(x, seed), image)
  37. return image
  38. image_a = tf.cond(self.is_train,
  39. lambda: augment_image(image_a),
  40. lambda: image_a)
  41. image_b = tf.cond(self.is_train,
  42. lambda: augment_image(image_b),
  43. lambda: image_b)
  44. # Generator
  45. G = Generator('G', is_train=self.is_train,
  46. norm='batch', image_size=self._image_size)
  47. # Discriminator
  48. D = Discriminator('D', is_train=self.is_train,
  49. norm='batch', activation='leaky',
  50. image_size=self._image_size)
  51. # Encoder
  52. E = Encoder('E', is_train=self.is_train,
  53. norm='batch', activation='relu',
  54. image_size=self._image_size, latent_dim=self._latent_dim)
  55. # conditional VAE-GAN: B -> z -> B'
  56. z_encoded, z_encoded_mu, z_encoded_log_sigma = E(image_b)
  57. image_ab_encoded = G(image_a, z_encoded)
  58. # conditional Latent Regressor-GAN: z -> B' -> z'
  59. image_ab = self.image_ab = G(image_a, z)
  60. z_recon, z_recon_mu, z_recon_log_sigma = E(image_ab)
  61. # Discriminate real/fake images
  62. D_real = D(image_b)
  63. D_fake = D(image_ab)
  64. D_fake_encoded = D(image_ab_encoded)
  65. loss_vae_gan = (tf.reduce_mean(tf.squared_difference(D_real, 0.9)) +
  66. tf.reduce_mean(tf.square(D_fake_encoded)))
  67. loss_image_cycle = tf.reduce_mean(tf.abs(image_b - image_ab_encoded))
  68. loss_gan = (tf.reduce_mean(tf.squared_difference(D_real, 0.9)) +
  69. tf.reduce_mean(tf.square(D_fake)))
  70. loss_latent_cycle = tf.reduce_mean(tf.abs(z - z_recon))
  71. loss_kl = -0.5 * tf.reduce_mean(1 + 2 * z_encoded_log_sigma - z_encoded_mu ** 2 -
  72. tf.exp(2 * z_encoded_log_sigma), 1)
  73. loss = loss_vae_gan + self._coeff_reconstruct * loss_image_cycle + \
  74. loss_gan + self._coeff_latent * loss_latent_cycle + \
  75. self._coeff_kl * loss_kl
  76. # Optimizer
  77. self.optimizer_D = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
  78. .minimize(loss, var_list=D.var_list, global_step=self.global_step)
  79. self.optimizer_G = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
  80. .minimize(-loss, var_list=G.var_list)
  81. self.optimizer_E = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
  82. .minimize(-loss, var_list=E.var_list)
  83. # Summaries
  84. self.loss_image_reconstruct = loss_image_reconstruct
  85. self.loss_image_cycle = loss_image_cycle
  86. self.loss_latent_cycle = loss_latent_cycle
  87. self.loss_gan = loss_gan
  88. self.loss_z_kl = loss_z_kl
  89. self.loss = loss
  90. tf.summary.scalar('loss/image_reconstruct', loss_image_reconstruct)
  91. tf.summary.scalar('loss/image_cycle', loss_image_cycle)
  92. tf.summary.scalar('loss/latent_cycle', loss_latent_cycle)
  93. tf.summary.scalar('loss/gan', loss_gan)
  94. tf.summary.scalar('loss/Dz', loss_z_kl)
  95. tf.summary.scalar('loss/total', loss)
  96. tf.summary.scalar('model/D_real', tf.reduce_mean(D_real))
  97. tf.summary.scalar('model/D_fake', tf.reduce_mean(D_fake))
  98. tf.summary.scalar('model/D_fake_encoded', tf.reduce_mean(D_fake_encoded))
  99. tf.summary.scalar('model/lr', self.lr)
  100. tf.summary.image('image/A', image_a[0:1])
  101. tf.summary.image('image/B', image_b[0:1])
  102. tf.summary.image('image/A-B', image_ab[0:1])
  103. tf.summary.image('image/A-B_encoded', image_ab_encoded[0:1])
  104. self.summary_op = tf.summary.merge_all()
  105. def train(self, sess, summary_writer, data_A, data_B):
  106. logger.info('Start training.')
  107. logger.info(' {} images from A'.format(len(data_A)))
  108. logger.info(' {} images from B'.format(len(data_B)))
  109. assert len(data_A) == len(data_B), \
  110. 'Data size mismatch dataA(%d) dataB(%d)' % (len(data_A), len(data_B))
  111. data_size = len(data_A)
  112. num_batch = data_size // self._batch_size
  113. epoch_length = num_batch * self._batch_size
  114. num_initial_iter = 8
  115. num_decay_iter = 2
  116. lr = lr_initial = 0.0002
  117. lr_decay = lr_initial / num_decay_iter
  118. initial_step = sess.run(self.global_step)
  119. num_global_step = (num_initial_iter + num_decay_iter) * epoch_length
  120. t = trange(initial_step, num_global_step,
  121. total=num_global_step, initial=initial_step)
  122. for step in t:
  123. #TODO: resume training with global_step
  124. epoch = step // epoch_length
  125. iter = step % epoch_length
  126. if epoch > num_initial_iter:
  127. lr = max(0.0, lr_initial - (epoch - num_initial_iter) * lr_decay)
  128. if iter == 0:
  129. data = zip(data_A, data_B)
  130. random.shuffle(data)
  131. data_A, data_B = zip(*data)
  132. image_a = np.stack(data_A[iter*self._batch_size:(iter+1)*self._batch_size])
  133. image_b = np.stack(data_B[iter*self._batch_size:(iter+1)*self._batch_size])
  134. #sample_z = np.random.uniform(-1, 1, size=(self._batch_size, self._latent_dim))
  135. sample_z = np.random.normal(size=(self._batch_size, self._latent_dim))
  136. fetches = [self.loss,
  137. self.optimizer_D, self.optimizer_Dz,
  138. self.optimizer_G, self.optimizer_E]
  139. if step % self._log_step == 0:
  140. fetches += [self.summary_op]
  141. fetched = sess.run(fetches, feed_dict={self.image_a: image_a,
  142. self.image_b: image_b,
  143. self.is_train: True,
  144. self.lr: lr,
  145. self.z: sample_z})
  146. if step % self._log_step == 0:
  147. z = np.random.normal(size=(1, self._latent_dim))
  148. image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
  149. self.z: z,
  150. self.is_train: False})
  151. imsave('results/r_{}.jpg'.format(step), np.squeeze(image_ab, axis=0))
  152. summary_writer.add_summary(fetched[-1], step)
  153. summary_writer.flush()
  154. t.set_description('Loss({:.3f})'.format(fetched[0]))
  155. def test(self, sess, data_A, data_B, base_dir):
  156. step = 0
  157. for (dataA, dataB) in tqdm(zip(data_A, data_B)):
  158. step += 1
  159. image_a = np.expand_dims(dataA, axis=0)
  160. image_b = np.expand_dims(dataB, axis=0)
  161. images_random = []
  162. images_random.append(image_a)
  163. images_random.append(image_b)
  164. images_linear = []
  165. images_linear.append(image_a)
  166. images_linear.append(image_b)
  167. for i in range(23):
  168. z = np.random.uniform(-1, 1, size=(1, self._latent_dim))
  169. image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
  170. self.z: z,
  171. self.is_train: False})
  172. images_random.append(image_ab)
  173. z = np.zeros((1, self._latent_dim))
  174. z[0][0] = (i / 23.0 - 0.5) * 2.0
  175. image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
  176. self.z: z,
  177. self.is_train: False})
  178. images_linear.append(image_ab)
  179. image_rows = []
  180. for i in range(5):
  181. image_rows.append(np.concatenate(images_random[i*5:(i+1)*5], axis=2))
  182. images = np.concatenate(image_rows, axis=1)
  183. images = np.squeeze(images, axis=0)
  184. imsave(os.path.join(base_dir, 'random_{}.jpg'.format(step)), images)
  185. image_rows = []
  186. for i in range(5):
  187. image_rows.append(np.concatenate(images_linear[i*5:(i+1)*5], axis=2))
  188. images = np.concatenate(image_rows, axis=1)
  189. images = np.squeeze(images, axis=0)
  190. imsave(os.path.join(base_dir, 'linear_{}.jpg'.format(step)), images)