|
@@ -9,7 +9,6 @@ import numpy as np
|
|
|
from generator import Generator
|
|
|
from encoder import Encoder
|
|
|
from discriminator import Discriminator
|
|
|
-from discriminator_z import DiscriminatorZ
|
|
|
from utils import logger
|
|
|
|
|
|
|
|
@@ -19,9 +18,9 @@ class BicycleGAN(object):
|
|
|
self._batch_size = args.batch_size
|
|
|
self._image_size = args.image_size
|
|
|
self._latent_dim = args.latent_dim
|
|
|
- self._lambda1 = args.lambda1
|
|
|
- self._lambda2 = args.lambda2
|
|
|
- self._gamma = args.gamma
|
|
|
+ self._coeff_reconstruct = args.coeff_reconstruct
|
|
|
+ self._coeff_latent = args.coeff_latent
|
|
|
+ self._coeff_kl = args.coeff_kl
|
|
|
|
|
|
self._augment_size = self._image_size + (30 if self._image_size == 256 else 15)
|
|
|
self._image_shape = [self._image_size, self._image_size, 3]
|
|
@@ -60,53 +59,48 @@ class BicycleGAN(object):
|
|
|
D = Discriminator('D', is_train=self.is_train,
|
|
|
norm='batch', activation='leaky',
|
|
|
image_size=self._image_size)
|
|
|
- Dz = DiscriminatorZ('Dz', is_train=self.is_train,
|
|
|
- norm='batch', activation='relu')
|
|
|
|
|
|
# Encoder
|
|
|
E = Encoder('E', is_train=self.is_train,
|
|
|
norm='batch', activation='relu',
|
|
|
image_size=self._image_size, latent_dim=self._latent_dim)
|
|
|
|
|
|
- # Generate images (a->b)
|
|
|
+ # conditional VAE-GAN: B -> z -> B'
|
|
|
+ z_encoded, z_encoded_mu, z_encoded_log_sigma = E(image_b)
|
|
|
+ image_ab_encoded = G(image_a, z_encoded)
|
|
|
+
|
|
|
+ # conditional Latent Regressor-GAN: z -> B' -> z'
|
|
|
image_ab = self.image_ab = G(image_a, z)
|
|
|
- z_reconstruct = E(image_ab)
|
|
|
+ z_recon, z_recon_mu, z_recon_log_sigma = E(image_ab)
|
|
|
|
|
|
- # Encode z (G(A, z) -> z)
|
|
|
- z_encoded = E(image_b)
|
|
|
- image_ab_encoded = G(image_a, z_encoded)
|
|
|
|
|
|
# Discriminate real/fake images
|
|
|
D_real = D(image_b)
|
|
|
D_fake = D(image_ab)
|
|
|
D_fake_encoded = D(image_ab_encoded)
|
|
|
- Dz_real = Dz(z)
|
|
|
- Dz_fake = Dz(z_encoded)
|
|
|
|
|
|
- loss_image_reconstruct = tf.reduce_mean(tf.abs(image_b - image_ab_encoded))
|
|
|
+ loss_vae_gan = (tf.reduce_mean(tf.squared_difference(D_real, 0.9)) +
|
|
|
+ tf.reduce_mean(tf.square(D_fake_encoded)))
|
|
|
|
|
|
- loss_gan = (tf.reduce_mean(tf.squared_difference(D_real, 0.9)) +
|
|
|
- tf.reduce_mean(tf.square(D_fake))) * 0.5
|
|
|
+ loss_image_cycle = tf.reduce_mean(tf.abs(image_b - image_ab_encoded))
|
|
|
|
|
|
- loss_image_cycle = (tf.reduce_mean(tf.squared_difference(D_real, 0.9)) +
|
|
|
- tf.reduce_mean(tf.square(D_fake_encoded))) * 0.5
|
|
|
+ loss_gan = (tf.reduce_mean(tf.squared_difference(D_real, 0.9)) +
|
|
|
+ tf.reduce_mean(tf.square(D_fake)))
|
|
|
|
|
|
- loss_latent_cycle = tf.reduce_mean(tf.abs(z - z_reconstruct))
|
|
|
+ loss_latent_cycle = tf.reduce_mean(tf.abs(z - z_recon))
|
|
|
|
|
|
- loss_Dz = (tf.reduce_mean(tf.squared_difference(Dz_real, 0.9)) +
|
|
|
- tf.reduce_mean(tf.square(Dz_fake))) * 0.5
|
|
|
+ loss_kl = -0.5 * tf.reduce_mean(1 + 2 * z_encoded_log_sigma - z_encoded_mu ** 2 -
|
|
|
+ tf.exp(2 * z_encoded_log_sigma), 1)
|
|
|
|
|
|
- loss = self._gamma * loss_Dz \
|
|
|
- + loss_image_cycle - self._lambda1 * loss_image_reconstruct \
|
|
|
- + loss_gan - self._lambda2 * loss_latent_cycle
|
|
|
+ loss = loss_vae_gan + self._coeff_reconstruct * loss_image_cycle + \
|
|
|
+ loss_gan + self._coeff_latent * loss_latent_cycle + \
|
|
|
+ self._coeff_kl * loss_kl
|
|
|
|
|
|
# Optimizer
|
|
|
self.optimizer_D = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
|
|
|
.minimize(loss, var_list=D.var_list, global_step=self.global_step)
|
|
|
self.optimizer_G = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
|
|
|
.minimize(-loss, var_list=G.var_list)
|
|
|
- self.optimizer_Dz = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
|
|
|
- .minimize(loss, var_list=Dz.var_list)
|
|
|
self.optimizer_E = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
|
|
|
.minimize(-loss, var_list=E.var_list)
|
|
|
|
|
@@ -115,20 +109,18 @@ class BicycleGAN(object):
|
|
|
self.loss_image_cycle = loss_image_cycle
|
|
|
self.loss_latent_cycle = loss_latent_cycle
|
|
|
self.loss_gan = loss_gan
|
|
|
- self.loss_Dz = loss_Dz
|
|
|
+ self.loss_z_kl = loss_z_kl
|
|
|
self.loss = loss
|
|
|
|
|
|
tf.summary.scalar('loss/image_reconstruct', loss_image_reconstruct)
|
|
|
tf.summary.scalar('loss/image_cycle', loss_image_cycle)
|
|
|
tf.summary.scalar('loss/latent_cycle', loss_latent_cycle)
|
|
|
tf.summary.scalar('loss/gan', loss_gan)
|
|
|
- tf.summary.scalar('loss/Dz', loss_Dz)
|
|
|
+ tf.summary.scalar('loss/Dz', loss_z_kl)
|
|
|
tf.summary.scalar('loss/total', loss)
|
|
|
tf.summary.scalar('model/D_real', tf.reduce_mean(D_real))
|
|
|
tf.summary.scalar('model/D_fake', tf.reduce_mean(D_fake))
|
|
|
tf.summary.scalar('model/D_fake_encoded', tf.reduce_mean(D_fake_encoded))
|
|
|
- tf.summary.scalar('model/Dz_real', tf.reduce_mean(Dz_real))
|
|
|
- tf.summary.scalar('model/Dz_fake', tf.reduce_mean(Dz_fake))
|
|
|
tf.summary.scalar('model/lr', self.lr)
|
|
|
tf.summary.image('image/A', image_a[0:1])
|
|
|
tf.summary.image('image/B', image_b[0:1])
|
|
@@ -141,7 +133,9 @@ class BicycleGAN(object):
|
|
|
logger.info(' {} images from A'.format(len(data_A)))
|
|
|
logger.info(' {} images from B'.format(len(data_B)))
|
|
|
|
|
|
- data_size = min(len(data_A), len(data_B))
|
|
|
+ assert len(data_A) == len(data_B), \
|
|
|
+ 'Data size mismatch dataA(%d) dataB(%d)' % (len(data_A), len(data_B))
|
|
|
+ data_size = len(data_A)
|
|
|
num_batch = data_size // self._batch_size
|
|
|
epoch_length = num_batch * self._batch_size
|
|
|
|
|
@@ -170,7 +164,8 @@ class BicycleGAN(object):
|
|
|
|
|
|
image_a = np.stack(data_A[iter*self._batch_size:(iter+1)*self._batch_size])
|
|
|
image_b = np.stack(data_B[iter*self._batch_size:(iter+1)*self._batch_size])
|
|
|
- sample_z = np.random.uniform(-1, 1, size=(self._batch_size, self._latent_dim))
|
|
|
+ #sample_z = np.random.uniform(-1, 1, size=(self._batch_size, self._latent_dim))
|
|
|
+ sample_z = np.random.normal(size=(self._batch_size, self._latent_dim))
|
|
|
|
|
|
fetches = [self.loss,
|
|
|
self.optimizer_D, self.optimizer_Dz,
|
|
@@ -184,15 +179,13 @@ class BicycleGAN(object):
|
|
|
self.lr: lr,
|
|
|
self.z: sample_z})
|
|
|
|
|
|
- z = np.random.uniform(-1, 1, size=(1, self._latent_dim))
|
|
|
- image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
|
|
|
- self.image_b: image_b,
|
|
|
- self.lr: lr,
|
|
|
- self.z: z,
|
|
|
- self.is_train: True})
|
|
|
- imsave('results/r_{}.jpg'.format(step), np.squeeze(image_ab, axis=0))
|
|
|
-
|
|
|
if step % self._log_step == 0:
|
|
|
+ z = np.random.normal(size=(1, self._latent_dim))
|
|
|
+ image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
|
|
|
+ self.z: z,
|
|
|
+ self.is_train: False})
|
|
|
+ imsave('results/r_{}.jpg'.format(step), np.squeeze(image_ab, axis=0))
|
|
|
+
|
|
|
summary_writer.add_summary(fetched[-1], step)
|
|
|
summary_writer.flush()
|
|
|
t.set_description('Loss({:.3f})'.format(fetched[0]))
|
|
@@ -203,44 +196,37 @@ class BicycleGAN(object):
|
|
|
step += 1
|
|
|
image_a = np.expand_dims(dataA, axis=0)
|
|
|
image_b = np.expand_dims(dataB, axis=0)
|
|
|
- images = []
|
|
|
- images.append(image_a)
|
|
|
- images.append(image_b)
|
|
|
+ images_random = []
|
|
|
+ images_random.append(image_a)
|
|
|
+ images_random.append(image_b)
|
|
|
+ images_linear = []
|
|
|
+ images_linear.append(image_a)
|
|
|
+ images_linear.append(image_b)
|
|
|
|
|
|
for i in range(23):
|
|
|
z = np.random.uniform(-1, 1, size=(1, self._latent_dim))
|
|
|
image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
|
|
|
self.z: z,
|
|
|
- self.is_train: True})
|
|
|
- images.append(image_ab)
|
|
|
+ self.is_train: False})
|
|
|
+ images_random.append(image_ab)
|
|
|
+
|
|
|
+ z = np.zeros((1, self._latent_dim))
|
|
|
+ z[0][0] = (i / 23.0 - 0.5) * 2.0
|
|
|
+ image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
|
|
|
+ self.z: z,
|
|
|
+ self.is_train: False})
|
|
|
+ images_linear.append(image_ab)
|
|
|
|
|
|
image_rows = []
|
|
|
for i in range(5):
|
|
|
- image_rows.append(np.concatenate(images[i*5:(i+1)*5], axis=2))
|
|
|
+ image_rows.append(np.concatenate(images_random[i*5:(i+1)*5], axis=2))
|
|
|
images = np.concatenate(image_rows, axis=1)
|
|
|
images = np.squeeze(images, axis=0)
|
|
|
imsave(os.path.join(base_dir, 'random_{}.jpg'.format(step)), images)
|
|
|
|
|
|
- step=0
|
|
|
- for (dataA, dataB) in tqdm(zip(data_A, data_B)):
|
|
|
- step += 1
|
|
|
- image_a = np.expand_dims(dataA, axis=0)
|
|
|
- image_b = np.expand_dims(dataB, axis=0)
|
|
|
- images = []
|
|
|
- images.append(image_a)
|
|
|
- images.append(image_b)
|
|
|
-
|
|
|
- for i in range(23):
|
|
|
- z = np.zeros((1, self._latent_dim))
|
|
|
- z[0][0] = (i / 23.0 - 0.5) * 2.0
|
|
|
- image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
|
|
|
- self.z: z,
|
|
|
- self.is_train: True})
|
|
|
- images.append(image_ab)
|
|
|
-
|
|
|
image_rows = []
|
|
|
for i in range(5):
|
|
|
- image_rows.append(np.concatenate(images[i*5:(i+1)*5], axis=2))
|
|
|
+ image_rows.append(np.concatenate(images_linear[i*5:(i+1)*5], axis=2))
|
|
|
images = np.concatenate(image_rows, axis=1)
|
|
|
images = np.squeeze(images, axis=0)
|
|
|
imsave(os.path.join(base_dir, 'linear_{}.jpg'.format(step)), images)
|