Youngwoon Lee пре 7 година
родитељ
комит
d4e5104e3a
11 измењених фајлова са 872 додато и 0 уклоњено
  1. 111 0
      .gitignore
  2. 123 0
      bicycle-gan.py
  3. 110 0
      data_loader.py
  4. 34 0
      discriminator.py
  5. 26 0
      discriminator_z.py
  6. 14 0
      download_pix2pix_dataset.sh
  7. 34 0
      encoder.py
  8. 47 0
      generator.py
  9. 246 0
      model.py
  10. 115 0
      ops.py
  11. 12 0
      utils.py

+ 111 - 0
.gitignore

@@ -0,0 +1,111 @@
+# Tensorflow logs / datasets / results
+logs/
+datasets/
+results/
+
+# Temporary files
+*.zip
+*.swp
+*~
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+env/
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+
+# PyInstaller
+#  Usually these files are written by a python script from a template
+#  before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# dotenv
+.env
+
+# virtualenv
+.venv
+venv/
+ENV/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/

+ 123 - 0
bicycle-gan.py

@@ -0,0 +1,123 @@
+import argparse
+import sys
+import signal
+import os
+from datetime import datetime
+
+import tensorflow as tf
+
+from data_loader import get_data
+from model import BicycleGAN
+from utils import logger, makedirs
+
+
+# parsing cmd arguments
+parser = argparse.ArgumentParser(description="Run commands")
+def str2bool(v):
+    return v.lower() == 'true'
+parser.add_argument('--train', default=True, type=str2bool,
+                    help="Training mode")
+parser.add_argument('--task', type=str, default='edges2shoes',
+                    help='Task name')
+parser.add_argument('--gamma', type=float, default=1,
+                    help='Loss coefficient')
+parser.add_argument('--lambda1', type=float, default=1,
+                    help='Loss coefficient')
+parser.add_argument('--lambda2', type=float, default=1,
+                    help='Loss coefficient')
+parser.add_argument('--instance_normalization', default=False, type=bool,
+                    help="Use instance norm instead of batch norm")
+parser.add_argument('--log_step', default=100, type=int,
+                    help="Tensorboard log frequency")
+parser.add_argument('--batch_size', default=1, type=int,
+                    help="Batch size")
+parser.add_argument('--image_size', default=128, type=int,
+                    help="Image size")
+parser.add_argument('--latent_dim', default=8, type=int,
+                    help="Dimensionality of latent vector")
+parser.add_argument('--load_model', default='',
+                    help='Model path to load (e.g., train_2017-07-07_01-23-45)')
+parser.add_argument('--gpu', default="1", type=str,
+                    help="gpu index for CUDA_VISIBLE_DEVICES")
+
+
+class FastSaver(tf.train.Saver):
+    def save(self, sess, save_path, global_step=None, latest_filename=None,
+             meta_graph_suffix="meta", write_meta_graph=True):
+        super(FastSaver, self).save(sess, save_path, global_step, latest_filename,
+                                    meta_graph_suffix, False)
+
+
+def run(args):
+    # setting the GPU #
+    os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
+    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
+
+    logger.info('Read data:')
+    train_A, train_B, test_A, test_B = get_data(args.task, args.image_size)
+
+    logger.info('Build graph:')
+    model = BicycleGAN(args)
+
+    variables_to_save = tf.global_variables()
+    init_op = tf.variables_initializer(variables_to_save)
+    init_all_op = tf.global_variables_initializer()
+    saver = FastSaver(variables_to_save)
+
+    logger.info('Trainable vars:')
+    var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
+                                 tf.get_variable_scope().name)
+    for v in var_list:
+        logger.info('  %s %s', v.name, v.get_shape())
+
+    if args.load_model != '':
+        model_name = args.load_model
+    else:
+        model_name = '{}_{}'.format(args.task, datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
+    logdir = './logs'
+    makedirs(logdir)
+    logdir = os.path.join(logdir, model_name)
+    logger.info('Events directory: %s', logdir)
+    summary_writer = tf.summary.FileWriter(logdir)
+
+    def init_fn(sess):
+        logger.info('Initializing all parameters.')
+        sess.run(init_all_op)
+
+    sv = tf.train.Supervisor(is_chief=True,
+                             logdir=logdir,
+                             saver=saver,
+                             summary_op=None,
+                             init_op=init_op,
+                             init_fn=init_fn,
+                             summary_writer=summary_writer,
+                             ready_op=tf.report_uninitialized_variables(variables_to_save),
+                             global_step=model.global_step,
+                             save_model_secs=300,
+                             save_summaries_secs=30)
+
+    if args.train:
+        logger.info("Starting training session.")
+        with sv.managed_session() as sess:
+            model.train(sess, summary_writer, train_A, train_B)
+
+    logger.info("Starting testing session.")
+    with sv.managed_session() as sess:
+        base_dir = os.path.join('results', model_name)
+        makedirs(base_dir)
+        model.test(sess, test_A, test_B, base_dir)
+
+def main():
+    args, unparsed = parser.parse_known_args()
+
+    def shutdown(signal, frame):
+        tf.logging.warn('Received signal %s: exiting', signal)
+        sys.exit(128+signal)
+    signal.signal(signal.SIGHUP, shutdown)
+    signal.signal(signal.SIGINT, shutdown)
+    signal.signal(signal.SIGTERM, shutdown)
+
+    run(args)
+
+if __name__ == "__main__":
+    main()

+ 110 - 0
data_loader.py

@@ -0,0 +1,110 @@
+import os
+from glob import glob
+
+from scipy.misc import imread, imresize
+import numpy as np
+from tqdm import tqdm
+import h5py
+
+datasets = ['maps', 'cityscapes', 'facades', 'edges2handbags', 'edges2shoes']
+
+def read_image(path):
+    image = imread(path)
+    if len(image.shape) != 3 or image.shape[2] != 3:
+        print('Wrong image {} with shape {}'.format(path, image.shape))
+        return None
+
+    # split image
+    h, w, c = image.shape
+    assert w == 256 or w == 512, 'Image size mismatch ({}, {})'.format(h, w)
+    assert h == 128 or h == 256, 'Image size mismatch ({}, {})'.format(h, w)
+    image_a = image[:, :w/2, :].astype(np.float32) / 255.0
+    image_b = image[:, w/2:, :].astype(np.float32) / 255.0
+
+    # range of pixel values = [-1.0, 1.0]
+    image_a = image_a * 2.0 - 1.0
+    image_b = image_b * 2.0 - 1.0
+    return image_a, image_b
+
+def read_images(base_dir):
+    ret = []
+    for dir_name in ['train', 'val']:
+        data_dir = os.path.join(base_dir, dir_name)
+        paths = glob(os.path.join(data_dir, '*.jpg'))
+        print('# images in {}: {}'.format(data_dir, len(paths)))
+
+        images_A = []
+        images_B = []
+        for path in tqdm(paths):
+            image_A, image_B = read_image(path)
+            if image_A is not None:
+                images_A.append(image_A)
+                images_B.append(image_B)
+        ret.append((dir_name + 'A', images_A))
+        ret.append((dir_name + 'B', images_B))
+    return ret
+
+def store_h5py(base_dir, dir_name, images, image_size):
+    f = h5py.File(os.path.join(base_dir, '{}_{}.hy'.format(dir_name, image_size)), 'w')
+    for i in range(len(images)):
+        grp = f.create_group(str(i))
+        if images[i].shape[0] != image_size:
+            image = imresize(images[i], (image_size, image_size, 3))
+            # range of pixel values = [-1.0, 1.0]
+            image = image.astype(np.float32) / 255.0
+            image = image * 2.0 - 1.0
+            grp['image'] = image
+        else:
+            grp['image'] = images[i]
+    f.close()
+
+def convert_h5py(task_name):
+    print('Generating h5py file')
+    base_dir = os.path.join('datasets', task_name)
+    data = read_images(base_dir)
+    for dir_name, images in data:
+        if images[0].shape[0] == 256:
+            store_h5py(base_dir, dir_name, images, 256)
+        store_h5py(base_dir, dir_name, images, 128)
+
+def read_h5py(task_name, image_size):
+    base_dir = 'datasets/' + task_name
+    paths = glob(os.path.join(base_dir, '*_{}.hy'.format(image_size)))
+    if len(paths) != 4:
+        convert_h5py(task_name)
+    ret = []
+    for dir_name in ['trainA', 'trainB', 'valA', 'valB']:
+        try:
+            dataset = h5py.File(os.path.join(base_dir, '{}_{}.hy'.format(dir_name, image_size)), 'r')
+        except:
+            raise IOError('Dataset is not available. Please try it again')
+
+        images = []
+        for id in dataset:
+            images.append(dataset[id]['image'].value.astype(np.float32))
+        ret.append(images)
+    return ret
+
+def download_dataset(task_name):
+    print('Download data %s' % task_name)
+    cmd = './download_pix2pix_dataset.sh ' +  task_name
+    os.system(cmd)
+
+def get_data(task_name, image_size):
+    assert task_name in datasets, 'Dataset {}_{} is not available'.format(
+        task_name, image_size)
+
+    if not os.path.exists('datasets'):
+        os.makedirs('datasets')
+
+    base_dir = os.path.join('datasets', task_name)
+    print('Check data %s' % base_dir)
+    if not os.path.exists(base_dir):
+        print('Dataset not found. Start downloading...')
+        download_dataset(task_name)
+        convert_h5py(task_name)
+
+    print('Load data %s' % task_name)
+    train_A, train_B, test_A, test_B = \
+        read_h5py(task_name, image_size)
+    return train_A, train_B, test_A, test_B

+ 34 - 0
discriminator.py

@@ -0,0 +1,34 @@
+import tensorflow as tf
+from utils import logger
+import ops
+
+
+class Discriminator(object):
+    def __init__(self, name, is_train, norm='instance', activation='leaky', image_size=128):
+        logger.info('Init Discriminator %s', name)
+        self.name = name
+        self._is_train = is_train
+        self._norm = norm
+        self._activation = activation
+        self._reuse = False
+        self._image_size = image_size
+
+    def __call__(self, input):
+        with tf.variable_scope(self.name, reuse=self._reuse):
+            D = ops.conv_block(input, 64, 'C64', 4, 2, self._is_train,
+                               self._reuse, norm=None, activation=self._activation)
+            D = ops.conv_block(D, 128, 'C128', 4, 2, self._is_train,
+                               self._reuse, self._norm, self._activation)
+            D = ops.conv_block(D, 256, 'C256', 4, 2, self._is_train,
+                               self._reuse, self._norm, self._activation)
+            num_layers = 3 if self._image_size == 256 else 1
+            for i in range(num_layers):
+                D = ops.conv_block(D, 512, 'C512_{}'.format(i), 4, 2, self._is_train,
+                                   self._reuse, self._norm, self._activation)
+            D = ops.conv_block(D, 1, 'C1', 4, 1, self._is_train,
+                               self._reuse, norm=None, activation=None, bias=True)
+            D = tf.reduce_mean(D, axis=[1,2,3])
+
+            self._reuse = True
+            self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
+            return D

+ 26 - 0
discriminator_z.py

@@ -0,0 +1,26 @@
+import tensorflow as tf
+from utils import logger
+import ops
+
+
+class DiscriminatorZ(object):
+    def __init__(self, name, is_train, norm='batch', activation='relu'):
+        logger.info('Init DiscriminatorZ %s', name)
+        self.name = name
+        self._is_train = is_train
+        self._norm = norm
+        self._activation = activation
+        self._reuse = False
+
+    def __call__(self, input):
+        with tf.variable_scope(self.name, reuse=self._reuse):
+            D = input
+            for i in range(3):
+                D = ops.mlp(D, 512, 'FC512_{}'.format(i), self._is_train,
+                            self._reuse, self._norm, self._activation)
+            D = ops.mlp(D, 1, 'FC1_{}'.format(i), self._is_train,
+                        self._reuse, norm=None, activation=None)
+
+            self._reuse = True
+            self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
+            return D

+ 14 - 0
download_pix2pix_dataset.sh

@@ -0,0 +1,14 @@
+FILE=$1
+
+if [[ $FILE != "cityscapes" && $FILE != "edges2handbags" && $FILE != "edges2shoes" &&  $FILE != "facades" && $FILE != "maps" ]]; then
+    echo "Available datasets are: edges2handbags, edges2shoes, maps, cityscapes, facades"
+    exit 1
+fi
+
+URL=https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/$FILE.tar.gz
+TAR_FILE=./datasets/$FILE.tar.gz
+TARGET_DIR=./datasets/$FILE/
+wget -N $URL -O $TAR_FILE
+mkdir $TARGET_DIR
+tar -zxvf $TAR_FILE -C ./datasets/
+rm $TAR_FILE

+ 34 - 0
encoder.py

@@ -0,0 +1,34 @@
+import tensorflow as tf
+from utils import logger
+import ops
+
+
+class Encoder(object):
+    def __init__(self, name, is_train, norm='instance', activation='leaky',
+                 image_size=128, latent_dim=8):
+        logger.info('Init Encoder %s', name)
+        self.name = name
+        self._is_train = is_train
+        self._norm = norm
+        self._activation = activation
+        self._reuse = False
+        self._image_size = image_size
+        self._latent_dim = latent_dim
+
+    def __call__(self, input):
+        with tf.variable_scope(self.name, reuse=self._reuse):
+            num_filters = [64, 128, 256, 512, 512, 512, 512]
+            if self._image_size == 256:
+                num_filters.append(512)
+
+            E = input
+            for i, n in enumerate(num_filters):
+                E = ops.conv_block(E, n, 'C{}_{}'.format(n, i), 4, 2, self._is_train,
+                                self._reuse, norm=self._norm if i else None, activation='leaky')
+            E = tf.reshape(E, [-1, 512])
+            E = ops.mlp(E, self._latent_dim, 'FC8', self._is_train, self._reuse,
+                        norm=None, activation=None)
+
+            self._reuse = True
+            self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
+            return E

+ 47 - 0
generator.py

@@ -0,0 +1,47 @@
+import tensorflow as tf
+from utils import logger
+import ops
+
+
+class Generator(object):
+    def __init__(self, name, is_train, norm='batch', image_size=128):
+        logger.info('Init Generator %s', name)
+        self.name = name
+        self._is_train = is_train
+        self._norm = norm
+        self._reuse = False
+        self._image_size = image_size
+
+    def __call__(self, input, z):
+        with tf.variable_scope(self.name, reuse=self._reuse):
+            self._dropout = tf.constant(1.0)
+            batch_size = int(input.get_shape()[0])
+            latent_dim = int(z.get_shape()[-1])
+            num_filters = [64, 128, 256, 512, 512, 512, 512]
+            if self._image_size == 256:
+                num_filters.append(512)
+
+            layers = []
+            G = input
+            for i, n in enumerate(num_filters):
+                G = ops.conv_block(G, n, 'C{}_{}'.format(n, i), 4, 2, self._is_train,
+                                self._reuse, norm=self._norm if i else None, activation='leaky')
+                layers.append(G)
+
+            z = tf.reshape(z, [batch_size, 1, 1, latent_dim])
+            G = tf.concat([G, z], axis=3)
+
+            layers.pop()
+            num_filters.pop()
+            num_filters.reverse()
+
+            for i, n in enumerate(num_filters):
+                G = ops.deconv_block(G, n, 'CD{}_{}'.format(n, i), 4, 2, self._is_train,
+                                self._reuse, norm=self._norm, activation='relu', dropout=self._dropout)
+                G = tf.concat([G, layers.pop()], axis=3)
+            G = ops.deconv_block(G, 3, 'last_layer', 4, 2, self._is_train,
+                               self._reuse, norm=None, activation='tanh', dropout=self._dropout)
+
+            self._reuse = True
+            self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
+            return G

+ 246 - 0
model.py

@@ -0,0 +1,246 @@
+import os
+import random
+
+from tqdm import trange, tqdm
+from scipy.misc import imsave
+import tensorflow as tf
+import numpy as np
+
+from generator import Generator
+from encoder import Encoder
+from discriminator import Discriminator
+from discriminator_z import DiscriminatorZ
+from utils import logger
+
+
+class BicycleGAN(object):
+    def __init__(self, args):
+        self._log_step = args.log_step
+        self._batch_size = args.batch_size
+        self._image_size = args.image_size
+        self._latent_dim = args.latent_dim
+        self._lambda1 = args.lambda1
+        self._lambda2 = args.lambda2
+        self._gamma = args.gamma
+
+        self._augment_size = self._image_size + (30 if self._image_size == 256 else 15)
+        self._image_shape = [self._image_size, self._image_size, 3]
+
+        self.is_train = tf.placeholder(tf.bool, name='is_train')
+        self.lr = tf.placeholder(tf.float32, name='lr')
+        self.global_step = tf.contrib.framework.get_or_create_global_step(graph=None)
+
+        image_a = self.image_a = \
+            tf.placeholder(tf.float32, [self._batch_size] + self._image_shape, name='image_a')
+        image_b = self.image_b = \
+            tf.placeholder(tf.float32, [self._batch_size] + self._image_shape, name='image_b')
+        z = self.z = \
+            tf.placeholder(tf.float32, [self._batch_size, self._latent_dim], name='z')
+
+        # Data augmentation
+        seed = random.randint(0, 2**31 - 1)
+        def augment_image(image):
+            image = tf.image.resize_images(image, [self._augment_size, self._augment_size])
+            image = tf.random_crop(image, [self._batch_size] + self._image_shape, seed=seed)
+            image = tf.map_fn(lambda x: tf.image.random_flip_left_right(x, seed), image)
+            return image
+
+        image_a = tf.cond(self.is_train,
+                          lambda: augment_image(image_a),
+                          lambda: image_a)
+        image_b = tf.cond(self.is_train,
+                          lambda: augment_image(image_b),
+                          lambda: image_b)
+
+        # Generator
+        G = Generator('G', is_train=self.is_train,
+                      norm='batch', image_size=self._image_size)
+
+        # Discriminator
+        D = Discriminator('D', is_train=self.is_train,
+                          norm='batch', activation='leaky',
+                          image_size=self._image_size)
+        Dz = DiscriminatorZ('Dz', is_train=self.is_train,
+                             norm='batch', activation='relu')
+
+        # Encoder
+        E = Encoder('E', is_train=self.is_train,
+                    norm='batch', activation='relu',
+                    image_size=self._image_size, latent_dim=self._latent_dim)
+
+        # Generate images (a->b)
+        image_ab = self.image_ab = G(image_a, z)
+        z_reconstruct = E(image_ab)
+
+        # Encode z (G(A, z) -> z)
+        z_encoded = E(image_b)
+        image_ab_encoded = G(image_a, z_encoded)
+
+        # Discriminate real/fake images
+        D_real = D(image_b)
+        D_fake = D(image_ab)
+        D_fake_encoded = D(image_ab_encoded)
+        Dz_real = Dz(z)
+        Dz_fake = Dz(z_encoded)
+
+        loss_image_reconstruct = tf.reduce_mean(tf.abs(image_b - image_ab_encoded))
+
+        loss_gan = (tf.reduce_mean(tf.squared_difference(D_real, 0.9)) +
+            tf.reduce_mean(tf.square(D_fake))) * 0.5
+
+        loss_image_cycle = (tf.reduce_mean(tf.squared_difference(D_real, 0.9)) +
+            tf.reduce_mean(tf.square(D_fake_encoded))) * 0.5
+
+        loss_latent_cycle = tf.reduce_mean(tf.abs(z - z_reconstruct))
+
+        loss_Dz = (tf.reduce_mean(tf.squared_difference(Dz_real, 0.9)) +
+            tf.reduce_mean(tf.square(Dz_fake))) * 0.5
+
+        loss = self._gamma * loss_Dz \
+            + loss_image_cycle - self._lambda1 * loss_image_reconstruct \
+            + loss_gan - self._lambda2 * loss_latent_cycle
+
+        # Optimizer
+        self.optimizer_D = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
+                            .minimize(loss, var_list=D.var_list, global_step=self.global_step)
+        self.optimizer_G = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
+                            .minimize(-loss, var_list=G.var_list)
+        self.optimizer_Dz = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
+                            .minimize(loss, var_list=Dz.var_list)
+        self.optimizer_E = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
+                            .minimize(-loss, var_list=E.var_list)
+
+        # Summaries
+        self.loss_image_reconstruct = loss_image_reconstruct
+        self.loss_image_cycle = loss_image_cycle
+        self.loss_latent_cycle = loss_latent_cycle
+        self.loss_gan = loss_gan
+        self.loss_Dz = loss_Dz
+        self.loss = loss
+
+        tf.summary.scalar('loss/image_reconstruct', loss_image_reconstruct)
+        tf.summary.scalar('loss/image_cycle', loss_image_cycle)
+        tf.summary.scalar('loss/latent_cycle', loss_latent_cycle)
+        tf.summary.scalar('loss/gan', loss_gan)
+        tf.summary.scalar('loss/Dz', loss_Dz)
+        tf.summary.scalar('loss/total', loss)
+        tf.summary.scalar('model/D_real', tf.reduce_mean(D_real))
+        tf.summary.scalar('model/D_fake', tf.reduce_mean(D_fake))
+        tf.summary.scalar('model/D_fake_encoded', tf.reduce_mean(D_fake_encoded))
+        tf.summary.scalar('model/Dz_real', tf.reduce_mean(Dz_real))
+        tf.summary.scalar('model/Dz_fake', tf.reduce_mean(Dz_fake))
+        tf.summary.scalar('model/lr', self.lr)
+        tf.summary.image('image/A', image_a[0:1])
+        tf.summary.image('image/B', image_b[0:1])
+        tf.summary.image('image/A-B', image_ab[0:1])
+        tf.summary.image('image/A-B_encoded', image_ab_encoded[0:1])
+        self.summary_op = tf.summary.merge_all()
+
+    def train(self, sess, summary_writer, data_A, data_B):
+        logger.info('Start training.')
+        logger.info('  {} images from A'.format(len(data_A)))
+        logger.info('  {} images from B'.format(len(data_B)))
+
+        data_size = min(len(data_A), len(data_B))
+        num_batch = data_size // self._batch_size
+        epoch_length = num_batch * self._batch_size
+
+        num_initial_iter = 8
+        num_decay_iter = 2
+        lr = lr_initial = 0.0002
+        lr_decay = lr_initial / num_decay_iter
+
+        initial_step = sess.run(self.global_step)
+        num_global_step = (num_initial_iter + num_decay_iter) * epoch_length
+        t = trange(initial_step, num_global_step,
+                   total=num_global_step, initial=initial_step)
+
+        for step in t:
+            #TODO: resume training with global_step
+            epoch = step // epoch_length
+            iter = step % epoch_length
+
+            if epoch > num_initial_iter:
+                lr = max(0.0, lr_initial - (epoch - num_initial_iter) * lr_decay)
+
+            if iter == 0:
+                data = zip(data_A, data_B)
+                random.shuffle(data)
+                data_A, data_B = zip(*data)
+
+            image_a = np.stack(data_A[iter*self._batch_size:(iter+1)*self._batch_size])
+            image_b = np.stack(data_B[iter*self._batch_size:(iter+1)*self._batch_size])
+            sample_z = np.random.uniform(-1, 1, size=(self._batch_size, self._latent_dim))
+
+            fetches = [self.loss,
+                       self.optimizer_D, self.optimizer_Dz,
+                       self.optimizer_G, self.optimizer_E]
+            if step % self._log_step == 0:
+                fetches += [self.summary_op]
+
+            fetched = sess.run(fetches, feed_dict={self.image_a: image_a,
+                                                   self.image_b: image_b,
+                                                   self.is_train: True,
+                                                   self.lr: lr,
+                                                   self.z: sample_z})
+
+            z = np.random.uniform(-1, 1, size=(1, self._latent_dim))
+            image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
+                                                   self.image_b: image_b,
+                                                   self.lr: lr,
+                                                    self.z: z,
+                                                    self.is_train: True})
+            imsave('results/r_{}.jpg'.format(step), np.squeeze(image_ab, axis=0))
+
+            if step % self._log_step == 0:
+                summary_writer.add_summary(fetched[-1], step)
+                summary_writer.flush()
+                t.set_description('Loss({:.3f})'.format(fetched[0]))
+
+    def test(self, sess, data_A, data_B, base_dir):
+        step = 0
+        for (dataA, dataB) in tqdm(zip(data_A, data_B)):
+            step += 1
+            image_a = np.expand_dims(dataA, axis=0)
+            image_b = np.expand_dims(dataB, axis=0)
+            images = []
+            images.append(image_a)
+            images.append(image_b)
+
+            for i in range(23):
+                z = np.random.uniform(-1, 1, size=(1, self._latent_dim))
+                image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
+                                                        self.z: z,
+                                                        self.is_train: True})
+                images.append(image_ab)
+
+            image_rows = []
+            for i in range(5):
+                image_rows.append(np.concatenate(images[i*5:(i+1)*5], axis=2))
+            images = np.concatenate(image_rows, axis=1)
+            images = np.squeeze(images, axis=0)
+            imsave(os.path.join(base_dir, 'random_{}.jpg'.format(step)), images)
+
+        step=0
+        for (dataA, dataB) in tqdm(zip(data_A, data_B)):
+            step += 1
+            image_a = np.expand_dims(dataA, axis=0)
+            image_b = np.expand_dims(dataB, axis=0)
+            images = []
+            images.append(image_a)
+            images.append(image_b)
+
+            for i in range(23):
+                z = np.zeros((1, self._latent_dim))
+                z[0][0] = (i / 23.0 - 0.5) * 2.0
+                image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
+                                                        self.z: z,
+                                                        self.is_train: True})
+                images.append(image_ab)
+
+            image_rows = []
+            for i in range(5):
+                image_rows.append(np.concatenate(images[i*5:(i+1)*5], axis=2))
+            images = np.concatenate(image_rows, axis=1)
+            images = np.squeeze(images, axis=0)
+            imsave(os.path.join(base_dir, 'linear_{}.jpg'.format(step)), images)

