ソースを参照

fix residual network

Youngwoon Lee 6 年 前
コミット
af75776f8e
2 ファイル変更6 行追加3 行削除
  1. 2 2
      encoder.py
  2. 4 1
      ops.py

+ 2 - 2
encoder.py

@@ -55,9 +55,9 @@ class Encoder(object):
                                self._reuse, norm=None, activation='leaky', bias=True)
             for i, n in enumerate(num_filters):
                 E = ops.residual(E, n, 'res{}_{}'.format(n, i + 1), self._is_train,
-                                 self._reuse, norm=self._norm, activation='leaky',
-                                 bias=True)
+                                 self._reuse, norm=self._norm, bias=True)
                 E = tf.nn.avg_pool(E, [1, 2, 2, 1], [1, 2, 2, 1], 'SAME')
+            E = tf.nn.relu(E)
             E = tf.nn.avg_pool(E, [1, 8, 8, 1], [1, 8, 8, 1], 'SAME')
             E = ops.flatten(E)
             mu = ops.mlp(E, self._latent_dim, 'FC8_mu', self._is_train, self._reuse,

+ 4 - 1
ops.py

@@ -107,7 +107,10 @@ def residual(input, num_filters, name, is_train, reuse, norm, pad='REFLECT',
             out = conv2d(out, num_filters, 3, 1, reuse, pad, bias=bias)
             out = _norm(out, is_train, reuse, norm)
 
-        return tf.nn.relu(input + out)
+        with tf.variable_scope('shortcut', reuse=reuse):
+            shortcut = conv2d(input, num_filters, 1, 1, reuse, pad, bias=bias)
+
+        return tf.nn.relu(shortcut + out)
 
 def deconv_block(input, num_filters, name, k_size, stride, is_train, reuse,
                  norm, activation):