123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172 |
- import tensorflow as tf
- from utils import logger
- import ops
- class Encoder(object):
- def __init__(self, name, is_train, norm='instance', activation='leaky',
- image_size=128, latent_dim=8, use_resnet=True):
- logger.info('Init Encoder %s', name)
- self.name = name
- self._is_train = is_train
- self._norm = norm
- self._activation = activation
- self._reuse = False
- self._image_size = image_size
- self._latent_dim = latent_dim
- self._use_resnet = use_resnet
- def __call__(self, input):
- if self._use_resnet:
- return self._resnet(input)
- else:
- return self._convnet(input)
- def _convnet(self, input):
- with tf.variable_scope(self.name, reuse=self._reuse):
- num_filters = [64, 128, 256, 512, 512, 512, 512]
- if self._image_size == 256:
- num_filters.append(512)
- E = input
- for i, n in enumerate(num_filters):
- E = ops.conv_block(E, n, 'C{}_{}'.format(n, i), 4, 2, self._is_train,
- self._reuse, norm=self._norm if i else None, activation='leaky')
- E = ops.flatten(E)
- mu = ops.mlp(E, self._latent_dim, 'FC8_mu', self._is_train, self._reuse,
- norm=None, activation=None)
- log_sigma = ops.mlp(E, self._latent_dim, 'FC8_sigma', self._is_train, self._reuse,
- norm=None, activation=None)
- z = mu + tf.random_normal(shape=tf.shape(self._latent_dim)) * tf.exp(log_sigma)
- self._reuse = True
- self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
- return z, mu, log_sigma
- def _resnet(self, input):
- with tf.variable_scope(self.name, reuse=self._reuse):
- num_filters = [128, 256, 512, 512]
- if self._image_size == 256:
- num_filters.append(512)
- E = input
- E = ops.conv_block(E, 64, 'C{}_{}'.format(64, 0), 4, 2, self._is_train,
- self._reuse, norm=None, activation='leaky', bias=True)
- for i, n in enumerate(num_filters):
- E = ops.residual(E, n, 'res{}_{}'.format(n, i + 1), self._is_train,
- self._reuse, norm=self._norm, bias=True)
- E = tf.nn.avg_pool(E, [1, 2, 2, 1], [1, 2, 2, 1], 'SAME')
- E = tf.nn.relu(E)
- E = tf.nn.avg_pool(E, [1, 8, 8, 1], [1, 8, 8, 1], 'SAME')
- E = ops.flatten(E)
- mu = ops.mlp(E, self._latent_dim, 'FC8_mu', self._is_train, self._reuse,
- norm=None, activation=None)
- log_sigma = ops.mlp(E, self._latent_dim, 'FC8_sigma', self._is_train, self._reuse,
- norm=None, activation=None)
- z = mu + tf.random_normal(shape=tf.shape(self._latent_dim)) * tf.exp(log_sigma)
- self._reuse = True
- self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
- return z, mu, log_sigma
|