Browse Source

initial commit

initial commit
karan hora 5 years ago
parent
commit
9d0cb59da7

+ 7 - 0
.gitignore

@@ -0,0 +1,7 @@
+__pycache__
+.idea/*
+logs/*
+checkpoint/*
+datasets/*
+test/*
+sample/*

+ 27 - 0
README.md

@@ -0,0 +1,27 @@
+# CycleGAN
+
+As proposed by [Jun-Yan Zhu](https://people.eecs.berkeley.edu/~junyanz/) in 
+[Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networkssee](https://arxiv.org/pdf/1703.10593.pdf). 
+For example in paper:
+
+
+## Prerequisites
+- tensorflow r1.1
+- numpy 1.11.0
+- scipy 0.17.0
+- pillow 3.3.0
+
+
+### Train
+```bash
+python main.py --dataset_dir=med_image
+```
+
+### Test
+```bash
+python main.py --dataset_dir=med-image --phase=test --which_direction=AtoB
+```
+
+## References
+- CycleGAN implementation, https://github.com/XHUJOY/CycleGAN-tensorflow
+- Torch CycleGAN, https://github.com/junyanz/CycleGAN

+ 28 - 0
dice_score.m

@@ -0,0 +1,28 @@
+imgs = load('./brainTumorDataPublic_1-766/500.mat');
+binaryMask = imgs.cjdata.tumorMask;
+iptsetpref('ImshowBorder','tight');
+figure(2);
+imshow(img2,[  ]);
+
+I = imread('A_2_A2B.jpg');
+
+%Convert the image into a binary image.
+
+figure(3);
+imshow(I,[  ]);
+
+Igray = rgb2gray(I);
+BW = imbinarize(Igray);
+
+figure(4);
+imshow(BW,[  ]);
+%Display the original image next to the binary version.
+
+figure(3);
+imshowpair(I,BW,'montage');
+
+andImage = binaryMask & BW;
+orImage  = binaryMask | BW;
+diceCoeff = sum(andImage) / sum(orImage);
+disp(diceCoeff);
+

+ 51 - 0
dice_score_multiple.m

@@ -0,0 +1,51 @@
+% pick files from two folders
+allScores = [];
+for k = 1:124
+
+	% Create an image filename, and read it in to a variable called imageData.
+	%read real highlight image
+    
+    realHLfileName = strcat('./real_highlights/','B_', num2str(k+1),'_realB','.jpg');
+	if exist(realHLfileName, 'file')
+		realImage = imread(realHLfileName);
+	else
+		fprintf('File %s does not exist.\n', realHLfileName);
+    end
+	
+    genHLfileName = strcat('./generated_highlights/','A_', num2str(k),'_A2B','.jpg');
+	if exist(genHLfileName, 'file')
+		genImage = imread(genHLfileName);
+	else
+		fprintf('File %s does not exist.\n', genHLfileName);
+    end
+
+    %imshowpair(I,BW,'montage');
+    realIgray = rgb2gray(realImage);
+    genIgray = rgb2gray(genImage);
+    
+    realBW = imbinarize(realIgray);
+    genBW = imbinarize(genIgray);
+    
+%     figure(1);
+%     
+% 	imshowpair(realImage,realIgray,'montage');
+%     figure(2);
+%     %,genImage,genIgray,genBW,
+%     imshowpair(realBW,genBW,'montage');
+%     
+%     figure(3);
+%     imshowpair(I,BW,'montage');
+%     
+    andImage = realBW & genBW;
+    orImage  = realBW | genBW;
+    diceCoeff = sum(andImage) / sum(orImage);
+    disp(diceCoeff);
+    
+    allScores = [allScores, diceCoeff];
+
+end
+
+disp(allScores);
+meanScore = mean(allScores);
+disp("-------");
+disp(meanScore);

+ 15 - 0
download_dataset.sh

@@ -0,0 +1,15 @@
+mkdir datasets
+FILE=$1
+
+if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" &&  $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then
+    echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos"
+    exit 1
+fi
+
+URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip
+ZIP_FILE=./datasets/$FILE.zip
+TARGET_DIR=./datasets/$FILE/
+wget -N $URL -O $ZIP_FILE
+mkdir $TARGET_DIR
+unzip $ZIP_FILE -d ./datasets/
+rm $ZIP_FILE

BIN
imgs/AtoB_n02381460_4530.jpg


BIN
imgs/AtoB_n02381460_4660.jpg


BIN
imgs/AtoB_n02381460_510.jpg


BIN
imgs/AtoB_n02381460_8980.jpg


BIN
imgs/BtoA_n02391049_1760.jpg


BIN
imgs/BtoA_n02391049_3070.jpg


BIN
imgs/BtoA_n02391049_5100.jpg


BIN
imgs/BtoA_n02391049_7150.jpg


BIN
imgs/n02381460_4530.jpg


BIN
imgs/n02381460_4660.jpg


BIN
imgs/n02381460_510.jpg


BIN
imgs/n02381460_8980.jpg


BIN
imgs/n02391049_1760.jpg


BIN
imgs/n02391049_3070.jpg


BIN
imgs/n02391049_5100.jpg


BIN
imgs/n02391049_7150.jpg


BIN
imgs/teaser.jpg


+ 53 - 0
main.py

@@ -0,0 +1,53 @@
+import argparse
+import os
+import tensorflow as tf
+tf.set_random_seed(19)
+from model import cyclegan
+
+parser = argparse.ArgumentParser(description='')
+parser.add_argument('--dataset_dir', dest='dataset_dir', default='horse2zebra', help='path of the dataset')
+parser.add_argument('--epoch', dest='epoch', type=int, default=200, help='# of epoch')
+parser.add_argument('--epoch_step', dest='epoch_step', type=int, default=100, help='# of epoch to decay lr')
+parser.add_argument('--batch_size', dest='batch_size', type=int, default=1, help='# images in batch')
+parser.add_argument('--train_size', dest='train_size', type=int, default=1e8, help='# images used to train')
+parser.add_argument('--load_size', dest='load_size', type=int, default=286, help='scale images to this size')
+parser.add_argument('--fine_size', dest='fine_size', type=int, default=256, help='then crop to this size')
+parser.add_argument('--ngf', dest='ngf', type=int, default=64, help='# of gen filters in first conv layer')
+parser.add_argument('--ndf', dest='ndf', type=int, default=64, help='# of discri filters in first conv layer')
+parser.add_argument('--input_nc', dest='input_nc', type=int, default=3, help='# of input image channels')
+parser.add_argument('--output_nc', dest='output_nc', type=int, default=3, help='# of output image channels')
+parser.add_argument('--lr', dest='lr', type=float, default=0.0002, help='initial learning rate for adam')
+parser.add_argument('--beta1', dest='beta1', type=float, default=0.5, help='momentum term of adam')
+parser.add_argument('--which_direction', dest='which_direction', default='AtoB', help='AtoB or BtoA')
+parser.add_argument('--phase', dest='phase', default='train', help='train, test')
+parser.add_argument('--save_freq', dest='save_freq', type=int, default=1000, help='save a model every save_freq iterations')
+parser.add_argument('--print_freq', dest='print_freq', type=int, default=100, help='print the debug information every print_freq iterations')
+parser.add_argument('--continue_train', dest='continue_train', type=bool, default=False, help='if continue training, load the latest model: 1: true, 0: false')
+parser.add_argument('--checkpoint_dir', dest='checkpoint_dir', default='./checkpoint', help='models are saved here')
+parser.add_argument('--sample_dir', dest='sample_dir', default='./sample', help='sample are saved here')
+parser.add_argument('--test_dir', dest='test_dir', default='./test', help='test sample are saved here')
+parser.add_argument('--L1_lambda', dest='L1_lambda', type=float, default=10.0, help='weight on L1 term in objective')
+parser.add_argument('--use_resnet', dest='use_resnet', type=bool, default=True, help='generation network using reidule block')
+parser.add_argument('--use_lsgan', dest='use_lsgan', type=bool, default=True, help='gan loss defined in lsgan')
+parser.add_argument('--max_size', dest='max_size', type=int, default=50, help='max size of image pool, 0 means do not use image pool')
+
+args = parser.parse_args()
+
+
+def main(_):
+    if not os.path.exists(args.checkpoint_dir):
+        os.makedirs(args.checkpoint_dir)
+    if not os.path.exists(args.sample_dir):
+        os.makedirs(args.sample_dir)
+    if not os.path.exists(args.test_dir):
+        os.makedirs(args.test_dir)
+
+    tfconfig = tf.ConfigProto(allow_soft_placement=True)
+    tfconfig.gpu_options.allow_growth = True
+    with tf.Session(config=tfconfig) as sess:
+        model = cyclegan(sess, args)
+        model.train(args) if args.phase == 'train' \
+            else model.test(args)
+
+if __name__ == '__main__':
+    tf.app.run()

+ 264 - 0
model.py

@@ -0,0 +1,264 @@
+from __future__ import division
+import os
+import time
+from glob import glob
+import tensorflow as tf
+import numpy as np
+from collections import namedtuple
+
+from module import *
+from utils import *
+
+
+class cyclegan(object):
+    def __init__(self, sess, args):
+        self.sess = sess
+        self.batch_size = args.batch_size
+        self.image_size = args.fine_size
+        self.input_c_dim = args.input_nc
+        self.output_c_dim = args.output_nc
+        self.L1_lambda = args.L1_lambda
+        self.dataset_dir = args.dataset_dir
+
+        self.discriminator = discriminator
+        if args.use_resnet:
+            self.generator = generator_resnet
+        else:
+            self.generator = generator_unet
+        if args.use_lsgan:
+            self.criterionGAN = mae_criterion
+        else:
+            self.criterionGAN = sce_criterion
+
+        OPTIONS = namedtuple('OPTIONS', 'batch_size image_size \
+                              gf_dim df_dim output_c_dim is_training')
+        self.options = OPTIONS._make((args.batch_size, args.fine_size,
+                                      args.ngf, args.ndf, args.output_nc,
+                                      args.phase == 'train'))
+
+        self._build_model()
+        self.saver = tf.train.Saver()
+        self.pool = ImagePool(args.max_size)
+
+    def _build_model(self):
+        self.real_data = tf.placeholder(tf.float32,
+                                        [None, self.image_size, self.image_size,
+                                         self.input_c_dim + self.output_c_dim],
+                                        name='real_A_and_B_images')
+
+        self.real_A = self.real_data[:, :, :, :self.input_c_dim]
+        self.real_B = self.real_data[:, :, :, self.input_c_dim:self.input_c_dim + self.output_c_dim]
+
+        self.fake_B = self.generator(self.real_A, self.options, False, name="generatorA2B")
+        self.fake_A_ = self.generator(self.fake_B, self.options, False, name="generatorB2A")
+        self.fake_A = self.generator(self.real_B, self.options, True, name="generatorB2A")
+        self.fake_B_ = self.generator(self.fake_A, self.options, True, name="generatorA2B")
+
+        self.DB_fake = self.discriminator(self.fake_B, self.options, reuse=False, name="discriminatorB")
+        self.DA_fake = self.discriminator(self.fake_A, self.options, reuse=False, name="discriminatorA")
+        self.g_loss_a2b = self.criterionGAN(self.DB_fake, tf.ones_like(self.DB_fake)) \
+            + self.L1_lambda * abs_criterion(self.real_A, self.fake_A_) \
+            + self.L1_lambda * abs_criterion(self.real_B, self.fake_B_)
+        self.g_loss_b2a = self.criterionGAN(self.DA_fake, tf.ones_like(self.DA_fake)) \
+            + self.L1_lambda * abs_criterion(self.real_A, self.fake_A_) \
+            + self.L1_lambda * abs_criterion(self.real_B, self.fake_B_)
+        self.g_loss = self.criterionGAN(self.DA_fake, tf.ones_like(self.DA_fake)) \
+            + self.criterionGAN(self.DB_fake, tf.ones_like(self.DB_fake)) \
+            + self.L1_lambda * abs_criterion(self.real_A, self.fake_A_) \
+            + self.L1_lambda * abs_criterion(self.real_B, self.fake_B_)
+
+        self.fake_A_sample = tf.placeholder(tf.float32,
+                                            [None, self.image_size, self.image_size,
+                                             self.input_c_dim], name='fake_A_sample')
+        self.fake_B_sample = tf.placeholder(tf.float32,
+                                            [None, self.image_size, self.image_size,
+                                             self.output_c_dim], name='fake_B_sample')
+        self.DB_real = self.discriminator(self.real_B, self.options, reuse=True, name="discriminatorB")
+        self.DA_real = self.discriminator(self.real_A, self.options, reuse=True, name="discriminatorA")
+        self.DB_fake_sample = self.discriminator(self.fake_B_sample, self.options, reuse=True, name="discriminatorB")
+        self.DA_fake_sample = self.discriminator(self.fake_A_sample, self.options, reuse=True, name="discriminatorA")
+
+        self.db_loss_real = self.criterionGAN(self.DB_real, tf.ones_like(self.DB_real))
+        self.db_loss_fake = self.criterionGAN(self.DB_fake_sample, tf.zeros_like(self.DB_fake_sample))
+        self.db_loss = (self.db_loss_real + self.db_loss_fake) / 2
+        self.da_loss_real = self.criterionGAN(self.DA_real, tf.ones_like(self.DA_real))
+        self.da_loss_fake = self.criterionGAN(self.DA_fake_sample, tf.zeros_like(self.DA_fake_sample))
+        self.da_loss = (self.da_loss_real + self.da_loss_fake) / 2
+        self.d_loss = self.da_loss + self.db_loss
+
+        self.g_loss_a2b_sum = tf.summary.scalar("g_loss_a2b", self.g_loss_a2b)
+        self.g_loss_b2a_sum = tf.summary.scalar("g_loss_b2a", self.g_loss_b2a)
+        self.g_loss_sum = tf.summary.scalar("g_loss", self.g_loss)
+        self.g_sum = tf.summary.merge([self.g_loss_a2b_sum, self.g_loss_b2a_sum, self.g_loss_sum])
+        self.db_loss_sum = tf.summary.scalar("db_loss", self.db_loss)
+        self.da_loss_sum = tf.summary.scalar("da_loss", self.da_loss)
+        self.d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)
+        self.db_loss_real_sum = tf.summary.scalar("db_loss_real", self.db_loss_real)
+        self.db_loss_fake_sum = tf.summary.scalar("db_loss_fake", self.db_loss_fake)
+        self.da_loss_real_sum = tf.summary.scalar("da_loss_real", self.da_loss_real)
+        self.da_loss_fake_sum = tf.summary.scalar("da_loss_fake", self.da_loss_fake)
+        self.d_sum = tf.summary.merge(
+            [self.da_loss_sum, self.da_loss_real_sum, self.da_loss_fake_sum,
+             self.db_loss_sum, self.db_loss_real_sum, self.db_loss_fake_sum,
+             self.d_loss_sum]
+        )
+
+        self.test_A = tf.placeholder(tf.float32,
+                                     [None, self.image_size, self.image_size,
+                                      self.input_c_dim], name='test_A')
+        self.test_B = tf.placeholder(tf.float32,
+                                     [None, self.image_size, self.image_size,
+                                      self.output_c_dim], name='test_B')
+        self.testB = self.generator(self.test_A, self.options, True, name="generatorA2B")
+        self.testA = self.generator(self.test_B, self.options, True, name="generatorB2A")
+
+        t_vars = tf.trainable_variables()
+        self.d_vars = [var for var in t_vars if 'discriminator' in var.name]
+        self.g_vars = [var for var in t_vars if 'generator' in var.name]
+        for var in t_vars: print(var.name)
+
+    def train(self, args):
+        """Train cyclegan"""
+        self.lr = tf.placeholder(tf.float32, None, name='learning_rate')
+        self.d_optim = tf.train.AdamOptimizer(self.lr, beta1=args.beta1) \
+            .minimize(self.d_loss, var_list=self.d_vars)
+        self.g_optim = tf.train.AdamOptimizer(self.lr, beta1=args.beta1) \
+            .minimize(self.g_loss, var_list=self.g_vars)
+
+        init_op = tf.global_variables_initializer()
+        self.sess.run(init_op)
+        self.writer = tf.summary.FileWriter("./logs", self.sess.graph)
+
+        counter = 1
+        start_time = time.time()
+
+        if args.continue_train:
+            if self.load(args.checkpoint_dir):
+                print(" [*] Load SUCCESS")
+            else:
+                print(" [!] Load failed...")
+
+        for epoch in range(args.epoch):
+            dataA = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/trainA'))
+            dataB = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/trainB'))
+            np.random.shuffle(dataA)
+            np.random.shuffle(dataB)
+            batch_idxs = min(min(len(dataA), len(dataB)), args.train_size) // self.batch_size
+            lr = args.lr if epoch < args.epoch_step else args.lr*(args.epoch-epoch)/(args.epoch-args.epoch_step)
+
+            for idx in range(0, batch_idxs):
+                batch_files = list(zip(dataA[idx * self.batch_size:(idx + 1) * self.batch_size],
+                                       dataB[idx * self.batch_size:(idx + 1) * self.batch_size]))
+                batch_images = [load_train_data(batch_file, args.load_size, args.fine_size) for batch_file in batch_files]
+                batch_images = np.array(batch_images).astype(np.float32)
+
+                # Update G network and record fake outputs
+                fake_A, fake_B, _, summary_str = self.sess.run(
+                    [self.fake_A, self.fake_B, self.g_optim, self.g_sum],
+                    feed_dict={self.real_data: batch_images, self.lr: lr})
+                self.writer.add_summary(summary_str, counter)
+                [fake_A, fake_B] = self.pool([fake_A, fake_B])
+
+                # Update D network
+                _, summary_str = self.sess.run(
+                    [self.d_optim, self.d_sum],
+                    feed_dict={self.real_data: batch_images,
+                               self.fake_A_sample: fake_A,
+                               self.fake_B_sample: fake_B,
+                               self.lr: lr})
+                self.writer.add_summary(summary_str, counter)
+
+                counter += 1
+                print(("Epoch: [%2d] [%4d/%4d] time: %4.4f" % (
+                    epoch, idx, batch_idxs, time.time() - start_time)))
+
+                if np.mod(counter, args.print_freq) == 1:
+                    self.sample_model(args.sample_dir, epoch, idx)
+
+                if np.mod(counter, args.save_freq) == 2:
+                    self.save(args.checkpoint_dir, counter)
+
+    def save(self, checkpoint_dir, step):
+        model_name = "cyclegan.model"
+        model_dir = "%s_%s" % (self.dataset_dir, self.image_size)
+        checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
+
+        if not os.path.exists(checkpoint_dir):
+            os.makedirs(checkpoint_dir)
+
+        self.saver.save(self.sess,
+                        os.path.join(checkpoint_dir, model_name),
+                        global_step=step)
+
+    def load(self, checkpoint_dir):
+        print(" [*] Reading checkpoint...")
+
+        model_dir = "%s_%s" % (self.dataset_dir, self.image_size)
+        checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
+
+        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
+        if ckpt and ckpt.model_checkpoint_path:
+            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
+            self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
+            return True
+        else:
+            return False
+
+    def sample_model(self, sample_dir, epoch, idx):
+        dataA = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/testA'))
+        dataB = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/testB'))
+        np.random.shuffle(dataA)
+        np.random.shuffle(dataB)
+        batch_files = list(zip(dataA[:self.batch_size], dataB[:self.batch_size]))
+        sample_images = [load_train_data(batch_file, is_testing=True) for batch_file in batch_files]
+        sample_images = np.array(sample_images).astype(np.float32)
+
+        fake_A, fake_B = self.sess.run(
+            [self.fake_A, self.fake_B],
+            feed_dict={self.real_data: sample_images}
+        )
+        save_images(fake_A, [self.batch_size, 1],
+                    './{}/A_{:02d}_{:04d}.jpg'.format(sample_dir, epoch, idx))
+        save_images(fake_B, [self.batch_size, 1],
+                    './{}/B_{:02d}_{:04d}.jpg'.format(sample_dir, epoch, idx))
+
+    def test(self, args):
+        """Test cyclegan"""
+        init_op = tf.global_variables_initializer()
+        self.sess.run(init_op)
+        if args.which_direction == 'AtoB':
+            sample_files = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/testA'))
+        elif args.which_direction == 'BtoA':
+            sample_files = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/testB'))
+        else:
+            raise Exception('--which_direction must be AtoB or BtoA')
+
+        if self.load(args.checkpoint_dir):
+            print(" [*] Load SUCCESS")
+        else:
+            print(" [!] Load failed...")
+
+        # write html for visual comparison
+        index_path = os.path.join(args.test_dir, '{0}_index.html'.format(args.which_direction))
+        index = open(index_path, "w")
+        index.write("<html><body><table><tr>")
+        index.write("<th>name</th><th>input</th><th>output</th></tr>")
+
+        out_var, in_var = (self.testB, self.test_A) if args.which_direction == 'AtoB' else (
+            self.testA, self.test_B)
+
+        for sample_file in sample_files:
+            print('Processing image: ' + sample_file)
+            sample_image = [load_test_data(sample_file, args.fine_size)]
+            sample_image = np.array(sample_image).astype(np.float32)
+            image_path = os.path.join(args.test_dir,
+                                      '{0}_{1}'.format(args.which_direction, os.path.basename(sample_file)))
+            fake_img = self.sess.run(out_var, feed_dict={in_var: sample_image})
+            save_images(fake_img, [1, 1], image_path)
+            index.write("<td>%s</td>" % os.path.basename(image_path))
+            index.write("<td><img src='%s'></td>" % (sample_file if os.path.isabs(sample_file) else (
+                '..' + os.path.sep + sample_file)))
+            index.write("<td><img src='%s'></td>" % (image_path if os.path.isabs(image_path) else (
+                '..' + os.path.sep + image_path)))
+            index.write("</tr>")
+        index.close()

+ 148 - 0
module.py

@@ -0,0 +1,148 @@
+from __future__ import division
+import tensorflow as tf
+from ops import *
+from utils import *
+
+
+def discriminator(image, options, reuse=False, name="discriminator"):
+
+    with tf.variable_scope(name):
+        # image is 256 x 256 x input_c_dim
+        if reuse:
+            tf.get_variable_scope().reuse_variables()
+        else:
+            assert tf.get_variable_scope().reuse is False
+
+        h0 = lrelu(conv2d(image, options.df_dim, name='d_h0_conv'))
+        # h0 is (128 x 128 x self.df_dim)
+        h1 = lrelu(instance_norm(conv2d(h0, options.df_dim*2, name='d_h1_conv'), 'd_bn1'))
+        # h1 is (64 x 64 x self.df_dim*2)
+        h2 = lrelu(instance_norm(conv2d(h1, options.df_dim*4, name='d_h2_conv'), 'd_bn2'))
+        # h2 is (32x 32 x self.df_dim*4)
+        h3 = lrelu(instance_norm(conv2d(h2, options.df_dim*8, s=1, name='d_h3_conv'), 'd_bn3'))
+        # h3 is (32 x 32 x self.df_dim*8)
+        h4 = conv2d(h3, 1, s=1, name='d_h3_pred')
+        # h4 is (32 x 32 x 1)
+        return h4
+
+
+def generator_unet(image, options, reuse=False, name="generator"):
+
+    dropout_rate = 0.5 if options.is_training else 1.0
+    with tf.variable_scope(name):
+        # image is 256 x 256 x input_c_dim
+        if reuse:
+            tf.get_variable_scope().reuse_variables()
+        else:
+            assert tf.get_variable_scope().reuse is False
+
+        # image is (256 x 256 x input_c_dim)
+        e1 = instance_norm(conv2d(image, options.gf_dim, name='g_e1_conv'))
+        # e1 is (128 x 128 x self.gf_dim)
+        e2 = instance_norm(conv2d(lrelu(e1), options.gf_dim*2, name='g_e2_conv'), 'g_bn_e2')
+        # e2 is (64 x 64 x self.gf_dim*2)
+        e3 = instance_norm(conv2d(lrelu(e2), options.gf_dim*4, name='g_e3_conv'), 'g_bn_e3')
+        # e3 is (32 x 32 x self.gf_dim*4)
+        e4 = instance_norm(conv2d(lrelu(e3), options.gf_dim*8, name='g_e4_conv'), 'g_bn_e4')
+        # e4 is (16 x 16 x self.gf_dim*8)
+        e5 = instance_norm(conv2d(lrelu(e4), options.gf_dim*8, name='g_e5_conv'), 'g_bn_e5')
+        # e5 is (8 x 8 x self.gf_dim*8)
+        e6 = instance_norm(conv2d(lrelu(e5), options.gf_dim*8, name='g_e6_conv'), 'g_bn_e6')
+        # e6 is (4 x 4 x self.gf_dim*8)
+        e7 = instance_norm(conv2d(lrelu(e6), options.gf_dim*8, name='g_e7_conv'), 'g_bn_e7')
+        # e7 is (2 x 2 x self.gf_dim*8)
+        e8 = instance_norm(conv2d(lrelu(e7), options.gf_dim*8, name='g_e8_conv'), 'g_bn_e8')
+        # e8 is (1 x 1 x self.gf_dim*8)
+
+        d1 = deconv2d(tf.nn.relu(e8), options.gf_dim*8, name='g_d1')
+        d1 = tf.nn.dropout(d1, dropout_rate)
+        d1 = tf.concat([instance_norm(d1, 'g_bn_d1'), e7], 3)
+        # d1 is (2 x 2 x self.gf_dim*8*2)
+
+        d2 = deconv2d(tf.nn.relu(d1), options.gf_dim*8, name='g_d2')
+        d2 = tf.nn.dropout(d2, dropout_rate)
+        d2 = tf.concat([instance_norm(d2, 'g_bn_d2'), e6], 3)
+        # d2 is (4 x 4 x self.gf_dim*8*2)
+
+        d3 = deconv2d(tf.nn.relu(d2), options.gf_dim*8, name='g_d3')
+        d3 = tf.nn.dropout(d3, dropout_rate)
+        d3 = tf.concat([instance_norm(d3, 'g_bn_d3'), e5], 3)
+        # d3 is (8 x 8 x self.gf_dim*8*2)
+
+        d4 = deconv2d(tf.nn.relu(d3), options.gf_dim*8, name='g_d4')
+        d4 = tf.concat([instance_norm(d4, 'g_bn_d4'), e4], 3)
+        # d4 is (16 x 16 x self.gf_dim*8*2)
+
+        d5 = deconv2d(tf.nn.relu(d4), options.gf_dim*4, name='g_d5')
+        d5 = tf.concat([instance_norm(d5, 'g_bn_d5'), e3], 3)
+        # d5 is (32 x 32 x self.gf_dim*4*2)
+
+        d6 = deconv2d(tf.nn.relu(d5), options.gf_dim*2, name='g_d6')
+        d6 = tf.concat([instance_norm(d6, 'g_bn_d6'), e2], 3)
+        # d6 is (64 x 64 x self.gf_dim*2*2)
+
+        d7 = deconv2d(tf.nn.relu(d6), options.gf_dim, name='g_d7')
+        d7 = tf.concat([instance_norm(d7, 'g_bn_d7'), e1], 3)
+        # d7 is (128 x 128 x self.gf_dim*1*2)
+
+        d8 = deconv2d(tf.nn.relu(d7), options.output_c_dim, name='g_d8')
+        # d8 is (256 x 256 x output_c_dim)
+
+        return tf.nn.tanh(d8)
+
+
+def generator_resnet(image, options, reuse=False, name="generator"):
+
+    with tf.variable_scope(name):
+        # image is 256 x 256 x input_c_dim
+        if reuse:
+            tf.get_variable_scope().reuse_variables()
+        else:
+            assert tf.get_variable_scope().reuse is False
+
+        def residule_block(x, dim, ks=3, s=1, name='res'):
+            p = int((ks - 1) / 2)
+            y = tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT")
+            y = instance_norm(conv2d(y, dim, ks, s, padding='VALID', name=name+'_c1'), name+'_bn1')
+            y = tf.pad(tf.nn.relu(y), [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT")
+            y = instance_norm(conv2d(y, dim, ks, s, padding='VALID', name=name+'_c2'), name+'_bn2')
+            return y + x
+
+        # Justin Johnson's model from https://github.com/jcjohnson/fast-neural-style/
+        # The network with 9 blocks consists of: c7s1-32, d64, d128, R128, R128, R128,
+        # R128, R128, R128, R128, R128, R128, u64, u32, c7s1-3
+        c0 = tf.pad(image, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT")
+        c1 = tf.nn.relu(instance_norm(conv2d(c0, options.gf_dim, 7, 1, padding='VALID', name='g_e1_c'), 'g_e1_bn'))
+        c2 = tf.nn.relu(instance_norm(conv2d(c1, options.gf_dim*2, 3, 2, name='g_e2_c'), 'g_e2_bn'))
+        c3 = tf.nn.relu(instance_norm(conv2d(c2, options.gf_dim*4, 3, 2, name='g_e3_c'), 'g_e3_bn'))
+        # define G network with 9 resnet blocks
+        r1 = residule_block(c3, options.gf_dim*4, name='g_r1')
+        r2 = residule_block(r1, options.gf_dim*4, name='g_r2')
+        r3 = residule_block(r2, options.gf_dim*4, name='g_r3')
+        r4 = residule_block(r3, options.gf_dim*4, name='g_r4')
+        r5 = residule_block(r4, options.gf_dim*4, name='g_r5')
+        r6 = residule_block(r5, options.gf_dim*4, name='g_r6')
+        r7 = residule_block(r6, options.gf_dim*4, name='g_r7')
+        r8 = residule_block(r7, options.gf_dim*4, name='g_r8')
+        r9 = residule_block(r8, options.gf_dim*4, name='g_r9')
+
+        d1 = deconv2d(r9, options.gf_dim*2, 3, 2, name='g_d1_dc')
+        d1 = tf.nn.relu(instance_norm(d1, 'g_d1_bn'))
+        d2 = deconv2d(d1, options.gf_dim, 3, 2, name='g_d2_dc')
+        d2 = tf.nn.relu(instance_norm(d2, 'g_d2_bn'))
+        d2 = tf.pad(d2, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT")
+        pred = tf.nn.tanh(conv2d(d2, options.output_c_dim, 7, 1, padding='VALID', name='g_pred_c'))
+
+        return pred
+
+
+def abs_criterion(in_, target):
+    return tf.reduce_mean(tf.abs(in_ - target))
+
+
+def mae_criterion(in_, target):
+    return tf.reduce_mean((in_-target)**2)
+
+
+def sce_criterion(logits, labels):
+    return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels))

+ 48 - 0
ops.py

@@ -0,0 +1,48 @@
+import math
+import numpy as np
+import tensorflow as tf
+import tensorflow.contrib.slim as slim
+from tensorflow.python.framework import ops
+
+from utils import *
+
+def batch_norm(x, name="batch_norm"):
+    return tf.contrib.layers.batch_norm(x, decay=0.9, updates_collections=None, epsilon=1e-5, scale=True, scope=name)
+
+def instance_norm(input, name="instance_norm"):
+    with tf.variable_scope(name):
+        depth = input.get_shape()[3]
+        scale = tf.get_variable("scale", [depth], initializer=tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32))
+        offset = tf.get_variable("offset", [depth], initializer=tf.constant_initializer(0.0))
+        mean, variance = tf.nn.moments(input, axes=[1,2], keep_dims=True)
+        epsilon = 1e-5
+        inv = tf.rsqrt(variance + epsilon)
+        normalized = (input-mean)*inv
+        return scale*normalized + offset
+
+def conv2d(input_, output_dim, ks=4, s=2, stddev=0.02, padding='SAME', name="conv2d"):
+    with tf.variable_scope(name):
+        return slim.conv2d(input_, output_dim, ks, s, padding=padding, activation_fn=None,
+                            weights_initializer=tf.truncated_normal_initializer(stddev=stddev),
+                            biases_initializer=None)
+
+def deconv2d(input_, output_dim, ks=4, s=2, stddev=0.02, name="deconv2d"):
+    with tf.variable_scope(name):
+        return slim.conv2d_transpose(input_, output_dim, ks, s, padding='SAME', activation_fn=None,
+                                    weights_initializer=tf.truncated_normal_initializer(stddev=stddev),
+                                    biases_initializer=None)
+
+def lrelu(x, leak=0.2, name="lrelu"):
+    return tf.maximum(x, leak*x)
+
+def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False):
+
+    with tf.variable_scope(scope or "Linear"):
+        matrix = tf.get_variable("Matrix", [input_.get_shape()[-1], output_size], tf.float32,
+                                 tf.random_normal_initializer(stddev=stddev))
+        bias = tf.get_variable("bias", [output_size],
+            initializer=tf.constant_initializer(bias_start))
+        if with_w:
+            return tf.matmul(input_, matrix) + bias, matrix, bias
+        else:
+            return tf.matmul(input_, matrix) + bias

+ 48 - 0
prepareimgs.m

@@ -0,0 +1,48 @@
+
+DirectoryPathTA ='./datasets/train/med-image/A';
+DirectoryPathTB ='./datasets/train/med-image/A';
+DirectoryPathTC ='./datasets/train/med-image/A';
+DirectoryPathTAhighlight ='./datasets/train/med-image/B';
+DirectoryPathTBhighlight ='./datasets/train/med-image/B';
+DirectoryPathTChighlight ='./datasets/train/med-image/B';
+files = dir('./brainTumorDataPublic_1-766/*.mat');
+
+i = 1;
+for file = files'
+    imgs = load(fullfile('./brainTumorDataPublic_1-766/',file.name));
+    imgA = imgs.cjdata.image;
+    imgA = double(imgA);
+    label = imgs.cjdata.label;
+    pid = imgs.cjdata.PID;
+    iptsetpref('ImshowBorder','tight');
+    figure(1);
+    imshow(imgA,[0 2825]);
+    imgB = imgs.cjdata.tumorMask;
+    figure(2);
+    imshow(imgB);
+    imgBdoub = double(imgB)*3000;
+    img3 = imgA + imgBdoub;
+    iptsetpref('ImshowBorder','tight');
+    figure(3);
+    imshow(img3,[0 6000]);
+    if label == 1
+        % Save original image
+        whereToStore=fullfile(DirectoryPathTA,['A_TA_',num2str(i),'_',pid, '.jpg']);
+        saveas(figure(1), whereToStore);
+        whereToStoreLabel=fullfile(DirectoryPathTAhighlight,['B_TB_',num2str(i),'_',pid, '.jpg']);
+        saveas(figure(3), whereToStoreLabel);
+    elseif label == 2
+        whereToStore=fullfile(DirectoryPathTB,['A_TB_',num2str(i),'_',pid, '.jpg']);
+        saveas(figure(1), whereToStore);
+        whereToStoreLabel=fullfile(DirectoryPathTBhighlight,['B_TB_',num2str(i),'_',pid, '.jpg']);
+        saveas(figure(3), whereToStoreLabel);
+    elseif label == 3
+        whereToStore=fullfile(DirectoryPathTC,['A_TC_',num2str(i),'_',pid, '.jpg']);
+        saveas(figure(1), whereToStore);
+        whereToStoreLabel=fullfile(DirectoryPathTChighlight,['B_TC_',num2str(i),'_',pid, '.jpg']);
+        saveas(figure(3), whereToStoreLabel);
+    end
+    i = i + 1;
+
+end
+

+ 5 - 0
requirements.txt

@@ -0,0 +1,5 @@
+tensorflow-gpu
+numpy
+scipy
+pillow
+imageio

+ 125 - 0
utils.py

@@ -0,0 +1,125 @@
+"""
+Some codes from https://github.com/Newmu/dcgan_code
+"""
+from __future__ import division
+import math
+import pprint
+import scipy.misc
+import numpy as np
+import copy
+try:
+    _imread = scipy.misc.imread
+except AttributeError:
+    from imageio import imread as _imread
+
+pp = pprint.PrettyPrinter()
+
+get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1])
+
+# -----------------------------
+# new added functions for cyclegan
+class ImagePool(object):
+    def __init__(self, maxsize=50):
+        self.maxsize = maxsize
+        self.num_img = 0
+        self.images = []
+
+    def __call__(self, image):
+        if self.maxsize <= 0:
+            return image
+        if self.num_img < self.maxsize:
+            self.images.append(image)
+            self.num_img += 1
+            return image
+        if np.random.rand() > 0.5:
+            idx = int(np.random.rand()*self.maxsize)
+            tmp1 = copy.copy(self.images[idx])[0]
+            self.images[idx][0] = image[0]
+            idx = int(np.random.rand()*self.maxsize)
+            tmp2 = copy.copy(self.images[idx])[1]
+            self.images[idx][1] = image[1]
+            return [tmp1, tmp2]
+        else:
+            return image
+
+def load_test_data(image_path, fine_size=256):
+    img = imread(image_path)
+    img = scipy.misc.imresize(img, [fine_size, fine_size])
+    img = img/127.5 - 1
+    return img
+
+def load_train_data(image_path, load_size=286, fine_size=256, is_testing=False):
+    img_A = imread(image_path[0])
+    img_B = imread(image_path[1])
+    if not is_testing:
+        img_A = scipy.misc.imresize(img_A, [load_size, load_size])
+        img_B = scipy.misc.imresize(img_B, [load_size, load_size])
+        h1 = int(np.ceil(np.random.uniform(1e-2, load_size-fine_size)))
+        w1 = int(np.ceil(np.random.uniform(1e-2, load_size-fine_size)))
+        img_A = img_A[h1:h1+fine_size, w1:w1+fine_size]
+        img_B = img_B[h1:h1+fine_size, w1:w1+fine_size]
+
+        if np.random.random() > 0.5:
+            img_A = np.fliplr(img_A)
+            img_B = np.fliplr(img_B)
+    else:
+        img_A = scipy.misc.imresize(img_A, [fine_size, fine_size])
+        img_B = scipy.misc.imresize(img_B, [fine_size, fine_size])
+
+    img_A = img_A/127.5 - 1.
+    img_B = img_B/127.5 - 1.
+
+    img_AB = np.concatenate((img_A, img_B), axis=2)
+    # img_AB shape: (fine_size, fine_size, input_c_dim + output_c_dim)
+    return img_AB
+
+# -----------------------------
+
+def get_image(image_path, image_size, is_crop=True, resize_w=64, is_grayscale = False):
+    return transform(imread(image_path, is_grayscale), image_size, is_crop, resize_w)
+
+def save_images(images, size, image_path):
+    return imsave(inverse_transform(images), size, image_path)
+
+def imread(path, is_grayscale = False):
+    if (is_grayscale):
+        return _imread(path, flatten=True).astype(np.float)
+    else:
+        return _imread(path, mode='RGB').astype(np.float)
+
+def merge_images(images, size):
+    return inverse_transform(images)
+
+def merge(images, size):
+    h, w = images.shape[1], images.shape[2]
+    img = np.zeros((h * size[0], w * size[1], 3))
+    for idx, image in enumerate(images):
+        i = idx % size[1]
+        j = idx // size[1]
+        img[j*h:j*h+h, i*w:i*w+w, :] = image
+
+    return img
+
+def imsave(images, size, path):
+    return scipy.misc.imsave(path, merge(images, size))
+
+def center_crop(x, crop_h, crop_w,
+                resize_h=64, resize_w=64):
+  if crop_w is None:
+    crop_w = crop_h
+  h, w = x.shape[:2]
+  j = int(round((h - crop_h)/2.))
+  i = int(round((w - crop_w)/2.))
+  return scipy.misc.imresize(
+      x[j:j+crop_h, i:i+crop_w], [resize_h, resize_w])
+
+def transform(image, npx=64, is_crop=True, resize_w=64):
+    # npx : # of pixels width/height of image
+    if is_crop:
+        cropped_image = center_crop(image, npx, resize_w=resize_w)
+    else:
+        cropped_image = image
+    return np.array(cropped_image)/127.5 - 1.
+
+def inverse_transform(images):
+    return (images+1.)/2.