Pārlūkot izejas kodu

add coeff for gan loss

youngwoon 6 gadi atpakaļ
vecāks
revīzija
7feb637ba9
2 mainītis faili ar 7 papildinājumiem un 4 dzēšanām
  1. 5 3
      bicycle-gan.py
  2. 2 1
      model.py

+ 5 - 3
bicycle-gan.py

@@ -19,14 +19,16 @@ parser.add_argument('--train', default=True, type=str2bool,
                     help="Training mode")
 parser.add_argument('--task', type=str, default='edges2shoes',
                     help='Task name')
+parser.add_argument('--coeff_gan', type=float, default=1.0,
+                    help='Loss coefficient for GAN loss')
 parser.add_argument('--coeff_vae', type=float, default=1.0,
-                    help='Loss coefficient for VAE')
+                    help='Loss coefficient for VAE loss')
 parser.add_argument('--coeff_kl', type=float, default=0.01,
                     help='Loss coefficient for KL divergence')
 parser.add_argument('--coeff_reconstruct', type=float, default=10,
-                    help='Loss coefficient for reconstruct')
+                    help='Loss coefficient for reconstruction error')
 parser.add_argument('--coeff_latent', type=float, default=0.5,
-                    help='Loss coefficient for latent cycle')
+                    help='Loss coefficient for latent cycle loss')
 parser.add_argument('--instance_normalization', default=False, type=bool,
                     help="Use instance norm instead of batch norm")
 parser.add_argument('--log_step', default=100, type=int,

+ 2 - 1
model.py

@@ -18,6 +18,7 @@ class BicycleGAN(object):
         self._batch_size = args.batch_size
         self._image_size = args.image_size
         self._latent_dim = args.latent_dim
+        self._coeff_gan = args.coeff_gan
         self._coeff_vae = args.coeff_vae
         self._coeff_reconstruct = args.coeff_reconstruct
         self._coeff_latent = args.coeff_latent
@@ -97,7 +98,7 @@ class BicycleGAN(object):
                                        tf.exp(2 * z_encoded_log_sigma))
 
         loss = self._coeff_vae * loss_vae_gan - self._coeff_reconstruct * loss_image_cycle + \
-            loss_gan - self._coeff_latent * loss_latent_cycle - \
+            self._coeff_gan * loss_gan - self._coeff_latent * loss_latent_cycle - \
             self._coeff_kl * loss_kl
 
         # Optimizer