data_loader.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import os
  2. from glob import glob
  3. from scipy.misc import imread, imresize
  4. import numpy as np
  5. from tqdm import tqdm
  6. import h5py
  7. datasets = ['maps', 'cityscapes', 'facades', 'edges2handbags', 'edges2shoes']
  8. def read_image(path):
  9. image = imread(path)
  10. if len(image.shape) != 3 or image.shape[2] != 3:
  11. print('Wrong image {} with shape {}'.format(path, image.shape))
  12. return None
  13. # split image
  14. h, w, c = image.shape
  15. assert w in [256, 512, 1200], 'Image size mismatch ({}, {})'.format(h, w)
  16. assert h in [128, 256, 600], 'Image size mismatch ({}, {})'.format(h, w)
  17. if 'maps' in path:
  18. image_a = image[:, w/2:, :].astype(np.float32) / 255.0
  19. image_b = image[:, :w/2, :].astype(np.float32) / 255.0
  20. else:
  21. image_a = image[:, :w/2, :].astype(np.float32) / 255.0
  22. image_b = image[:, w/2:, :].astype(np.float32) / 255.0
  23. # range of pixel values = [-1.0, 1.0]
  24. image_a = image_a * 2.0 - 1.0
  25. image_b = image_b * 2.0 - 1.0
  26. return image_a, image_b
  27. def read_images(base_dir):
  28. ret = []
  29. for dir_name in ['train', 'val']:
  30. data_dir = os.path.join(base_dir, dir_name)
  31. paths = glob(os.path.join(data_dir, '*.jpg'))
  32. print('# images in {}: {}'.format(data_dir, len(paths)))
  33. images_A = []
  34. images_B = []
  35. for path in tqdm(paths):
  36. image_A, image_B = read_image(path)
  37. if image_A is not None:
  38. images_A.append(image_A)
  39. images_B.append(image_B)
  40. ret.append((dir_name + 'A', images_A))
  41. ret.append((dir_name + 'B', images_B))
  42. return ret
  43. def store_h5py(base_dir, dir_name, images, image_size):
  44. f = h5py.File(os.path.join(base_dir, '{}_{}.hy'.format(dir_name, image_size)), 'w')
  45. for i in range(len(images)):
  46. grp = f.create_group(str(i))
  47. if images[i].shape[0] != image_size:
  48. image = imresize(images[i], (image_size, image_size, 3))
  49. # range of pixel values = [-1.0, 1.0]
  50. image = image.astype(np.float32) / 255.0
  51. image = image * 2.0 - 1.0
  52. grp['image'] = image
  53. else:
  54. grp['image'] = images[i]
  55. f.close()
  56. def convert_h5py(task_name):
  57. print('Generating h5py file')
  58. base_dir = os.path.join('datasets', task_name)
  59. data = read_images(base_dir)
  60. for dir_name, images in data:
  61. if images[0].shape[0] == 256:
  62. store_h5py(base_dir, dir_name, images, 256)
  63. store_h5py(base_dir, dir_name, images, 128)
  64. def read_h5py(task_name, image_size):
  65. base_dir = 'datasets/' + task_name
  66. paths = glob(os.path.join(base_dir, '*_{}.hy'.format(image_size)))
  67. if len(paths) != 4:
  68. convert_h5py(task_name)
  69. ret = []
  70. for dir_name in ['trainA', 'trainB', 'valA', 'valB']:
  71. try:
  72. dataset = h5py.File(os.path.join(base_dir, '{}_{}.hy'.format(dir_name, image_size)), 'r')
  73. except:
  74. raise IOError('Dataset is not available. Please try it again')
  75. images = []
  76. for id in dataset:
  77. images.append(dataset[id]['image'].value.astype(np.float32))
  78. ret.append(images)
  79. return ret
  80. def download_dataset(task_name):
  81. print('Download data %s' % task_name)
  82. cmd = './download_pix2pix_dataset.sh ' + task_name
  83. os.system(cmd)
  84. def get_data(task_name, image_size):
  85. assert task_name in datasets, 'Dataset {}_{} is not available'.format(
  86. task_name, image_size)
  87. if not os.path.exists('datasets'):
  88. os.makedirs('datasets')
  89. base_dir = os.path.join('datasets', task_name)
  90. print('Check data %s' % base_dir)
  91. if not os.path.exists(base_dir):
  92. print('Dataset not found. Start downloading...')
  93. download_dataset(task_name)
  94. convert_h5py(task_name)
  95. print('Load data %s' % task_name)
  96. train_A, train_B, test_A, test_B = \
  97. read_h5py(task_name, image_size)
  98. return train_A, train_B, test_A, test_B