data_loader.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import os
  2. from glob import glob
  3. #from scipy.misc import imread, imresize
  4. import imageio
  5. from imageio import imread
  6. import numpy
  7. from PIL import Image
  8. import numpy as np
  9. from skimage.transform import resize
  10. from tqdm import tqdm
  11. import h5py
  12. datasets = ['maps', 'cityscapes', 'facades', 'edges2handbags', 'edges2shoes']
  13. def read_image(path):
  14. image = imread(path)
  15. if len(image.shape) != 3 or image.shape[2] != 3:
  16. print('Wrong image {} with shape {}'.format(path, image.shape))
  17. return None
  18. # split image
  19. h, w, c = image.shape
  20. assert w in [256, 512, 1200], 'Image size mismatch ({}, {})'.format(h, w)
  21. assert h in [128, 256, 600], 'Image size mismatch ({}, {})'.format(h, w)
  22. if 'maps' in path:
  23. image_a = image[:, int(w/2):, :].astype(np.float32) / 255.0
  24. image_b = image[:, :int(w/2), :].astype(np.float32) / 255.0
  25. else:
  26. image_a = image[:, :int(w/2), :].astype(np.float32) / 255.0
  27. image_b = image[:, int(w/2):, :].astype(np.float32) / 255.0
  28. # range of pixel values = [-1.0, 1.0]
  29. image_a = image_a * 2.0 - 1.0
  30. image_b = image_b * 2.0 - 1.0
  31. return image_a, image_b
  32. def read_images(base_dir):
  33. ret = []
  34. for dir_name in ['train', 'val']:
  35. data_dir = os.path.join(base_dir, dir_name)
  36. paths = glob(os.path.join(data_dir, '*.jpg'))
  37. print('# images in {}: {}'.format(data_dir, len(paths)))
  38. images_A = []
  39. images_B = []
  40. for path in tqdm(paths):
  41. image_A, image_B = read_image(path)
  42. if image_A is not None:
  43. images_A.append(image_A)
  44. images_B.append(image_B)
  45. ret.append((dir_name + 'A', images_A))
  46. ret.append((dir_name + 'B', images_B))
  47. return ret
  48. def store_h5py(base_dir, dir_name, images, image_size):
  49. f = h5py.File(os.path.join(base_dir, '{}_{}.hy'.format(dir_name, image_size)), 'w')
  50. for i in range(len(images)):
  51. grp = f.create_group(str(i))
  52. if images[i].shape[0] != image_size:
  53. #image = imresize(images[i], (image_size, image_size, 3))
  54. #print(i)
  55. #image = numpy.array(Image.fromarray(images[i]).resize((image_size, image_size, 3)))
  56. image = resize(images[i], (image_size, image_size, 3))
  57. # range of pixel values = [-1.0, 1.0]
  58. image = image.astype(np.float32) / 255.0
  59. image = image * 2.0 - 1.0
  60. grp['image'] = image
  61. else:
  62. grp['image'] = images[i]
  63. f.close()
  64. def convert_h5py(task_name):
  65. print('Generating h5py file')
  66. base_dir = os.path.join('datasets', task_name)
  67. data = read_images(base_dir)
  68. for dir_name, images in data:
  69. if images[0].shape[0] == 256:
  70. store_h5py(base_dir, dir_name, images, 256)
  71. store_h5py(base_dir, dir_name, images, 128)
  72. def read_h5py(task_name, image_size):
  73. base_dir = 'datasets/' + task_name
  74. paths = glob(os.path.join(base_dir, '*_{}.hy'.format(image_size)))
  75. if len(paths) != 4:
  76. convert_h5py(task_name)
  77. ret = []
  78. for dir_name in ['trainA', 'trainB', 'valA', 'valB']:
  79. try:
  80. dataset = h5py.File(os.path.join(base_dir, '{}_{}.hy'.format(dir_name, image_size)), 'r')
  81. except:
  82. raise IOError('Dataset is not available. Please try it again')
  83. images = []
  84. for id in dataset:
  85. images.append(dataset[id]['image'].value.astype(np.float32))
  86. ret.append(images)
  87. return ret
  88. def download_dataset(task_name):
  89. print('Download data %s' % task_name)
  90. cmd = './download_pix2pix_dataset.sh ' + task_name
  91. os.system(cmd)
  92. def get_data(task_name, image_size):
  93. assert task_name in datasets, 'Dataset {}_{} is not available'.format(
  94. task_name, image_size)
  95. if not os.path.exists('datasets'):
  96. os.makedirs('datasets')
  97. base_dir = os.path.join('datasets', task_name)
  98. print('Check data %s' % base_dir)
  99. if not os.path.exists(base_dir):
  100. print('Dataset not found. Start downloading...')
  101. download_dataset(task_name)
  102. convert_h5py(task_name)
  103. print('Load data %s' % task_name)
  104. train_A, train_B, test_A, test_B = \
  105. read_h5py(task_name, image_size)
  106. return train_A, train_B, test_A, test_B