|
@@ -5,7 +5,7 @@ import ops
|
|
|
|
|
|
class Encoder(object):
|
|
class Encoder(object):
|
|
def __init__(self, name, is_train, norm='instance', activation='leaky',
|
|
def __init__(self, name, is_train, norm='instance', activation='leaky',
|
|
- image_size=128, latent_dim=8):
|
|
|
|
|
|
+ image_size=128, latent_dim=8, use_resnet=True):
|
|
logger.info('Init Encoder %s', name)
|
|
logger.info('Init Encoder %s', name)
|
|
self.name = name
|
|
self.name = name
|
|
self._is_train = is_train
|
|
self._is_train = is_train
|
|
@@ -14,8 +14,15 @@ class Encoder(object):
|
|
self._reuse = False
|
|
self._reuse = False
|
|
self._image_size = image_size
|
|
self._image_size = image_size
|
|
self._latent_dim = latent_dim
|
|
self._latent_dim = latent_dim
|
|
|
|
+ self._use_resnet = use_resnet
|
|
|
|
|
|
def __call__(self, input):
|
|
def __call__(self, input):
|
|
|
|
+ if self._use_resnet:
|
|
|
|
+ return self._resnet(input)
|
|
|
|
+ else:
|
|
|
|
+ return self._convnet(input)
|
|
|
|
+
|
|
|
|
+ def _convnet(self, input):
|
|
with tf.variable_scope(self.name, reuse=self._reuse):
|
|
with tf.variable_scope(self.name, reuse=self._reuse):
|
|
num_filters = [64, 128, 256, 512, 512, 512, 512]
|
|
num_filters = [64, 128, 256, 512, 512, 512, 512]
|
|
if self._image_size == 256:
|
|
if self._image_size == 256:
|
|
@@ -24,12 +31,39 @@ class Encoder(object):
|
|
E = input
|
|
E = input
|
|
for i, n in enumerate(num_filters):
|
|
for i, n in enumerate(num_filters):
|
|
E = ops.conv_block(E, n, 'C{}_{}'.format(n, i), 4, 2, self._is_train,
|
|
E = ops.conv_block(E, n, 'C{}_{}'.format(n, i), 4, 2, self._is_train,
|
|
- self._reuse, norm=self._norm if i else None, activation='leaky')
|
|
|
|
- E = tf.reshape(E, [-1, 512])
|
|
|
|
|
|
+ self._reuse, norm=self._norm if i else None, activation='leaky')
|
|
|
|
+ E = ops.flatten(E)
|
|
|
|
+ mu = ops.mlp(E, self._latent_dim, 'FC8_mu', self._is_train, self._reuse,
|
|
|
|
+ norm=None, activation=None)
|
|
|
|
+ log_sigma = ops.mlp(E, self._latent_dim, 'FC8_sigma', self._is_train, self._reuse,
|
|
|
|
+ norm=None, activation=None)
|
|
|
|
+
|
|
|
|
+ z = mu + tf.random_normal(shape=tf.shape(self._latent_dim)) * tf.exp(log_sigma)
|
|
|
|
+
|
|
|
|
+ self._reuse = True
|
|
|
|
+ self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
|
|
|
|
+ return z, mu, log_sigma
|
|
|
|
+
|
|
|
|
+ def _resnet(self, input):
|
|
|
|
+ with tf.variable_scope(self.name, reuse=self._reuse):
|
|
|
|
+ num_filters = [128, 256, 512, 512]
|
|
|
|
+ if self._image_size == 256:
|
|
|
|
+ num_filters.append(512)
|
|
|
|
+
|
|
|
|
+ E = input
|
|
|
|
+ E = ops.conv_block(E, 64, 'C{}_{}'.format(64, 0), 4, 2, self._is_train,
|
|
|
|
+ self._reuse, norm=None, activation='leaky', bias=True)
|
|
|
|
+ for i, n in enumerate(num_filters):
|
|
|
|
+ E = ops.residual(E, n, 'res{}_{}'.format(n, i + 1), self._is_train,
|
|
|
|
+ self._reuse, norm=self._norm, activation='leaky',
|
|
|
|
+ bias=True)
|
|
|
|
+ E = tf.nn.avg_pool(E, [1, 2, 2, 1], [1, 2, 2, 1], 'SAME')
|
|
|
|
+ E = tf.nn.avg_pool(E, [1, 8, 8, 1], [1, 8, 8, 1], 'SAME')
|
|
|
|
+ E = ops.flatten(E)
|
|
mu = ops.mlp(E, self._latent_dim, 'FC8_mu', self._is_train, self._reuse,
|
|
mu = ops.mlp(E, self._latent_dim, 'FC8_mu', self._is_train, self._reuse,
|
|
- norm=None, activation=None)
|
|
|
|
|
|
+ norm=None, activation=None)
|
|
log_sigma = ops.mlp(E, self._latent_dim, 'FC8_sigma', self._is_train, self._reuse,
|
|
log_sigma = ops.mlp(E, self._latent_dim, 'FC8_sigma', self._is_train, self._reuse,
|
|
- norm=None, activation=None)
|
|
|
|
|
|
+ norm=None, activation=None)
|
|
|
|
|
|
z = mu + tf.random_normal(shape=tf.shape(self._latent_dim)) * tf.exp(log_sigma)
|
|
z = mu + tf.random_normal(shape=tf.shape(self._latent_dim)) * tf.exp(log_sigma)
|
|
|
|
|