data_loader.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import os
  2. from glob import glob
  3. from scipy.misc import imread, imresize
  4. import numpy as np
  5. from tqdm import tqdm
  6. import h5py
  7. datasets = ['maps', 'cityscapes', 'facades', 'edges2handbags', 'edges2shoes']
  8. def read_image(path):
  9. image = imread(path)
  10. if len(image.shape) != 3 or image.shape[2] != 3:
  11. print('Wrong image {} with shape {}'.format(path, image.shape))
  12. return None
  13. # split image
  14. h, w, c = image.shape
  15. assert w == 256 or w == 512, 'Image size mismatch ({}, {})'.format(h, w)
  16. assert h == 128 or h == 256, 'Image size mismatch ({}, {})'.format(h, w)
  17. image_a = image[:, :w/2, :].astype(np.float32) / 255.0
  18. image_b = image[:, w/2:, :].astype(np.float32) / 255.0
  19. # range of pixel values = [-1.0, 1.0]
  20. image_a = image_a * 2.0 - 1.0
  21. image_b = image_b * 2.0 - 1.0
  22. return image_a, image_b
  23. def read_images(base_dir):
  24. ret = []
  25. for dir_name in ['train', 'val']:
  26. data_dir = os.path.join(base_dir, dir_name)
  27. paths = glob(os.path.join(data_dir, '*.jpg'))
  28. print('# images in {}: {}'.format(data_dir, len(paths)))
  29. images_A = []
  30. images_B = []
  31. for path in tqdm(paths):
  32. image_A, image_B = read_image(path)
  33. if image_A is not None:
  34. images_A.append(image_A)
  35. images_B.append(image_B)
  36. ret.append((dir_name + 'A', images_A))
  37. ret.append((dir_name + 'B', images_B))
  38. return ret
  39. def store_h5py(base_dir, dir_name, images, image_size):
  40. f = h5py.File(os.path.join(base_dir, '{}_{}.hy'.format(dir_name, image_size)), 'w')
  41. for i in range(len(images)):
  42. grp = f.create_group(str(i))
  43. if images[i].shape[0] != image_size:
  44. image = imresize(images[i], (image_size, image_size, 3))
  45. # range of pixel values = [-1.0, 1.0]
  46. image = image.astype(np.float32) / 255.0
  47. image = image * 2.0 - 1.0
  48. grp['image'] = image
  49. else:
  50. grp['image'] = images[i]
  51. f.close()
  52. def convert_h5py(task_name):
  53. print('Generating h5py file')
  54. base_dir = os.path.join('datasets', task_name)
  55. data = read_images(base_dir)
  56. for dir_name, images in data:
  57. if images[0].shape[0] == 256:
  58. store_h5py(base_dir, dir_name, images, 256)
  59. store_h5py(base_dir, dir_name, images, 128)
  60. def read_h5py(task_name, image_size):
  61. base_dir = 'datasets/' + task_name
  62. paths = glob(os.path.join(base_dir, '*_{}.hy'.format(image_size)))
  63. if len(paths) != 4:
  64. convert_h5py(task_name)
  65. ret = []
  66. for dir_name in ['trainA', 'trainB', 'valA', 'valB']:
  67. try:
  68. dataset = h5py.File(os.path.join(base_dir, '{}_{}.hy'.format(dir_name, image_size)), 'r')
  69. except:
  70. raise IOError('Dataset is not available. Please try it again')
  71. images = []
  72. for id in dataset:
  73. images.append(dataset[id]['image'].value.astype(np.float32))
  74. ret.append(images)
  75. return ret
  76. def download_dataset(task_name):
  77. print('Download data %s' % task_name)
  78. cmd = './download_pix2pix_dataset.sh ' + task_name
  79. os.system(cmd)
  80. def get_data(task_name, image_size):
  81. assert task_name in datasets, 'Dataset {}_{} is not available'.format(
  82. task_name, image_size)
  83. if not os.path.exists('datasets'):
  84. os.makedirs('datasets')
  85. base_dir = os.path.join('datasets', task_name)
  86. print('Check data %s' % base_dir)
  87. if not os.path.exists(base_dir):
  88. print('Dataset not found. Start downloading...')
  89. download_dataset(task_name)
  90. convert_h5py(task_name)
  91. print('Load data %s' % task_name)
  92. train_A, train_B, test_A, test_B = \
  93. read_h5py(task_name, image_size)
  94. return train_A, train_B, test_A, test_B