encoder.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import tensorflow as tf
  2. from utils import logger
  3. import ops
  4. class Encoder(object):
  5. def __init__(self, name, is_train, norm='instance', activation='leaky',
  6. image_size=128, latent_dim=8, use_resnet=True):
  7. logger.info('Init Encoder %s', name)
  8. self.name = name
  9. self._is_train = is_train
  10. self._norm = norm
  11. self._activation = activation
  12. self._reuse = False
  13. self._image_size = image_size
  14. self._latent_dim = latent_dim
  15. self._use_resnet = use_resnet
  16. def __call__(self, input):
  17. if self._use_resnet:
  18. return self._resnet(input)
  19. else:
  20. return self._convnet(input)
  21. def _convnet(self, input):
  22. with tf.variable_scope(self.name, reuse=self._reuse):
  23. num_filters = [64, 128, 256, 512, 512, 512, 512]
  24. if self._image_size == 256:
  25. num_filters.append(512)
  26. E = input
  27. for i, n in enumerate(num_filters):
  28. E = ops.conv_block(E, n, 'C{}_{}'.format(n, i), 4, 2, self._is_train,
  29. self._reuse, norm=self._norm if i else None, activation='leaky')
  30. E = ops.flatten(E)
  31. mu = ops.mlp(E, self._latent_dim, 'FC8_mu', self._is_train, self._reuse,
  32. norm=None, activation=None)
  33. log_sigma = ops.mlp(E, self._latent_dim, 'FC8_sigma', self._is_train, self._reuse,
  34. norm=None, activation=None)
  35. z = mu + tf.random_normal(shape=tf.shape(self._latent_dim)) * tf.exp(log_sigma)
  36. self._reuse = True
  37. self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
  38. return z, mu, log_sigma
  39. def _resnet(self, input):
  40. with tf.variable_scope(self.name, reuse=self._reuse):
  41. num_filters = [128, 256, 512, 512]
  42. if self._image_size == 256:
  43. num_filters.append(512)
  44. E = input
  45. E = ops.conv_block(E, 64, 'C{}_{}'.format(64, 0), 4, 2, self._is_train,
  46. self._reuse, norm=None, activation='leaky', bias=True)
  47. for i, n in enumerate(num_filters):
  48. E = ops.residual(E, n, 'res{}_{}'.format(n, i + 1), self._is_train,
  49. self._reuse, norm=self._norm, bias=True)
  50. E = tf.nn.avg_pool(E, [1, 2, 2, 1], [1, 2, 2, 1], 'SAME')
  51. E = tf.nn.relu(E)
  52. E = tf.nn.avg_pool(E, [1, 8, 8, 1], [1, 8, 8, 1], 'SAME')
  53. E = ops.flatten(E)
  54. mu = ops.mlp(E, self._latent_dim, 'FC8_mu', self._is_train, self._reuse,
  55. norm=None, activation=None)
  56. log_sigma = ops.mlp(E, self._latent_dim, 'FC8_sigma', self._is_train, self._reuse,
  57. norm=None, activation=None)
  58. z = mu + tf.random_normal(shape=tf.shape(self._latent_dim)) * tf.exp(log_sigma)
  59. self._reuse = True
  60. self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
  61. return z, mu, log_sigma