Browse Source

update bicycleGAN following NIPS2017 paper

Youngwoon Lee 7 years ago
parent
commit
5597a8e691
6 changed files with 69 additions and 107 deletions
  1. 7 7
      bicycle-gan.py
  2. 0 26
      discriminator_z.py
  3. 6 2
      encoder.py
  4. 5 6
      generator.py
  5. 50 64
      model.py
  6. 1 2
      ops.py

+ 7 - 7
bicycle-gan.py

@@ -19,19 +19,19 @@ parser.add_argument('--train', default=True, type=str2bool,
                     help="Training mode")
 parser.add_argument('--task', type=str, default='edges2shoes',
                     help='Task name')
-parser.add_argument('--gamma', type=float, default=1,
-                    help='Loss coefficient')
-parser.add_argument('--lambda1', type=float, default=1,
-                    help='Loss coefficient')
-parser.add_argument('--lambda2', type=float, default=1,
-                    help='Loss coefficient')
+parser.add_argument('--coeff_kl', type=float, default=0.01,
+                    help='Loss coefficient for KL divergence')
+parser.add_argument('--coeff_reconstruct', type=float, default=10,
+                    help='Loss coefficient for reconstruct')
+parser.add_argument('--coeff_latent', type=float, default=0.5,
+                    help='Loss coefficient for latent cycle')
 parser.add_argument('--instance_normalization', default=False, type=bool,
                     help="Use instance norm instead of batch norm")
 parser.add_argument('--log_step', default=100, type=int,
                     help="Tensorboard log frequency")
 parser.add_argument('--batch_size', default=1, type=int,
                     help="Batch size")
-parser.add_argument('--image_size', default=128, type=int,
+parser.add_argument('--image_size', default=256, type=int,
                     help="Image size")
 parser.add_argument('--latent_dim', default=8, type=int,
                     help="Dimensionality of latent vector")

+ 0 - 26
discriminator_z.py

@@ -1,26 +0,0 @@
-import tensorflow as tf
-from utils import logger
-import ops
-
-
-class DiscriminatorZ(object):
-    def __init__(self, name, is_train, norm='batch', activation='relu'):
-        logger.info('Init DiscriminatorZ %s', name)
-        self.name = name
-        self._is_train = is_train
-        self._norm = norm
-        self._activation = activation
-        self._reuse = False
-
-    def __call__(self, input):
-        with tf.variable_scope(self.name, reuse=self._reuse):
-            D = input
-            for i in range(3):
-                D = ops.mlp(D, 512, 'FC512_{}'.format(i), self._is_train,
-                            self._reuse, self._norm, self._activation)
-            D = ops.mlp(D, 1, 'FC1_{}'.format(i), self._is_train,
-                        self._reuse, norm=None, activation=None)
-
-            self._reuse = True
-            self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
-            return D

+ 6 - 2
encoder.py

@@ -26,9 +26,13 @@ class Encoder(object):
                 E = ops.conv_block(E, n, 'C{}_{}'.format(n, i), 4, 2, self._is_train,
                                 self._reuse, norm=self._norm if i else None, activation='leaky')
             E = tf.reshape(E, [-1, 512])
-            E = ops.mlp(E, self._latent_dim, 'FC8', self._is_train, self._reuse,
+            mu = ops.mlp(E, self._latent_dim, 'FC8_mu', self._is_train, self._reuse,
                         norm=None, activation=None)
+            log_sigma = ops.mlp(E, self._latent_dim, 'FC8_sigma', self._is_train, self._reuse,
+                        norm=None, activation=None)
+
+            z = mu + tf.random_normal(shape=tf.shape(self._latent_dim)) * tf.exp(log_sigma)
 
             self._reuse = True
             self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
-            return E
+            return z, mu, log_sigma

+ 5 - 6
generator.py

@@ -14,7 +14,6 @@ class Generator(object):
 
     def __call__(self, input, z):
         with tf.variable_scope(self.name, reuse=self._reuse):
-            self._dropout = tf.constant(1.0)
             batch_size = int(input.get_shape()[0])
             latent_dim = int(z.get_shape()[-1])
             num_filters = [64, 128, 256, 512, 512, 512, 512]
@@ -23,24 +22,24 @@ class Generator(object):
 
             layers = []
             G = input
+            z = tf.reshape(z, [batch_size, 1, 1, latent_dim])
+            z = tf.tile(z, [1, self._image_size, self._image_size, 1])
+            G = tf.concat([G, z], axis=3)
             for i, n in enumerate(num_filters):
                 G = ops.conv_block(G, n, 'C{}_{}'.format(n, i), 4, 2, self._is_train,
                                 self._reuse, norm=self._norm if i else None, activation='leaky')
                 layers.append(G)
 
-            z = tf.reshape(z, [batch_size, 1, 1, latent_dim])
-            G = tf.concat([G, z], axis=3)
-
             layers.pop()
             num_filters.pop()
             num_filters.reverse()
 
             for i, n in enumerate(num_filters):
                 G = ops.deconv_block(G, n, 'CD{}_{}'.format(n, i), 4, 2, self._is_train,
-                                self._reuse, norm=self._norm, activation='relu', dropout=self._dropout)
+                                self._reuse, norm=self._norm, activation='relu')
                 G = tf.concat([G, layers.pop()], axis=3)
             G = ops.deconv_block(G, 3, 'last_layer', 4, 2, self._is_train,
-                               self._reuse, norm=None, activation='tanh', dropout=self._dropout)
+                               self._reuse, norm=None, activation='tanh')
 
             self._reuse = True
             self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)

