utils.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. """
  2. Some codes from https://github.com/Newmu/dcgan_code
  3. """
  4. from __future__ import division
  5. import math
  6. import pprint
  7. import scipy.misc
  8. import numpy as np
  9. import copy
  10. from skimage.transform import resize
  11. from skimage.io import imsave
  12. try:
  13. _imread = scipy.misc.imread
  14. except AttributeError:
  15. from imageio import imread as _imread
  16. pp = pprint.PrettyPrinter()
  17. get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1])
  18. # -----------------------------
  19. # new added functions for cyclegan
  20. class ImagePool(object):
  21. def __init__(self, maxsize=50):
  22. self.maxsize = maxsize
  23. self.num_img = 0
  24. self.images = []
  25. def __call__(self, image):
  26. if self.maxsize <= 0:
  27. return image
  28. if self.num_img < self.maxsize:
  29. self.images.append(image)
  30. self.num_img += 1
  31. return image
  32. if np.random.rand() > 0.5:
  33. idx = int(np.random.rand()*self.maxsize)
  34. tmp1 = copy.copy(self.images[idx])[0]
  35. self.images[idx][0] = image[0]
  36. idx = int(np.random.rand()*self.maxsize)
  37. tmp2 = copy.copy(self.images[idx])[1]
  38. self.images[idx][1] = image[1]
  39. return [tmp1, tmp2]
  40. else:
  41. return image
  42. def load_test_data(image_path, fine_size=256):
  43. img = imread(image_path)
  44. img = resize(img, [fine_size, fine_size])
  45. img = img/127.5 - 1
  46. return img
  47. def load_train_data(image_path, load_size=286, fine_size=256, is_testing=False):
  48. img_A = imread(image_path[0])
  49. img_B = imread(image_path[1])
  50. if not is_testing:
  51. img_A = resize(img_A, [load_size, load_size])
  52. img_B = resize(img_B, [load_size, load_size])
  53. h1 = int(np.ceil(np.random.uniform(1e-2, load_size-fine_size)))
  54. w1 = int(np.ceil(np.random.uniform(1e-2, load_size-fine_size)))
  55. img_A = img_A[h1:h1+fine_size, w1:w1+fine_size]
  56. img_B = img_B[h1:h1+fine_size, w1:w1+fine_size]
  57. if np.random.random() > 0.5:
  58. img_A = np.fliplr(img_A)
  59. img_B = np.fliplr(img_B)
  60. else:
  61. img_A = resize(img_A, [fine_size, fine_size])
  62. img_B = resize(img_B, [fine_size, fine_size])
  63. img_A = img_A/127.5 - 1.
  64. img_B = img_B/127.5 - 1.
  65. img_AB = np.concatenate((img_A, img_B), axis=2)
  66. # img_AB shape: (fine_size, fine_size, input_c_dim + output_c_dim)
  67. return img_AB
  68. # -----------------------------
  69. def get_image(image_path, image_size, is_crop=True, resize_w=64, is_grayscale = False):
  70. return transform(imread(image_path, is_grayscale), image_size, is_crop, resize_w)
  71. def save_images(images, size, image_path):
  72. return imsave(image_path, inverse_transform(images))
  73. def imread(path, is_grayscale = False):
  74. if (is_grayscale):
  75. return _imread(path, flatten=True).astype(np.float)
  76. else:
  77. return _imread(path, pilmode='RGB').astype(np.float)
  78. def merge_images(images, size):
  79. return inverse_transform(images)
  80. def merge(images, size):
  81. h, w = images.shape[1], images.shape[2]
  82. img = np.zeros((h * size[0], w * size[1], 3))
  83. for idx, image in enumerate(images):
  84. i = idx % size[1]
  85. j = idx // size[1]
  86. img[j*h:j*h+h, i*w:i*w+w, :] = image
  87. return img
  88. #def imsave(images, size, path):
  89. # return scipy.misc.imsave(path, merge(images, size))
  90. def center_crop(x, crop_h, crop_w,
  91. resize_h=64, resize_w=64):
  92. if crop_w is None:
  93. crop_w = crop_h
  94. h, w = x.shape[:2]
  95. j = int(round((h - crop_h)/2.))
  96. i = int(round((w - crop_w)/2.))
  97. return scipy.misc.imresize(
  98. x[j:j+crop_h, i:i+crop_w], [resize_h, resize_w])
  99. def transform(image, npx=64, is_crop=True, resize_w=64):
  100. # npx : # of pixels width/height of image
  101. if is_crop:
  102. cropped_image = center_crop(image, npx, resize_w=resize_w)
  103. else:
  104. cropped_image = image
  105. return np.array(cropped_image)/127.5 - 1.
  106. def inverse_transform(images):
  107. return (images+1.)/2.