Pārlūkot izejas kodu

add resnet for encoder

youngwoon 7 gadi atpakaļ
vecāks
revīzija
dbf07c8ec8
4 mainītis faili ar 52 papildinājumiem un 9 dzēšanām
  1. 2 0
      bicycle-gan.py
  2. 39 5
      encoder.py
  3. 3 1
      model.py
  4. 8 3
      ops.py

+ 2 - 0
bicycle-gan.py

@@ -35,6 +35,8 @@ parser.add_argument('--image_size', default=256, type=int,
                     help="Image size")
 parser.add_argument('--latent_dim', default=8, type=int,
                     help="Dimensionality of latent vector")
+parser.add_argument('--use_resnet', default=True, type=bool,
+                    help="Use the ResNet model for the encoder")
 parser.add_argument('--load_model', default='',
                     help='Model path to load (e.g., train_2017-07-07_01-23-45)')
 parser.add_argument('--gpu', default="1", type=str,

+ 39 - 5
encoder.py

@@ -5,7 +5,7 @@ import ops
 
 class Encoder(object):
     def __init__(self, name, is_train, norm='instance', activation='leaky',
-                 image_size=128, latent_dim=8):
+                 image_size=128, latent_dim=8, use_resnet=True):
         logger.info('Init Encoder %s', name)
         self.name = name
         self._is_train = is_train
@@ -14,8 +14,15 @@ class Encoder(object):
         self._reuse = False
         self._image_size = image_size
         self._latent_dim = latent_dim
+        self._use_resnet = use_resnet
 
     def __call__(self, input):
+        if self._use_resnet:
+            return self._resnet(input)
+        else:
+            return self._convnet(input)
+
+    def _convnet(self, input):
         with tf.variable_scope(self.name, reuse=self._reuse):
             num_filters = [64, 128, 256, 512, 512, 512, 512]
             if self._image_size == 256:
@@ -24,12 +31,39 @@ class Encoder(object):
             E = input
             for i, n in enumerate(num_filters):
                 E = ops.conv_block(E, n, 'C{}_{}'.format(n, i), 4, 2, self._is_train,
-                                self._reuse, norm=self._norm if i else None, activation='leaky')
-            E = tf.reshape(E, [-1, 512])
+                                   self._reuse, norm=self._norm if i else None, activation='leaky')
+            E = ops.flatten(E)
+            mu = ops.mlp(E, self._latent_dim, 'FC8_mu', self._is_train, self._reuse,
+                         norm=None, activation=None)
+            log_sigma = ops.mlp(E, self._latent_dim, 'FC8_sigma', self._is_train, self._reuse,
+                                norm=None, activation=None)
+
+            z = mu + tf.random_normal(shape=tf.shape(self._latent_dim)) * tf.exp(log_sigma)
+
+            self._reuse = True
+            self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
+            return z, mu, log_sigma
+
+    def _resnet(self, input):
+        with tf.variable_scope(self.name, reuse=self._reuse):
+            num_filters = [128, 256, 512, 512]
+            if self._image_size == 256:
+                num_filters.append(512)
+
+            E = input
+            E = ops.conv_block(E, 64, 'C{}_{}'.format(64, 0), 4, 2, self._is_train,
+                               self._reuse, norm=None, activation='leaky', bias=True)
+            for i, n in enumerate(num_filters):
+                E = ops.residual(E, n, 'res{}_{}'.format(n, i + 1), self._is_train,
+                                 self._reuse, norm=self._norm, activation='leaky',
+                                 bias=True)
+                E = tf.nn.avg_pool(E, [1, 2, 2, 1], [1, 2, 2, 1], 'SAME')
+            E = tf.nn.avg_pool(E, [1, 8, 8, 1], [1, 8, 8, 1], 'SAME')
+            E = ops.flatten(E)
             mu = ops.mlp(E, self._latent_dim, 'FC8_mu', self._is_train, self._reuse,
-                        norm=None, activation=None)
+                         norm=None, activation=None)
             log_sigma = ops.mlp(E, self._latent_dim, 'FC8_sigma', self._is_train, self._reuse,
-                        norm=None, activation=None)
+                                norm=None, activation=None)
 
             z = mu + tf.random_normal(shape=tf.shape(self._latent_dim)) * tf.exp(log_sigma)
 

+ 3 - 1
model.py

@@ -22,6 +22,7 @@ class BicycleGAN(object):
         self._coeff_latent = args.coeff_latent
         self._coeff_kl = args.coeff_kl
         self._norm = 'instance' if args.instance_normalization else 'batch'
+        self._use_resnet = args.use_resnet
 
         self._augment_size = self._image_size + (30 if self._image_size == 256 else 15)
         self._image_shape = [self._image_size, self._image_size, 3]
@@ -64,7 +65,8 @@ class BicycleGAN(object):
         # Encoder
         E = Encoder('E', is_train=self.is_train,
                     norm=self._norm, activation='relu',
-                    image_size=self._image_size, latent_dim=self._latent_dim)
+                    image_size=self._image_size, latent_dim=self._latent_dim,
+                    use_resnet=self._use_resnet)
 
         # conditional VAE-GAN: B -> z -> B'
         z_encoded, z_encoded_mu, z_encoded_log_sigma = E(image_b)

+ 8 - 3
ops.py

@@ -1,4 +1,5 @@
 import tensorflow as tf
+import numpy as np
 
 
 def _norm(input, is_train, reuse=True, norm=None):
@@ -39,6 +40,9 @@ def _activation(input, activation=None):
     else:
         return input
 
+def flatten(input):
+    return tf.reshape(input, [-1, np.prod(input.get_shape().as_list()[1:])])
+
 def conv2d(input, num_filters, filter_size, stride, reuse=False,
            pad='SAME', dtype=tf.float32, bias=False):
     stride_shape = [1, stride, stride, 1]
@@ -91,15 +95,16 @@ def conv_block(input, num_filters, name, k_size, stride, is_train, reuse, norm,
         out = _activation(out, activation)
         return out
 
-def residual(input, num_filters, name, is_train, reuse, norm, pad='REFLECT'):
+def residual(input, num_filters, name, is_train, reuse, norm, pad='REFLECT',
+             bias=False):
     with tf.variable_scope(name, reuse=reuse):
         with tf.variable_scope('res1', reuse=reuse):
-            out = conv2d(input, num_filters, 3, 1, reuse, pad)
+            out = conv2d(input, num_filters, 3, 1, reuse, pad, bias=bias)
             out = _norm(out, is_train, reuse, norm)
             out = tf.nn.relu(out)
 
         with tf.variable_scope('res2', reuse=reuse):
-            out = conv2d(out, num_filters, 3, 1, reuse, pad)
+            out = conv2d(out, num_filters, 3, 1, reuse, pad, bias=bias)
             out = _norm(out, is_train, reuse, norm)
 
         return tf.nn.relu(input + out)