discriminator.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334
  1. import tensorflow as tf
  2. from utils import logger
  3. import ops
  4. class Discriminator(object):
  5. def __init__(self, name, is_train, norm='instance', activation='leaky', image_size=128):
  6. logger.info('Init Discriminator %s', name)
  7. self.name = name
  8. self._is_train = is_train
  9. self._norm = norm
  10. self._activation = activation
  11. self._reuse = False
  12. self._image_size = image_size
  13. def __call__(self, input):
  14. with tf.variable_scope(self.name, reuse=self._reuse):
  15. D = ops.conv_block(input, 64, 'C64', 4, 2, self._is_train,
  16. self._reuse, norm=None, activation=self._activation)
  17. D = ops.conv_block(D, 128, 'C128', 4, 2, self._is_train,
  18. self._reuse, self._norm, self._activation)
  19. D = ops.conv_block(D, 256, 'C256', 4, 2, self._is_train,
  20. self._reuse, self._norm, self._activation)
  21. num_layers = 3 if self._image_size == 256 else 1
  22. for i in range(num_layers):
  23. D = ops.conv_block(D, 512, 'C512_{}'.format(i), 4, 2, self._is_train,
  24. self._reuse, self._norm, self._activation)
  25. D = ops.conv_block(D, 1, 'C1', 4, 1, self._is_train,
  26. self._reuse, norm=None, activation=None, bias=True)
  27. D = tf.reduce_mean(D, axis=[1,2,3])
  28. self._reuse = True
  29. self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
  30. return D