|
@@ -0,0 +1,246 @@
|
|
|
+import os
|
|
|
+import random
|
|
|
+
|
|
|
+from tqdm import trange, tqdm
|
|
|
+from scipy.misc import imsave
|
|
|
+import tensorflow as tf
|
|
|
+import numpy as np
|
|
|
+
|
|
|
+from generator import Generator
|
|
|
+from encoder import Encoder
|
|
|
+from discriminator import Discriminator
|
|
|
+from discriminator_z import DiscriminatorZ
|
|
|
+from utils import logger
|
|
|
+
|
|
|
+
|
|
|
+class BicycleGAN(object):
|
|
|
+ def __init__(self, args):
|
|
|
+ self._log_step = args.log_step
|
|
|
+ self._batch_size = args.batch_size
|
|
|
+ self._image_size = args.image_size
|
|
|
+ self._latent_dim = args.latent_dim
|
|
|
+ self._lambda1 = args.lambda1
|
|
|
+ self._lambda2 = args.lambda2
|
|
|
+ self._gamma = args.gamma
|
|
|
+
|
|
|
+ self._augment_size = self._image_size + (30 if self._image_size == 256 else 15)
|
|
|
+ self._image_shape = [self._image_size, self._image_size, 3]
|
|
|
+
|
|
|
+ self.is_train = tf.placeholder(tf.bool, name='is_train')
|
|
|
+ self.lr = tf.placeholder(tf.float32, name='lr')
|
|
|
+ self.global_step = tf.contrib.framework.get_or_create_global_step(graph=None)
|
|
|
+
|
|
|
+ image_a = self.image_a = \
|
|
|
+ tf.placeholder(tf.float32, [self._batch_size] + self._image_shape, name='image_a')
|
|
|
+ image_b = self.image_b = \
|
|
|
+ tf.placeholder(tf.float32, [self._batch_size] + self._image_shape, name='image_b')
|
|
|
+ z = self.z = \
|
|
|
+ tf.placeholder(tf.float32, [self._batch_size, self._latent_dim], name='z')
|
|
|
+
|
|
|
+ # Data augmentation
|
|
|
+ seed = random.randint(0, 2**31 - 1)
|
|
|
+ def augment_image(image):
|
|
|
+ image = tf.image.resize_images(image, [self._augment_size, self._augment_size])
|
|
|
+ image = tf.random_crop(image, [self._batch_size] + self._image_shape, seed=seed)
|
|
|
+ image = tf.map_fn(lambda x: tf.image.random_flip_left_right(x, seed), image)
|
|
|
+ return image
|
|
|
+
|
|
|
+ image_a = tf.cond(self.is_train,
|
|
|
+ lambda: augment_image(image_a),
|
|
|
+ lambda: image_a)
|
|
|
+ image_b = tf.cond(self.is_train,
|
|
|
+ lambda: augment_image(image_b),
|
|
|
+ lambda: image_b)
|
|
|
+
|
|
|
+ # Generator
|
|
|
+ G = Generator('G', is_train=self.is_train,
|
|
|
+ norm='batch', image_size=self._image_size)
|
|
|
+
|
|
|
+ # Discriminator
|
|
|
+ D = Discriminator('D', is_train=self.is_train,
|
|
|
+ norm='batch', activation='leaky',
|
|
|
+ image_size=self._image_size)
|
|
|
+ Dz = DiscriminatorZ('Dz', is_train=self.is_train,
|
|
|
+ norm='batch', activation='relu')
|
|
|
+
|
|
|
+ # Encoder
|
|
|
+ E = Encoder('E', is_train=self.is_train,
|
|
|
+ norm='batch', activation='relu',
|
|
|
+ image_size=self._image_size, latent_dim=self._latent_dim)
|
|
|
+
|
|
|
+ # Generate images (a->b)
|
|
|
+ image_ab = self.image_ab = G(image_a, z)
|
|
|
+ z_reconstruct = E(image_ab)
|
|
|
+
|
|
|
+ # Encode z (G(A, z) -> z)
|
|
|
+ z_encoded = E(image_b)
|
|
|
+ image_ab_encoded = G(image_a, z_encoded)
|
|
|
+
|
|
|
+ # Discriminate real/fake images
|
|
|
+ D_real = D(image_b)
|
|
|
+ D_fake = D(image_ab)
|
|
|
+ D_fake_encoded = D(image_ab_encoded)
|
|
|
+ Dz_real = Dz(z)
|
|
|
+ Dz_fake = Dz(z_encoded)
|
|
|
+
|
|
|
+ loss_image_reconstruct = tf.reduce_mean(tf.abs(image_b - image_ab_encoded))
|
|
|
+
|
|
|
+ loss_gan = (tf.reduce_mean(tf.squared_difference(D_real, 0.9)) +
|
|
|
+ tf.reduce_mean(tf.square(D_fake))) * 0.5
|
|
|
+
|
|
|
+ loss_image_cycle = (tf.reduce_mean(tf.squared_difference(D_real, 0.9)) +
|
|
|
+ tf.reduce_mean(tf.square(D_fake_encoded))) * 0.5
|
|
|
+
|
|
|
+ loss_latent_cycle = tf.reduce_mean(tf.abs(z - z_reconstruct))
|
|
|
+
|
|
|
+ loss_Dz = (tf.reduce_mean(tf.squared_difference(Dz_real, 0.9)) +
|
|
|
+ tf.reduce_mean(tf.square(Dz_fake))) * 0.5
|
|
|
+
|
|
|
+ loss = self._gamma * loss_Dz \
|
|
|
+ + loss_image_cycle - self._lambda1 * loss_image_reconstruct \
|
|
|
+ + loss_gan - self._lambda2 * loss_latent_cycle
|
|
|
+
|
|
|
+ # Optimizer
|
|
|
+ self.optimizer_D = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
|
|
|
+ .minimize(loss, var_list=D.var_list, global_step=self.global_step)
|
|
|
+ self.optimizer_G = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
|
|
|
+ .minimize(-loss, var_list=G.var_list)
|
|
|
+ self.optimizer_Dz = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
|
|
|
+ .minimize(loss, var_list=Dz.var_list)
|
|
|
+ self.optimizer_E = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
|
|
|
+ .minimize(-loss, var_list=E.var_list)
|
|
|
+
|
|
|
+ # Summaries
|
|
|
+ self.loss_image_reconstruct = loss_image_reconstruct
|
|
|
+ self.loss_image_cycle = loss_image_cycle
|
|
|
+ self.loss_latent_cycle = loss_latent_cycle
|
|
|
+ self.loss_gan = loss_gan
|
|
|
+ self.loss_Dz = loss_Dz
|
|
|
+ self.loss = loss
|
|
|
+
|
|
|
+ tf.summary.scalar('loss/image_reconstruct', loss_image_reconstruct)
|
|
|
+ tf.summary.scalar('loss/image_cycle', loss_image_cycle)
|
|
|
+ tf.summary.scalar('loss/latent_cycle', loss_latent_cycle)
|
|
|
+ tf.summary.scalar('loss/gan', loss_gan)
|
|
|
+ tf.summary.scalar('loss/Dz', loss_Dz)
|
|
|
+ tf.summary.scalar('loss/total', loss)
|
|
|
+ tf.summary.scalar('model/D_real', tf.reduce_mean(D_real))
|
|
|
+ tf.summary.scalar('model/D_fake', tf.reduce_mean(D_fake))
|
|
|
+ tf.summary.scalar('model/D_fake_encoded', tf.reduce_mean(D_fake_encoded))
|
|
|
+ tf.summary.scalar('model/Dz_real', tf.reduce_mean(Dz_real))
|
|
|
+ tf.summary.scalar('model/Dz_fake', tf.reduce_mean(Dz_fake))
|
|
|
+ tf.summary.scalar('model/lr', self.lr)
|
|
|
+ tf.summary.image('image/A', image_a[0:1])
|
|
|
+ tf.summary.image('image/B', image_b[0:1])
|
|
|
+ tf.summary.image('image/A-B', image_ab[0:1])
|
|
|
+ tf.summary.image('image/A-B_encoded', image_ab_encoded[0:1])
|
|
|
+ self.summary_op = tf.summary.merge_all()
|
|
|
+
|
|
|
+ def train(self, sess, summary_writer, data_A, data_B):
|
|
|
+ logger.info('Start training.')
|
|
|
+ logger.info(' {} images from A'.format(len(data_A)))
|
|
|
+ logger.info(' {} images from B'.format(len(data_B)))
|
|
|
+
|
|
|
+ data_size = min(len(data_A), len(data_B))
|
|
|
+ num_batch = data_size // self._batch_size
|
|
|
+ epoch_length = num_batch * self._batch_size
|
|
|
+
|
|
|
+ num_initial_iter = 8
|
|
|
+ num_decay_iter = 2
|
|
|
+ lr = lr_initial = 0.0002
|
|
|
+ lr_decay = lr_initial / num_decay_iter
|
|
|
+
|
|
|
+ initial_step = sess.run(self.global_step)
|
|
|
+ num_global_step = (num_initial_iter + num_decay_iter) * epoch_length
|
|
|
+ t = trange(initial_step, num_global_step,
|
|
|
+ total=num_global_step, initial=initial_step)
|
|
|
+
|
|
|
+ for step in t:
|
|
|
+ #TODO: resume training with global_step
|
|
|
+ epoch = step // epoch_length
|
|
|
+ iter = step % epoch_length
|
|
|
+
|
|
|
+ if epoch > num_initial_iter:
|
|
|
+ lr = max(0.0, lr_initial - (epoch - num_initial_iter) * lr_decay)
|
|
|
+
|
|
|
+ if iter == 0:
|
|
|
+ data = zip(data_A, data_B)
|
|
|
+ random.shuffle(data)
|
|
|
+ data_A, data_B = zip(*data)
|
|
|
+
|
|
|
+ image_a = np.stack(data_A[iter*self._batch_size:(iter+1)*self._batch_size])
|
|
|
+ image_b = np.stack(data_B[iter*self._batch_size:(iter+1)*self._batch_size])
|
|
|
+ sample_z = np.random.uniform(-1, 1, size=(self._batch_size, self._latent_dim))
|
|
|
+
|
|
|
+ fetches = [self.loss,
|
|
|
+ self.optimizer_D, self.optimizer_Dz,
|
|
|
+ self.optimizer_G, self.optimizer_E]
|
|
|
+ if step % self._log_step == 0:
|
|
|
+ fetches += [self.summary_op]
|
|
|
+
|
|
|
+ fetched = sess.run(fetches, feed_dict={self.image_a: image_a,
|
|
|
+ self.image_b: image_b,
|
|
|
+ self.is_train: True,
|
|
|
+ self.lr: lr,
|
|
|
+ self.z: sample_z})
|
|
|
+
|
|
|
+ z = np.random.uniform(-1, 1, size=(1, self._latent_dim))
|
|
|
+ image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
|
|
|
+ self.image_b: image_b,
|
|
|
+ self.lr: lr,
|
|
|
+ self.z: z,
|
|
|
+ self.is_train: True})
|
|
|
+ imsave('results/r_{}.jpg'.format(step), np.squeeze(image_ab, axis=0))
|
|
|
+
|
|
|
+ if step % self._log_step == 0:
|
|
|
+ summary_writer.add_summary(fetched[-1], step)
|
|
|
+ summary_writer.flush()
|
|
|
+ t.set_description('Loss({:.3f})'.format(fetched[0]))
|
|
|
+
|
|
|
+ def test(self, sess, data_A, data_B, base_dir):
|
|
|
+ step = 0
|
|
|
+ for (dataA, dataB) in tqdm(zip(data_A, data_B)):
|
|
|
+ step += 1
|
|
|
+ image_a = np.expand_dims(dataA, axis=0)
|
|
|
+ image_b = np.expand_dims(dataB, axis=0)
|
|
|
+ images = []
|
|
|
+ images.append(image_a)
|
|
|
+ images.append(image_b)
|
|
|
+
|
|
|
+ for i in range(23):
|
|
|
+ z = np.random.uniform(-1, 1, size=(1, self._latent_dim))
|
|
|
+ image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
|
|
|
+ self.z: z,
|
|
|
+ self.is_train: True})
|
|
|
+ images.append(image_ab)
|
|
|
+
|
|
|
+ image_rows = []
|
|
|
+ for i in range(5):
|
|
|
+ image_rows.append(np.concatenate(images[i*5:(i+1)*5], axis=2))
|
|
|
+ images = np.concatenate(image_rows, axis=1)
|
|
|
+ images = np.squeeze(images, axis=0)
|
|
|
+ imsave(os.path.join(base_dir, 'random_{}.jpg'.format(step)), images)
|
|
|
+
|
|
|
+ step=0
|
|
|
+ for (dataA, dataB) in tqdm(zip(data_A, data_B)):
|
|
|
+ step += 1
|
|
|
+ image_a = np.expand_dims(dataA, axis=0)
|
|
|
+ image_b = np.expand_dims(dataB, axis=0)
|
|
|
+ images = []
|
|
|
+ images.append(image_a)
|
|
|
+ images.append(image_b)
|
|
|
+
|
|
|
+ for i in range(23):
|
|
|
+ z = np.zeros((1, self._latent_dim))
|
|
|
+ z[0][0] = (i / 23.0 - 0.5) * 2.0
|
|
|
+ image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
|
|
|
+ self.z: z,
|
|
|
+ self.is_train: True})
|
|
|
+ images.append(image_ab)
|
|
|
+
|
|
|
+ image_rows = []
|
|
|
+ for i in range(5):
|
|
|
+ image_rows.append(np.concatenate(images[i*5:(i+1)*5], axis=2))
|
|
|
+ images = np.concatenate(image_rows, axis=1)
|
|
|
+ images = np.squeeze(images, axis=0)
|
|
|
+ imsave(os.path.join(base_dir, 'linear_{}.jpg'.format(step)), images)
|