+ 115 - 0
ops.py

@@ -0,0 +1,115 @@
+import tensorflow as tf
+
+
+def _norm(input, is_train, reuse=True, norm=None):
+    assert norm in ['instance', 'batch', None]
+    if norm == 'instance':
+        with tf.variable_scope('instance_norm', reuse=reuse):
+            eps = 1e-5
+            mean, sigma = tf.nn.moments(input, [1, 2], keep_dims=True)
+            normalized = (input - mean) / (tf.sqrt(sigma) + eps)
+            out = normalized
+            # Apply momentum (not mendatory)
+            #c = input.get_shape()[-1]
+            #shift = tf.get_variable('shift', shape=[c],
+            #                        initializer=tf.zeros_initializer())
+            #scale = tf.get_variable('scale', shape=[c],
+            #                        initializer=tf.random_normal_initializer(1.0, 0.02))
+            #out = scale * normalized + shift
+    elif norm == 'batch':
+        with tf.variable_scope('batch_norm', reuse=reuse):
+            out = tf.contrib.layers.batch_norm(input,
+                                               decay=0.99, center=True,
+                                               scale=True, is_training=is_train,
+                                               updates_collections=None)
+    else:
+        out = input
+
+    return out
+
+def _activation(input, activation=None):
+    assert activation in ['relu', 'leaky', 'tanh', 'sigmoid', None]
+    if activation == 'relu':
+        return tf.nn.relu(input)
+    elif activation == 'leaky':
+        return tf.contrib.keras.layers.LeakyReLU(0.2)(input)
+    elif activation == 'tanh':
+        return tf.tanh(input)
+    elif activation == 'sigmoid':
+        return tf.sigmoid(input)
+    else:
+        return input
+
+def conv2d(input, num_filters, filter_size, stride, reuse=False,
+           pad='SAME', dtype=tf.float32, bias=False):
+    stride_shape = [1, stride, stride, 1]
+    filter_shape = [filter_size, filter_size, input.get_shape()[3], num_filters]
+
+    w = tf.get_variable('w', filter_shape, dtype, tf.random_normal_initializer(0.0, 0.02))
+    if pad == 'REFLECT':
+        p = (filter_size - 1) // 2
+        x = tf.pad(input, [[0,0],[p,p],[p,p],[0,0]], 'REFLECT')
+        conv = tf.nn.conv2d(x, w, stride_shape, padding='VALID')
+    else:
+        assert pad in ['SAME', 'VALID']
+        conv = tf.nn.conv2d(input, w, stride_shape, padding=pad)
+
+    if bias:
+        b = tf.get_variable('b', [1,1,1,num_filters], initializer=tf.constant_initializer(0.0))
+        conv = conv + b
+    return conv
+
+def conv2d_transpose(input, num_filters, filter_size, stride, reuse,
+                     pad='SAME', dtype=tf.float32):
+    assert pad == 'SAME'
+    n, h, w, c = input.get_shape().as_list()
+    stride_shape = [1, stride, stride, 1]
+    filter_shape = [filter_size, filter_size, num_filters, c]
+    output_shape = [n, h * stride, w * stride, num_filters]
+
+    w = tf.get_variable('w', filter_shape, dtype, tf.random_normal_initializer(0.0, 0.02))
+    deconv = tf.nn.conv2d_transpose(input, w, output_shape, stride_shape, pad)
+    return deconv
+
+def mlp(input, out_dim, name, is_train, reuse, norm=None, activation=None,
+        dtype=tf.float32, bias=True):
+    with tf.variable_scope(name, reuse=reuse):
+        _, n = input.get_shape()
+        w = tf.get_variable('w', [n, out_dim], dtype, tf.random_normal_initializer(0.0, 0.02))
+        out = tf.matmul(input, w)
+        if bias:
+            b = tf.get_variable('b', [out_dim], initializer=tf.constant_initializer(0.0))
+            out = out + b
+        out = _activation(out, activation)
+        out = _norm(out, is_train, reuse, norm)
+        return out
+
+def conv_block(input, num_filters, name, k_size, stride, is_train, reuse, norm,
+          activation, pad='SAME', bias=False):
+    with tf.variable_scope(name, reuse=reuse):
+        out = conv2d(input, num_filters, k_size, stride, reuse, pad, bias=bias)
+        out = _norm(out, is_train, reuse, norm)
+        out = _activation(out, activation)
+        return out
+
+def residual(input, num_filters, name, is_train, reuse, norm, pad='REFLECT'):
+    with tf.variable_scope(name, reuse=reuse):
+        with tf.variable_scope('res1', reuse=reuse):
+            out = conv2d(input, num_filters, 3, 1, reuse, pad)
+            out = _norm(out, is_train, reuse, norm)
+            out = tf.nn.relu(out)
+
+        with tf.variable_scope('res2', reuse=reuse):
+            out = conv2d(out, num_filters, 3, 1, reuse, pad)
+            out = _norm(out, is_train, reuse, norm)
+
+        return tf.nn.relu(input + out)
+
+def deconv_block(input, num_filters, name, k_size, stride, is_train, reuse,
+                 norm, activation, dropout):
+    with tf.variable_scope(name, reuse=reuse):
+        out = conv2d_transpose(input, num_filters, k_size, stride, reuse)
+        out = _norm(out, is_train, reuse, norm)
+        out = tf.nn.dropout(out, dropout)
+        out = _activation(out, activation)
+        return out

+ 12 - 0
utils.py

@@ -0,0 +1,12 @@
+import logging
+import os
+
+
+# start logging
+logging.info("Start CycleGAN")
+logger = logging.getLogger('cycle-gan')
+logger.setLevel(logging.INFO)
+
+def makedirs(path):
+    if not os.path.exists(path):
+        os.makedirs(path)