Sfoglia il codice sorgente

fix batchnorm for testing

Youngwoon Lee 7 anni fa
parent
commit
0487dc0ace
2 ha cambiato i file con 14 aggiunte e 13 eliminazioni
  1. 13 11
      model.py
  2. 1 2
      ops.py

+ 13 - 11
model.py

@@ -21,6 +21,7 @@ class BicycleGAN(object):
         self._coeff_reconstruct = args.coeff_reconstruct
         self._coeff_latent = args.coeff_latent
         self._coeff_kl = args.coeff_kl
+        self._norm = 'instance' if args.instance_normalization else 'batch'
 
         self._augment_size = self._image_size + (30 if self._image_size == 256 else 15)
         self._image_shape = [self._image_size, self._image_size, 3]
@@ -53,16 +54,16 @@ class BicycleGAN(object):
 
         # Generator
         G = Generator('G', is_train=self.is_train,
-                      norm='batch', image_size=self._image_size)
+                      norm=self._norm, image_size=self._image_size)
 
         # Discriminator
         D = Discriminator('D', is_train=self.is_train,
-                          norm='batch', activation='leaky',
+                          norm=self._norm, activation='leaky',
                           image_size=self._image_size)
 
         # Encoder
         E = Encoder('E', is_train=self.is_train,
-                    norm='batch', activation='relu',
+                    norm=self._norm, activation='relu',
                     image_size=self._image_size, latent_dim=self._latent_dim)
 
         # conditional VAE-GAN: B -> z -> B'
@@ -97,12 +98,14 @@ class BicycleGAN(object):
             self._coeff_kl * loss_kl
 
         # Optimizer
-        self.optimizer_D = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
-                            .minimize(loss, var_list=D.var_list, global_step=self.global_step)
-        self.optimizer_G = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
-                            .minimize(-loss, var_list=G.var_list)
-        self.optimizer_E = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
-                            .minimize(-loss, var_list=E.var_list)
+        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
+        with tf.control_dependencies(update_ops):
+            self.optimizer_D = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
+                                .minimize(loss, var_list=D.var_list, global_step=self.global_step)
+            self.optimizer_G = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
+                                .minimize(-loss, var_list=G.var_list)
+            self.optimizer_E = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
+                                .minimize(-loss, var_list=E.var_list)
 
         # Summaries
         self.loss_vae_gan = loss_vae_gan
@@ -164,7 +167,6 @@ class BicycleGAN(object):
 
             image_a = np.stack(data_A[iter*self._batch_size:(iter+1)*self._batch_size])
             image_b = np.stack(data_B[iter*self._batch_size:(iter+1)*self._batch_size])
-            #sample_z = np.random.uniform(-1, 1, size=(self._batch_size, self._latent_dim))
             sample_z = np.random.normal(size=(self._batch_size, self._latent_dim))
 
             fetches = [self.loss, self.optimizer_D,
@@ -203,7 +205,7 @@ class BicycleGAN(object):
             images_linear.append(image_b)
 
             for i in range(23):
-                z = np.random.uniform(-1, 1, size=(1, self._latent_dim))
+                z = np.random.normal(size=(1, self._latent_dim))
                 image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
                                                         self.z: z,
                                                         self.is_train: False})

+ 1 - 2
ops.py

@@ -20,8 +20,7 @@ def _norm(input, is_train, reuse=True, norm=None):
         with tf.variable_scope('batch_norm', reuse=reuse):
             out = tf.contrib.layers.batch_norm(input,
                                                decay=0.99, center=True,
-                                               scale=True, is_training=is_train,
-                                               updates_collections=None)
+                                               scale=True, is_training=True)
     else:
         out = input