Hao SUN преди 3 години
родител
ревизия
7d73cf04f0
променени са 4 файла, в които са добавени 20 реда и са изтрити 16 реда
  1. 2 2
      README.md
  2. 4 4
      main.py
  3. 2 1
      model.py
  4. 12 9
      utils.py

+ 2 - 2
README.md

@@ -12,7 +12,7 @@ As proposed by [Jun-Yan Zhu](https://people.eecs.berkeley.edu/~junyanz/) in
 
 ### Train
 ```bash
-python main.py --dataset_dir=med_image
+python main.py --dataset_dir=med-image
 ```
 
 ### Test
@@ -22,4 +22,4 @@ python main.py --dataset_dir=med-image --phase=test --which_direction=AtoB
 
 ## References
 - CycleGAN implementation, https://github.com/XHUJOY/CycleGAN-tensorflow
-- Torch CycleGAN, https://github.com/junyanz/CycleGAN
+- Torch CycleGAN, https://github.com/junyanz/CycleGAN

+ 4 - 4
main.py

@@ -5,10 +5,10 @@ tf.set_random_seed(19)
 from model import cyclegan
 
 parser = argparse.ArgumentParser(description='')
-parser.add_argument('--dataset_dir', dest='dataset_dir', default='horse2zebra', help='path of the dataset')
+parser.add_argument('--dataset_dir', dest='dataset_dir', default='med-image', help='path of the dataset')
 parser.add_argument('--epoch', dest='epoch', type=int, default=200, help='# of epoch')
 parser.add_argument('--epoch_step', dest='epoch_step', type=int, default=100, help='# of epoch to decay lr')
-parser.add_argument('--batch_size', dest='batch_size', type=int, default=1, help='# images in batch')
+parser.add_argument('--batch_size', dest='batch_size', type=int, default=4, help='# images in batch')
 parser.add_argument('--train_size', dest='train_size', type=int, default=1e8, help='# images used to train')
 parser.add_argument('--load_size', dest='load_size', type=int, default=286, help='scale images to this size')
 parser.add_argument('--fine_size', dest='fine_size', type=int, default=256, help='then crop to this size')
@@ -20,8 +20,8 @@ parser.add_argument('--lr', dest='lr', type=float, default=0.0002, help='initial
 parser.add_argument('--beta1', dest='beta1', type=float, default=0.5, help='momentum term of adam')
 parser.add_argument('--which_direction', dest='which_direction', default='AtoB', help='AtoB or BtoA')
 parser.add_argument('--phase', dest='phase', default='train', help='train, test')
-parser.add_argument('--save_freq', dest='save_freq', type=int, default=1000, help='save a model every save_freq iterations')
-parser.add_argument('--print_freq', dest='print_freq', type=int, default=100, help='print the debug information every print_freq iterations')
+parser.add_argument('--save_freq', dest='save_freq', type=int, default=3, help='save a model every save_freq iterations')
+parser.add_argument('--print_freq', dest='print_freq', type=int, default=2, help='print the debug information every print_freq iterations')
 parser.add_argument('--continue_train', dest='continue_train', type=bool, default=False, help='if continue training, load the latest model: 1: true, 0: false')
 parser.add_argument('--checkpoint_dir', dest='checkpoint_dir', default='./checkpoint', help='models are saved here')
 parser.add_argument('--sample_dir', dest='sample_dir', default='./sample', help='sample are saved here')

+ 2 - 1
model.py

@@ -172,9 +172,10 @@ class cyclegan(object):
                 print(("Epoch: [%2d] [%4d/%4d] time: %4.4f" % (
                     epoch, idx, batch_idxs, time.time() - start_time)))
 
+                '''
                 if np.mod(counter, args.print_freq) == 1:
                     self.sample_model(args.sample_dir, epoch, idx)
-
+                '''
                 if np.mod(counter, args.save_freq) == 2:
                     self.save(args.checkpoint_dir, counter)
 

+ 12 - 9
utils.py

@@ -7,6 +7,9 @@ import pprint
 import scipy.misc
 import numpy as np
 import copy
+from skimage.transform import resize
+from skimage.io import imsave
+
 try:
     _imread = scipy.misc.imread
 except AttributeError:
@@ -44,7 +47,7 @@ class ImagePool(object):
 
 def load_test_data(image_path, fine_size=256):
     img = imread(image_path)
-    img = scipy.misc.imresize(img, [fine_size, fine_size])
+    img = resize(img, [fine_size, fine_size])
     img = img/127.5 - 1
     return img
 
@@ -52,8 +55,8 @@ def load_train_data(image_path, load_size=286, fine_size=256, is_testing=False):
     img_A = imread(image_path[0])
     img_B = imread(image_path[1])
     if not is_testing:
-        img_A = scipy.misc.imresize(img_A, [load_size, load_size])
-        img_B = scipy.misc.imresize(img_B, [load_size, load_size])
+        img_A = resize(img_A, [load_size, load_size])
+        img_B = resize(img_B, [load_size, load_size])
         h1 = int(np.ceil(np.random.uniform(1e-2, load_size-fine_size)))
         w1 = int(np.ceil(np.random.uniform(1e-2, load_size-fine_size)))
         img_A = img_A[h1:h1+fine_size, w1:w1+fine_size]
@@ -63,8 +66,8 @@ def load_train_data(image_path, load_size=286, fine_size=256, is_testing=False):
             img_A = np.fliplr(img_A)
             img_B = np.fliplr(img_B)
     else:
-        img_A = scipy.misc.imresize(img_A, [fine_size, fine_size])
-        img_B = scipy.misc.imresize(img_B, [fine_size, fine_size])
+        img_A = resize(img_A, [fine_size, fine_size])
+        img_B = resize(img_B, [fine_size, fine_size])
 
     img_A = img_A/127.5 - 1.
     img_B = img_B/127.5 - 1.
@@ -79,13 +82,13 @@ def get_image(image_path, image_size, is_crop=True, resize_w=64, is_grayscale =
     return transform(imread(image_path, is_grayscale), image_size, is_crop, resize_w)
 
 def save_images(images, size, image_path):
-    return imsave(inverse_transform(images), size, image_path)
+    return imsave(image_path, inverse_transform(images))
 
 def imread(path, is_grayscale = False):
     if (is_grayscale):
         return _imread(path, flatten=True).astype(np.float)
     else:
-        return _imread(path, mode='RGB').astype(np.float)
+        return _imread(path, pilmode='RGB').astype(np.float)
 
 def merge_images(images, size):
     return inverse_transform(images)
@@ -100,8 +103,8 @@ def merge(images, size):
 
     return img
 
-def imsave(images, size, path):
-    return scipy.misc.imsave(path, merge(images, size))
+#def imsave(images, size, path):
+#    return scipy.misc.imsave(path, merge(images, size))
 
 def center_crop(x, crop_h, crop_w,
                 resize_h=64, resize_w=64):