|
@@ -21,6 +21,7 @@ class BicycleGAN(object):
|
|
self._coeff_reconstruct = args.coeff_reconstruct
|
|
self._coeff_reconstruct = args.coeff_reconstruct
|
|
self._coeff_latent = args.coeff_latent
|
|
self._coeff_latent = args.coeff_latent
|
|
self._coeff_kl = args.coeff_kl
|
|
self._coeff_kl = args.coeff_kl
|
|
|
|
+ self._norm = 'instance' if args.instance_normalization else 'batch'
|
|
|
|
|
|
self._augment_size = self._image_size + (30 if self._image_size == 256 else 15)
|
|
self._augment_size = self._image_size + (30 if self._image_size == 256 else 15)
|
|
self._image_shape = [self._image_size, self._image_size, 3]
|
|
self._image_shape = [self._image_size, self._image_size, 3]
|
|
@@ -53,16 +54,16 @@ class BicycleGAN(object):
|
|
|
|
|
|
# Generator
|
|
# Generator
|
|
G = Generator('G', is_train=self.is_train,
|
|
G = Generator('G', is_train=self.is_train,
|
|
- norm='batch', image_size=self._image_size)
|
|
|
|
|
|
+ norm=self._norm, image_size=self._image_size)
|
|
|
|
|
|
# Discriminator
|
|
# Discriminator
|
|
D = Discriminator('D', is_train=self.is_train,
|
|
D = Discriminator('D', is_train=self.is_train,
|
|
- norm='batch', activation='leaky',
|
|
|
|
|
|
+ norm=self._norm, activation='leaky',
|
|
image_size=self._image_size)
|
|
image_size=self._image_size)
|
|
|
|
|
|
# Encoder
|
|
# Encoder
|
|
E = Encoder('E', is_train=self.is_train,
|
|
E = Encoder('E', is_train=self.is_train,
|
|
- norm='batch', activation='relu',
|
|
|
|
|
|
+ norm=self._norm, activation='relu',
|
|
image_size=self._image_size, latent_dim=self._latent_dim)
|
|
image_size=self._image_size, latent_dim=self._latent_dim)
|
|
|
|
|
|
# conditional VAE-GAN: B -> z -> B'
|
|
# conditional VAE-GAN: B -> z -> B'
|
|
@@ -97,12 +98,14 @@ class BicycleGAN(object):
|
|
self._coeff_kl * loss_kl
|
|
self._coeff_kl * loss_kl
|
|
|
|
|
|
# Optimizer
|
|
# Optimizer
|
|
- self.optimizer_D = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
|
|
|
|
- .minimize(loss, var_list=D.var_list, global_step=self.global_step)
|
|
|
|
- self.optimizer_G = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
|
|
|
|
- .minimize(-loss, var_list=G.var_list)
|
|
|
|
- self.optimizer_E = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
|
|
|
|
- .minimize(-loss, var_list=E.var_list)
|
|
|
|
|
|
+ update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
|
|
|
|
+ with tf.control_dependencies(update_ops):
|
|
|
|
+ self.optimizer_D = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
|
|
|
|
+ .minimize(loss, var_list=D.var_list, global_step=self.global_step)
|
|
|
|
+ self.optimizer_G = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
|
|
|
|
+ .minimize(-loss, var_list=G.var_list)
|
|
|
|
+ self.optimizer_E = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
|
|
|
|
+ .minimize(-loss, var_list=E.var_list)
|
|
|
|
|
|
# Summaries
|
|
# Summaries
|
|
self.loss_vae_gan = loss_vae_gan
|
|
self.loss_vae_gan = loss_vae_gan
|
|
@@ -164,7 +167,6 @@ class BicycleGAN(object):
|
|
|
|
|
|
image_a = np.stack(data_A[iter*self._batch_size:(iter+1)*self._batch_size])
|
|
image_a = np.stack(data_A[iter*self._batch_size:(iter+1)*self._batch_size])
|
|
image_b = np.stack(data_B[iter*self._batch_size:(iter+1)*self._batch_size])
|
|
image_b = np.stack(data_B[iter*self._batch_size:(iter+1)*self._batch_size])
|
|
- #sample_z = np.random.uniform(-1, 1, size=(self._batch_size, self._latent_dim))
|
|
|
|
sample_z = np.random.normal(size=(self._batch_size, self._latent_dim))
|
|
sample_z = np.random.normal(size=(self._batch_size, self._latent_dim))
|
|
|
|
|
|
fetches = [self.loss, self.optimizer_D,
|
|
fetches = [self.loss, self.optimizer_D,
|
|
@@ -203,7 +205,7 @@ class BicycleGAN(object):
|
|
images_linear.append(image_b)
|
|
images_linear.append(image_b)
|
|
|
|
|
|
for i in range(23):
|
|
for i in range(23):
|
|
- z = np.random.uniform(-1, 1, size=(1, self._latent_dim))
|
|
|
|
|
|
+ z = np.random.normal(size=(1, self._latent_dim))
|
|
image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
|
|
image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
|
|
self.z: z,
|
|
self.z: z,
|
|
self.is_train: False})
|
|
self.is_train: False})
|