+ 50 - 64
model.py

@@ -9,7 +9,6 @@ import numpy as np
 from generator import Generator
 from encoder import Encoder
 from discriminator import Discriminator
-from discriminator_z import DiscriminatorZ
 from utils import logger
 
 
@@ -19,9 +18,9 @@ class BicycleGAN(object):
         self._batch_size = args.batch_size
         self._image_size = args.image_size
         self._latent_dim = args.latent_dim
-        self._lambda1 = args.lambda1
-        self._lambda2 = args.lambda2
-        self._gamma = args.gamma
+        self._coeff_reconstruct = args.coeff_reconstruct
+        self._coeff_latent = args.coeff_latent
+        self._coeff_kl = args.coeff_kl
 
         self._augment_size = self._image_size + (30 if self._image_size == 256 else 15)
         self._image_shape = [self._image_size, self._image_size, 3]
@@ -60,53 +59,48 @@ class BicycleGAN(object):
         D = Discriminator('D', is_train=self.is_train,
                           norm='batch', activation='leaky',
                           image_size=self._image_size)
-        Dz = DiscriminatorZ('Dz', is_train=self.is_train,
-                             norm='batch', activation='relu')
 
         # Encoder
         E = Encoder('E', is_train=self.is_train,
                     norm='batch', activation='relu',
                     image_size=self._image_size, latent_dim=self._latent_dim)
 
-        # Generate images (a->b)
+        # conditional VAE-GAN: B -> z -> B'
+        z_encoded, z_encoded_mu, z_encoded_log_sigma = E(image_b)
+        image_ab_encoded = G(image_a, z_encoded)
+
+        # conditional Latent Regressor-GAN: z -> B' -> z'
         image_ab = self.image_ab = G(image_a, z)
-        z_reconstruct = E(image_ab)
+        z_recon, z_recon_mu, z_recon_log_sigma = E(image_ab)
 
-        # Encode z (G(A, z) -> z)
-        z_encoded = E(image_b)
-        image_ab_encoded = G(image_a, z_encoded)
 
         # Discriminate real/fake images
         D_real = D(image_b)
         D_fake = D(image_ab)
         D_fake_encoded = D(image_ab_encoded)
-        Dz_real = Dz(z)
-        Dz_fake = Dz(z_encoded)
 
-        loss_image_reconstruct = tf.reduce_mean(tf.abs(image_b - image_ab_encoded))
+        loss_vae_gan = (tf.reduce_mean(tf.squared_difference(D_real, 0.9)) +
+            tf.reduce_mean(tf.square(D_fake_encoded)))
 
-        loss_gan = (tf.reduce_mean(tf.squared_difference(D_real, 0.9)) +
-            tf.reduce_mean(tf.square(D_fake))) * 0.5
+        loss_image_cycle = tf.reduce_mean(tf.abs(image_b - image_ab_encoded))
 
