|
@@ -27,7 +27,7 @@ class BicycleGAN(object):
|
|
|
|
|
|
self.is_train = tf.placeholder(tf.bool, name='is_train')
|
|
|
self.lr = tf.placeholder(tf.float32, name='lr')
|
|
|
- self.global_step = tf.contrib.framework.get_or_create_global_step(graph=None)
|
|
|
+ self.global_step = tf.train.get_or_create_global_step(graph=None)
|
|
|
|
|
|
image_a = self.image_a = \
|
|
|
tf.placeholder(tf.float32, [self._batch_size] + self._image_shape, name='image_a')
|
|
@@ -90,10 +90,10 @@ class BicycleGAN(object):
|
|
|
loss_latent_cycle = tf.reduce_mean(tf.abs(z - z_recon))
|
|
|
|
|
|
loss_kl = -0.5 * tf.reduce_mean(1 + 2 * z_encoded_log_sigma - z_encoded_mu ** 2 -
|
|
|
- tf.exp(2 * z_encoded_log_sigma), 1)
|
|
|
+ tf.exp(2 * z_encoded_log_sigma))
|
|
|
|
|
|
- loss = loss_vae_gan + self._coeff_reconstruct * loss_image_cycle + \
|
|
|
- loss_gan + self._coeff_latent * loss_latent_cycle + \
|
|
|
+ loss = loss_vae_gan - self._coeff_reconstruct * loss_image_cycle + \
|
|
|
+ loss_gan - self._coeff_latent * loss_latent_cycle - \
|
|
|
self._coeff_kl * loss_kl
|
|
|
|
|
|
# Optimizer
|
|
@@ -105,18 +105,18 @@ class BicycleGAN(object):
|
|
|
.minimize(-loss, var_list=E.var_list)
|
|
|
|
|
|
# Summaries
|
|
|
- self.loss_image_reconstruct = loss_image_reconstruct
|
|
|
+ self.loss_vae_gan = loss_vae_gan
|
|
|
self.loss_image_cycle = loss_image_cycle
|
|
|
self.loss_latent_cycle = loss_latent_cycle
|
|
|
self.loss_gan = loss_gan
|
|
|
- self.loss_z_kl = loss_z_kl
|
|
|
+ self.loss_kl = loss_kl
|
|
|
self.loss = loss
|
|
|
|
|
|
- tf.summary.scalar('loss/image_reconstruct', loss_image_reconstruct)
|
|
|
+ tf.summary.scalar('loss/vae_gan', loss_vae_gan)
|
|
|
tf.summary.scalar('loss/image_cycle', loss_image_cycle)
|
|
|
tf.summary.scalar('loss/latent_cycle', loss_latent_cycle)
|
|
|
tf.summary.scalar('loss/gan', loss_gan)
|
|
|
- tf.summary.scalar('loss/Dz', loss_z_kl)
|
|
|
+ tf.summary.scalar('loss/kl', loss_kl)
|
|
|
tf.summary.scalar('loss/total', loss)
|
|
|
tf.summary.scalar('model/D_real', tf.reduce_mean(D_real))
|
|
|
tf.summary.scalar('model/D_fake', tf.reduce_mean(D_fake))
|
|
@@ -167,8 +167,7 @@ class BicycleGAN(object):
|
|
|
#sample_z = np.random.uniform(-1, 1, size=(self._batch_size, self._latent_dim))
|
|
|
sample_z = np.random.normal(size=(self._batch_size, self._latent_dim))
|
|
|
|
|
|
- fetches = [self.loss,
|
|
|
- self.optimizer_D, self.optimizer_Dz,
|
|
|
+ fetches = [self.loss, self.optimizer_D,
|
|
|
self.optimizer_G, self.optimizer_E]
|
|
|
if step % self._log_step == 0:
|
|
|
fetches += [self.summary_op]
|