generator.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. import tensorflow as tf
  2. from utils import logger
  3. import ops
  4. class Generator(object):
  5. def __init__(self, name, is_train, norm='batch', image_size=128):
  6. logger.info('Init Generator %s', name)
  7. self.name = name
  8. self._is_train = is_train
  9. self._norm = norm
  10. self._reuse = False
  11. self._image_size = image_size
  12. def __call__(self, input, z):
  13. with tf.variable_scope(self.name, reuse=self._reuse):
  14. batch_size = int(input.get_shape()[0])
  15. latent_dim = int(z.get_shape()[-1])
  16. num_filters = [64, 128, 256, 512, 512, 512, 512]
  17. if self._image_size == 256:
  18. num_filters.append(512)
  19. layers = []
  20. G = input
  21. z = tf.reshape(z, [batch_size, 1, 1, latent_dim])
  22. z = tf.tile(z, [1, self._image_size, self._image_size, 1])
  23. G = tf.concat([G, z], axis=3)
  24. for i, n in enumerate(num_filters):
  25. G = ops.conv_block(G, n, 'C{}_{}'.format(n, i), 4, 2, self._is_train,
  26. self._reuse, norm=self._norm if i else None, activation='leaky')
  27. layers.append(G)
  28. layers.pop()
  29. num_filters.pop()
  30. num_filters.reverse()
  31. for i, n in enumerate(num_filters):
  32. G = ops.deconv_block(G, n, 'CD{}_{}'.format(n, i), 4, 2, self._is_train,
  33. self._reuse, norm=self._norm, activation='relu')
  34. G = tf.concat([G, layers.pop()], axis=3)
  35. G = ops.deconv_block(G, 3, 'last_layer', 4, 2, self._is_train,
  36. self._reuse, norm=None, activation='tanh')
  37. self._reuse = True
  38. self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
  39. return G