-        loss_image_cycle = (tf.reduce_mean(tf.squared_difference(D_real, 0.9)) +
-            tf.reduce_mean(tf.square(D_fake_encoded))) * 0.5
+        loss_gan = (tf.reduce_mean(tf.squared_difference(D_real, 0.9)) +
+            tf.reduce_mean(tf.square(D_fake)))
 
-        loss_latent_cycle = tf.reduce_mean(tf.abs(z - z_reconstruct))
+        loss_latent_cycle = tf.reduce_mean(tf.abs(z - z_recon))
 
-        loss_Dz = (tf.reduce_mean(tf.squared_difference(Dz_real, 0.9)) +
-            tf.reduce_mean(tf.square(Dz_fake))) * 0.5
+        loss_kl = -0.5 * tf.reduce_mean(1 + 2 * z_encoded_log_sigma - z_encoded_mu ** 2 -
+                                       tf.exp(2 * z_encoded_log_sigma), 1)
 
-        loss = self._gamma * loss_Dz \
-            + loss_image_cycle - self._lambda1 * loss_image_reconstruct \
-            + loss_gan - self._lambda2 * loss_latent_cycle
+        loss = loss_vae_gan + self._coeff_reconstruct * loss_image_cycle + \
+            loss_gan + self._coeff_latent * loss_latent_cycle + \
+            self._coeff_kl * loss_kl
 
         # Optimizer
         self.optimizer_D = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
                             .minimize(loss, var_list=D.var_list, global_step=self.global_step)
         self.optimizer_G = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
                             .minimize(-loss, var_list=G.var_list)
-        self.optimizer_Dz = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
-                            .minimize(loss, var_list=Dz.var_list)
         self.optimizer_E = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
                             .minimize(-loss, var_list=E.var_list)
 
@@ -115,20 +109,18 @@ class BicycleGAN(object):
         self.loss_image_cycle = loss_image_cycle
         self.loss_latent_cycle = loss_latent_cycle
         self.loss_gan = loss_gan
-        self.loss_Dz = loss_Dz
+        self.loss_z_kl = loss_z_kl
         self.loss = loss
 
         tf.summary.scalar('loss/image_reconstruct', loss_image_reconstruct)
         tf.summary.scalar('loss/image_cycle', loss_image_cycle)
         tf.summary.scalar('loss/latent_cycle', loss_latent_cycle)
         tf.summary.scalar('loss/gan', loss_gan)
-        tf.summary.scalar('loss/Dz', loss_Dz)
+        tf.summary.scalar('loss/Dz', loss_z_kl)
         tf.summary.scalar('loss/total', loss)
         tf.summary.scalar('model/D_real', tf.reduce_mean(D_real))
         tf.summary.scalar('model/D_fake', tf.reduce_mean(D_fake))
         tf.summary.scalar('model/D_fake_encoded', tf.reduce_mean(D_fake_encoded))
-        tf.summary.scalar('model/Dz_real', tf.reduce_mean(Dz_real))
-        tf.summary.scalar('model/Dz_fake', tf.reduce_mean(Dz_fake))
         tf.summary.scalar('model/lr', self.lr)
         tf.summary.image('image/A', image_a[0:1])
         tf.summary.image('image/B', image_b[0:1])
@@ -141,7 +133,9 @@ class BicycleGAN(object):
         logger.info('  {} images from A'.format(len(data_A)))
         logger.info('  {} images from B'.format(len(data_B)))
 
-        data_size = min(len(data_A), len(data_B))
+        assert len(data_A) == len(data_B), \
+            'Data size mismatch dataA(%d) dataB(%d)' % (len(data_A), len(data_B))
+        data_size = len(data_A)
         num_batch = data_size // self._batch_size
         epoch_length = num_batch * self._batch_size
 
@@ -170,7 +164,8 @@ class BicycleGAN(object):
 
             image_a = np.stack(data_A[iter*self._batch_size:(iter+1)*self._batch_size])
             image_b = np.stack(data_B[iter*self._batch_size:(iter+1)*self._batch_size])
