Explorar o código

add coeff for vae loss

youngwoon %!s(int64=5) %!d(string=hai) anos
pai
achega
afd145be96
Modificáronse 2 ficheiros con 4 adicións e 1 borrados
  1. 2 0
      bicycle-gan.py
  2. 2 1
      model.py

+ 2 - 0
bicycle-gan.py

@@ -19,6 +19,8 @@ parser.add_argument('--train', default=True, type=str2bool,
                     help="Training mode")
 parser.add_argument('--task', type=str, default='edges2shoes',
                     help='Task name')
+parser.add_argument('--coeff_vae', type=float, default=1.0,
+                    help='Loss coefficient for VAE')
 parser.add_argument('--coeff_kl', type=float, default=0.01,
                     help='Loss coefficient for KL divergence')
 parser.add_argument('--coeff_reconstruct', type=float, default=10,

+ 2 - 1
model.py

@@ -18,6 +18,7 @@ class BicycleGAN(object):
         self._batch_size = args.batch_size
         self._image_size = args.image_size
         self._latent_dim = args.latent_dim
+        self._coeff_vae = args.coeff_vae
         self._coeff_reconstruct = args.coeff_reconstruct
         self._coeff_latent = args.coeff_latent
         self._coeff_kl = args.coeff_kl
@@ -95,7 +96,7 @@ class BicycleGAN(object):
         loss_kl = -0.5 * tf.reduce_mean(1 + 2 * z_encoded_log_sigma - z_encoded_mu ** 2 -
                                        tf.exp(2 * z_encoded_log_sigma))
 
-        loss = loss_vae_gan - self._coeff_reconstruct * loss_image_cycle + \
+        loss = self._coeff_vae * loss_vae_gan - self._coeff_reconstruct * loss_image_cycle + \
             loss_gan - self._coeff_latent * loss_latent_cycle - \
             self._coeff_kl * loss_kl