discriminator_z.py 935 B

1234567891011121314151617181920212223242526
  1. import tensorflow as tf
  2. from utils import logger
  3. import ops
  4. class DiscriminatorZ(object):
  5. def __init__(self, name, is_train, norm='batch', activation='relu'):
  6. logger.info('Init DiscriminatorZ %s', name)
  7. self.name = name
  8. self._is_train = is_train
  9. self._norm = norm
  10. self._activation = activation
  11. self._reuse = False
  12. def __call__(self, input):
  13. with tf.variable_scope(self.name, reuse=self._reuse):
  14. D = input
  15. for i in range(3):
  16. D = ops.mlp(D, 512, 'FC512_{}'.format(i), self._is_train,
  17. self._reuse, self._norm, self._activation)
  18. D = ops.mlp(D, 1, 'FC1_{}'.format(i), self._is_train,
  19. self._reuse, norm=None, activation=None)
  20. self._reuse = True
  21. self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
  22. return D