-            sample_z = np.random.uniform(-1, 1, size=(self._batch_size, self._latent_dim))
+            #sample_z = np.random.uniform(-1, 1, size=(self._batch_size, self._latent_dim))
+            sample_z = np.random.normal(size=(self._batch_size, self._latent_dim))
 
             fetches = [self.loss,
                        self.optimizer_D, self.optimizer_Dz,
@@ -184,15 +179,13 @@ class BicycleGAN(object):
                                                    self.lr: lr,
                                                    self.z: sample_z})
 
-            z = np.random.uniform(-1, 1, size=(1, self._latent_dim))
-            image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
-                                                   self.image_b: image_b,
-                                                   self.lr: lr,
-                                                    self.z: z,
-                                                    self.is_train: True})
-            imsave('results/r_{}.jpg'.format(step), np.squeeze(image_ab, axis=0))
-
             if step % self._log_step == 0:
+                z = np.random.normal(size=(1, self._latent_dim))
+                image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
+                                                            self.z: z,
+                                                            self.is_train: False})
+                imsave('results/r_{}.jpg'.format(step), np.squeeze(image_ab, axis=0))
+
                 summary_writer.add_summary(fetched[-1], step)
                 summary_writer.flush()
                 t.set_description('Loss({:.3f})'.format(fetched[0]))
@@ -203,44 +196,37 @@ class BicycleGAN(object):
             step += 1
             image_a = np.expand_dims(dataA, axis=0)
             image_b = np.expand_dims(dataB, axis=0)
-            images = []
-            images.append(image_a)
-            images.append(image_b)
+            images_random = []
+            images_random.append(image_a)
+            images_random.append(image_b)
+            images_linear = []
+            images_linear.append(image_a)
+            images_linear.append(image_b)
 
             for i in range(23):
                 z = np.random.uniform(-1, 1, size=(1, self._latent_dim))
                 image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
                                                         self.z: z,
-                                                        self.is_train: True})
-                images.append(image_ab)
+                                                        self.is_train: False})
+                images_random.append(image_ab)
+
+                z = np.zeros((1, self._latent_dim))
+                z[0][0] = (i / 23.0 - 0.5) * 2.0
+                image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
+                                                        self.z: z,
+                                                        self.is_train: False})
+                images_linear.append(image_ab)
 
             image_rows = []
             for i in range(5):
-                image_rows.append(np.concatenate(images[i*5:(i+1)*5], axis=2))
+                image_rows.append(np.concatenate(images_random[i*5:(i+1)*5], axis=2))
             images = np.concatenate(image_rows, axis=1)
             images = np.squeeze(images, axis=0)
             imsave(os.path.join(base_dir, 'random_{}.jpg'.format(step)), images)
 
-        step=0
-        for (dataA, dataB) in tqdm(zip(data_A, data_B)):
-            step += 1
-            image_a = np.expand_dims(dataA, axis=0)
-            image_b = np.expand_dims(dataB, axis=0)
-            images = []
-            images.append(image_a)
-            images.append(image_b)
-
-            for i in range(23):
-                z = np.zeros((1, self._latent_dim))
-                z[0][0] = (i / 23.0 - 0.5) * 2.0
-                image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
-                                                        self.z: z,
-                                                        self.is_train: True})
-                images.append(image_ab)
-
             image_rows = []
             for i in range(5):
-                image_rows.append(np.concatenate(images[i*5:(i+1)*5], axis=2))
+                image_rows.append(np.concatenate(images_linear[i*5:(i+1)*5], axis=2))
             images = np.concatenate(image_rows, axis=1)
             images = np.squeeze(images, axis=0)
             imsave(os.path.join(base_dir, 'linear_{}.jpg'.format(step)), images)

+ 1 - 2
ops.py

@@ -106,10 +106,9 @@ def residual(input, num_filters, name, is_train, reuse, norm, pad='REFLECT'):
         return tf.nn.relu(input + out)
 
 def deconv_block(input, num_filters, name, k_size, stride, is_train, reuse,
-                 norm, activation, dropout):
+                 norm, activation):
     with tf.variable_scope(name, reuse=reuse):
         out = conv2d_transpose(input, num_filters, k_size, stride, reuse)
         out = _norm(out, is_train, reuse, norm)
-        out = tf.nn.dropout(out, dropout)
         out = _activation(out, activation)
         return out