12345678910111213141516171819202122232425262728293031323334353637383940414243444546 |
- import tensorflow as tf
- from utils import logger
- import ops
- class Generator(object):
- def __init__(self, name, is_train, norm='batch', image_size=128):
- logger.info('Init Generator %s', name)
- self.name = name
- self._is_train = is_train
- self._norm = norm
- self._reuse = False
- self._image_size = image_size
- def __call__(self, input, z):
- with tf.variable_scope(self.name, reuse=self._reuse):
- batch_size = int(input.get_shape()[0])
- latent_dim = int(z.get_shape()[-1])
- num_filters = [64, 128, 256, 512, 512, 512, 512]
- if self._image_size == 256:
- num_filters.append(512)
- layers = []
- G = input
- z = tf.reshape(z, [batch_size, 1, 1, latent_dim])
- z = tf.tile(z, [1, self._image_size, self._image_size, 1])
- G = tf.concat([G, z], axis=3)
- for i, n in enumerate(num_filters):
- G = ops.conv_block(G, n, 'C{}_{}'.format(n, i), 4, 2, self._is_train,
- self._reuse, norm=self._norm if i else None, activation='leaky')
- layers.append(G)
- layers.pop()
- num_filters.pop()
- num_filters.reverse()
- for i, n in enumerate(num_filters):
- G = ops.deconv_block(G, n, 'CD{}_{}'.format(n, i), 4, 2, self._is_train,
- self._reuse, norm=self._norm, activation='relu')
- G = tf.concat([G, layers.pop()], axis=3)
- G = ops.deconv_block(G, 3, 'last_layer', 4, 2, self._is_train,
- self._reuse, norm=None, activation='tanh')
- self._reuse = True
- self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
- return G
|