encoder.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334
  1. import tensorflow as tf
  2. from utils import logger
  3. import ops
  4. class Encoder(object):
  5. def __init__(self, name, is_train, norm='instance', activation='leaky',
  6. image_size=128, latent_dim=8):
  7. logger.info('Init Encoder %s', name)
  8. self.name = name
  9. self._is_train = is_train
  10. self._norm = norm
  11. self._activation = activation
  12. self._reuse = False
  13. self._image_size = image_size
  14. self._latent_dim = latent_dim
  15. def __call__(self, input):
  16. with tf.variable_scope(self.name, reuse=self._reuse):
  17. num_filters = [64, 128, 256, 512, 512, 512, 512]
  18. if self._image_size == 256:
  19. num_filters.append(512)
  20. E = input
  21. for i, n in enumerate(num_filters):
  22. E = ops.conv_block(E, n, 'C{}_{}'.format(n, i), 4, 2, self._is_train,
  23. self._reuse, norm=self._norm if i else None, activation='leaky')
  24. E = tf.reshape(E, [-1, 512])
  25. E = ops.mlp(E, self._latent_dim, 'FC8', self._is_train, self._reuse,
  26. norm=None, activation=None)
  27. self._reuse = True
  28. self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
  29. return E