Jelajahi Sumber

update to python 3

SUN Hao 4 tahun lalu
induk
melakukan
a4a3be8130
2 mengubah file dengan 17 tambahan dan 8 penghapusan
  1. 14 6
      data_loader.py
  2. 3 2
      model.py

+ 14 - 6
data_loader.py

@@ -1,8 +1,13 @@
 import os
 from glob import glob
 
-from scipy.misc import imread, imresize
+#from scipy.misc import imread, imresize
+import imageio
+from imageio import imread
+import numpy
+from PIL import Image
 import numpy as np
+from skimage.transform import resize
 from tqdm import tqdm
 import h5py
 
@@ -19,11 +24,11 @@ def read_image(path):
     assert w in [256, 512, 1200], 'Image size mismatch ({}, {})'.format(h, w)
     assert h in [128, 256, 600], 'Image size mismatch ({}, {})'.format(h, w)
     if 'maps' in path:
-        image_a = image[:, w/2:, :].astype(np.float32) / 255.0
-        image_b = image[:, :w/2, :].astype(np.float32) / 255.0
+        image_a = image[:, int(w/2):, :].astype(np.float32) / 255.0
+        image_b = image[:, :int(w/2), :].astype(np.float32) / 255.0
     else:
-        image_a = image[:, :w/2, :].astype(np.float32) / 255.0
-        image_b = image[:, w/2:, :].astype(np.float32) / 255.0
+        image_a = image[:, :int(w/2), :].astype(np.float32) / 255.0
+        image_b = image[:, int(w/2):, :].astype(np.float32) / 255.0
 
     # range of pixel values = [-1.0, 1.0]
     image_a = image_a * 2.0 - 1.0
@@ -53,7 +58,10 @@ def store_h5py(base_dir, dir_name, images, image_size):
     for i in range(len(images)):
         grp = f.create_group(str(i))
         if images[i].shape[0] != image_size:
-            image = imresize(images[i], (image_size, image_size, 3))
+            #image = imresize(images[i], (image_size, image_size, 3))
+            #print(i)
+            #image = numpy.array(Image.fromarray(images[i]).resize((image_size, image_size, 3)))
+            image = resize(images[i], (image_size, image_size, 3))
             # range of pixel values = [-1.0, 1.0]
             image = image.astype(np.float32) / 255.0
             image = image * 2.0 - 1.0

+ 3 - 2
model.py

@@ -2,7 +2,8 @@ import os
 import random
 
 from tqdm import trange, tqdm
-from scipy.misc import imsave
+#from scipy.misc import imsave
+from skimage.io import imsave
 import tensorflow as tf
 import numpy as np
 
@@ -165,7 +166,7 @@ class BicycleGAN(object):
                 lr = max(0.0, lr_initial - (epoch - num_initial_iter) * lr_decay)
 
             if iter == 0:
-                data = zip(data_A, data_B)
+                data = list(zip(data_A, data_B))
                 random.shuffle(data)
                 data_A, data_B = zip(*data)