model.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. import os
  2. import random
  3. from tqdm import trange, tqdm
  4. from scipy.misc import imsave
  5. import tensorflow as tf
  6. import numpy as np
  7. from generator import Generator
  8. from encoder import Encoder
  9. from discriminator import Discriminator
  10. from discriminator_z import DiscriminatorZ
  11. from utils import logger
  12. class BicycleGAN(object):
  13. def __init__(self, args):
  14. self._log_step = args.log_step
  15. self._batch_size = args.batch_size
  16. self._image_size = args.image_size
  17. self._latent_dim = args.latent_dim
  18. self._lambda1 = args.lambda1
  19. self._lambda2 = args.lambda2
  20. self._gamma = args.gamma
  21. self._augment_size = self._image_size + (30 if self._image_size == 256 else 15)
  22. self._image_shape = [self._image_size, self._image_size, 3]
  23. self.is_train = tf.placeholder(tf.bool, name='is_train')
  24. self.lr = tf.placeholder(tf.float32, name='lr')
  25. self.global_step = tf.contrib.framework.get_or_create_global_step(graph=None)
  26. image_a = self.image_a = \
  27. tf.placeholder(tf.float32, [self._batch_size] + self._image_shape, name='image_a')
  28. image_b = self.image_b = \
  29. tf.placeholder(tf.float32, [self._batch_size] + self._image_shape, name='image_b')
  30. z = self.z = \
  31. tf.placeholder(tf.float32, [self._batch_size, self._latent_dim], name='z')
  32. # Data augmentation
  33. seed = random.randint(0, 2**31 - 1)
  34. def augment_image(image):
  35. image = tf.image.resize_images(image, [self._augment_size, self._augment_size])
  36. image = tf.random_crop(image, [self._batch_size] + self._image_shape, seed=seed)
  37. image = tf.map_fn(lambda x: tf.image.random_flip_left_right(x, seed), image)
  38. return image
  39. image_a = tf.cond(self.is_train,
  40. lambda: augment_image(image_a),
  41. lambda: image_a)
  42. image_b = tf.cond(self.is_train,
  43. lambda: augment_image(image_b),
  44. lambda: image_b)
  45. # Generator
  46. G = Generator('G', is_train=self.is_train,
  47. norm='batch', image_size=self._image_size)
  48. # Discriminator
  49. D = Discriminator('D', is_train=self.is_train,
  50. norm='batch', activation='leaky',
  51. image_size=self._image_size)
  52. Dz = DiscriminatorZ('Dz', is_train=self.is_train,
  53. norm='batch', activation='relu')
  54. # Encoder
  55. E = Encoder('E', is_train=self.is_train,
  56. norm='batch', activation='relu',
  57. image_size=self._image_size, latent_dim=self._latent_dim)
  58. # Generate images (a->b)
  59. image_ab = self.image_ab = G(image_a, z)
  60. z_reconstruct = E(image_ab)
  61. # Encode z (G(A, z) -> z)
  62. z_encoded = E(image_b)
  63. image_ab_encoded = G(image_a, z_encoded)
  64. # Discriminate real/fake images
  65. D_real = D(image_b)
  66. D_fake = D(image_ab)
  67. D_fake_encoded = D(image_ab_encoded)
  68. Dz_real = Dz(z)
  69. Dz_fake = Dz(z_encoded)
  70. loss_image_reconstruct = tf.reduce_mean(tf.abs(image_b - image_ab_encoded))
  71. loss_gan = (tf.reduce_mean(tf.squared_difference(D_real, 0.9)) +
  72. tf.reduce_mean(tf.square(D_fake))) * 0.5
  73. loss_image_cycle = (tf.reduce_mean(tf.squared_difference(D_real, 0.9)) +
  74. tf.reduce_mean(tf.square(D_fake_encoded))) * 0.5
  75. loss_latent_cycle = tf.reduce_mean(tf.abs(z - z_reconstruct))
  76. loss_Dz = (tf.reduce_mean(tf.squared_difference(Dz_real, 0.9)) +
  77. tf.reduce_mean(tf.square(Dz_fake))) * 0.5
  78. loss = self._gamma * loss_Dz \
  79. + loss_image_cycle - self._lambda1 * loss_image_reconstruct \
  80. + loss_gan - self._lambda2 * loss_latent_cycle
  81. # Optimizer
  82. self.optimizer_D = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
  83. .minimize(loss, var_list=D.var_list, global_step=self.global_step)
  84. self.optimizer_G = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
  85. .minimize(-loss, var_list=G.var_list)
  86. self.optimizer_Dz = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
  87. .minimize(loss, var_list=Dz.var_list)
  88. self.optimizer_E = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
  89. .minimize(-loss, var_list=E.var_list)
  90. # Summaries
  91. self.loss_image_reconstruct = loss_image_reconstruct
  92. self.loss_image_cycle = loss_image_cycle
  93. self.loss_latent_cycle = loss_latent_cycle
  94. self.loss_gan = loss_gan
  95. self.loss_Dz = loss_Dz
  96. self.loss = loss
  97. tf.summary.scalar('loss/image_reconstruct', loss_image_reconstruct)
  98. tf.summary.scalar('loss/image_cycle', loss_image_cycle)
  99. tf.summary.scalar('loss/latent_cycle', loss_latent_cycle)
  100. tf.summary.scalar('loss/gan', loss_gan)
  101. tf.summary.scalar('loss/Dz', loss_Dz)
  102. tf.summary.scalar('loss/total', loss)
  103. tf.summary.scalar('model/D_real', tf.reduce_mean(D_real))
  104. tf.summary.scalar('model/D_fake', tf.reduce_mean(D_fake))
  105. tf.summary.scalar('model/D_fake_encoded', tf.reduce_mean(D_fake_encoded))
  106. tf.summary.scalar('model/Dz_real', tf.reduce_mean(Dz_real))
  107. tf.summary.scalar('model/Dz_fake', tf.reduce_mean(Dz_fake))
  108. tf.summary.scalar('model/lr', self.lr)
  109. tf.summary.image('image/A', image_a[0:1])
  110. tf.summary.image('image/B', image_b[0:1])
  111. tf.summary.image('image/A-B', image_ab[0:1])
  112. tf.summary.image('image/A-B_encoded', image_ab_encoded[0:1])
  113. self.summary_op = tf.summary.merge_all()
  114. def train(self, sess, summary_writer, data_A, data_B):
  115. logger.info('Start training.')
  116. logger.info(' {} images from A'.format(len(data_A)))
  117. logger.info(' {} images from B'.format(len(data_B)))
  118. data_size = min(len(data_A), len(data_B))
  119. num_batch = data_size // self._batch_size
  120. epoch_length = num_batch * self._batch_size
  121. num_initial_iter = 8
  122. num_decay_iter = 2
  123. lr = lr_initial = 0.0002
  124. lr_decay = lr_initial / num_decay_iter
  125. initial_step = sess.run(self.global_step)
  126. num_global_step = (num_initial_iter + num_decay_iter) * epoch_length
  127. t = trange(initial_step, num_global_step,
  128. total=num_global_step, initial=initial_step)
  129. for step in t:
  130. #TODO: resume training with global_step
  131. epoch = step // epoch_length
  132. iter = step % epoch_length
  133. if epoch > num_initial_iter:
  134. lr = max(0.0, lr_initial - (epoch - num_initial_iter) * lr_decay)
  135. if iter == 0:
  136. data = zip(data_A, data_B)
  137. random.shuffle(data)
  138. data_A, data_B = zip(*data)
  139. image_a = np.stack(data_A[iter*self._batch_size:(iter+1)*self._batch_size])
  140. image_b = np.stack(data_B[iter*self._batch_size:(iter+1)*self._batch_size])
  141. sample_z = np.random.uniform(-1, 1, size=(self._batch_size, self._latent_dim))
  142. fetches = [self.loss,
  143. self.optimizer_D, self.optimizer_Dz,
  144. self.optimizer_G, self.optimizer_E]
  145. if step % self._log_step == 0:
  146. fetches += [self.summary_op]
  147. fetched = sess.run(fetches, feed_dict={self.image_a: image_a,
  148. self.image_b: image_b,
  149. self.is_train: True,
  150. self.lr: lr,
  151. self.z: sample_z})
  152. z = np.random.uniform(-1, 1, size=(1, self._latent_dim))
  153. image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
  154. self.image_b: image_b,
  155. self.lr: lr,
  156. self.z: z,
  157. self.is_train: True})
  158. imsave('results/r_{}.jpg'.format(step), np.squeeze(image_ab, axis=0))
  159. if step % self._log_step == 0:
  160. summary_writer.add_summary(fetched[-1], step)
  161. summary_writer.flush()
  162. t.set_description('Loss({:.3f})'.format(fetched[0]))
  163. def test(self, sess, data_A, data_B, base_dir):
  164. step = 0
  165. for (dataA, dataB) in tqdm(zip(data_A, data_B)):
  166. step += 1
  167. image_a = np.expand_dims(dataA, axis=0)
  168. image_b = np.expand_dims(dataB, axis=0)
  169. images = []
  170. images.append(image_a)
  171. images.append(image_b)
  172. for i in range(23):
  173. z = np.random.uniform(-1, 1, size=(1, self._latent_dim))
  174. image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
  175. self.z: z,
  176. self.is_train: True})
  177. images.append(image_ab)
  178. image_rows = []
  179. for i in range(5):
  180. image_rows.append(np.concatenate(images[i*5:(i+1)*5], axis=2))
  181. images = np.concatenate(image_rows, axis=1)
  182. images = np.squeeze(images, axis=0)
  183. imsave(os.path.join(base_dir, 'random_{}.jpg'.format(step)), images)
  184. step=0
  185. for (dataA, dataB) in tqdm(zip(data_A, data_B)):
  186. step += 1
  187. image_a = np.expand_dims(dataA, axis=0)
  188. image_b = np.expand_dims(dataB, axis=0)
  189. images = []
  190. images.append(image_a)
  191. images.append(image_b)
  192. for i in range(23):
  193. z = np.zeros((1, self._latent_dim))
  194. z[0][0] = (i / 23.0 - 0.5) * 2.0
  195. image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
  196. self.z: z,
  197. self.is_train: True})
  198. images.append(image_ab)
  199. image_rows = []
  200. for i in range(5):
  201. image_rows.append(np.concatenate(images[i*5:(i+1)*5], axis=2))
  202. images = np.concatenate(image_rows, axis=1)
  203. images = np.squeeze(images, axis=0)
  204. imsave(os.path.join(base_dir, 'linear_{}.jpg'.format(step)), images)