utils.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. """
  2. Some codes from https://github.com/Newmu/dcgan_code
  3. """
  4. from __future__ import division
  5. import math
  6. import pprint
  7. import scipy.misc
  8. import numpy as np
  9. import copy
  10. try:
  11. _imread = scipy.misc.imread
  12. except AttributeError:
  13. from imageio import imread as _imread
  14. pp = pprint.PrettyPrinter()
  15. get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1])
  16. # -----------------------------
  17. # new added functions for cyclegan
  18. class ImagePool(object):
  19. def __init__(self, maxsize=50):
  20. self.maxsize = maxsize
  21. self.num_img = 0
  22. self.images = []
  23. def __call__(self, image):
  24. if self.maxsize <= 0:
  25. return image
  26. if self.num_img < self.maxsize:
  27. self.images.append(image)
  28. self.num_img += 1
  29. return image
  30. if np.random.rand() > 0.5:
  31. idx = int(np.random.rand()*self.maxsize)
  32. tmp1 = copy.copy(self.images[idx])[0]
  33. self.images[idx][0] = image[0]
  34. idx = int(np.random.rand()*self.maxsize)
  35. tmp2 = copy.copy(self.images[idx])[1]
  36. self.images[idx][1] = image[1]
  37. return [tmp1, tmp2]
  38. else:
  39. return image
  40. def load_test_data(image_path, fine_size=256):
  41. img = imread(image_path)
  42. img = scipy.misc.imresize(img, [fine_size, fine_size])
  43. img = img/127.5 - 1
  44. return img
  45. def load_train_data(image_path, load_size=286, fine_size=256, is_testing=False):
  46. img_A = imread(image_path[0])
  47. img_B = imread(image_path[1])
  48. if not is_testing:
  49. img_A = scipy.misc.imresize(img_A, [load_size, load_size])
  50. img_B = scipy.misc.imresize(img_B, [load_size, load_size])
  51. h1 = int(np.ceil(np.random.uniform(1e-2, load_size-fine_size)))
  52. w1 = int(np.ceil(np.random.uniform(1e-2, load_size-fine_size)))
  53. img_A = img_A[h1:h1+fine_size, w1:w1+fine_size]
  54. img_B = img_B[h1:h1+fine_size, w1:w1+fine_size]
  55. if np.random.random() > 0.5:
  56. img_A = np.fliplr(img_A)
  57. img_B = np.fliplr(img_B)
  58. else:
  59. img_A = scipy.misc.imresize(img_A, [fine_size, fine_size])
  60. img_B = scipy.misc.imresize(img_B, [fine_size, fine_size])
  61. img_A = img_A/127.5 - 1.
  62. img_B = img_B/127.5 - 1.
  63. img_AB = np.concatenate((img_A, img_B), axis=2)
  64. # img_AB shape: (fine_size, fine_size, input_c_dim + output_c_dim)
  65. return img_AB
  66. # -----------------------------
  67. def get_image(image_path, image_size, is_crop=True, resize_w=64, is_grayscale = False):
  68. return transform(imread(image_path, is_grayscale), image_size, is_crop, resize_w)
  69. def save_images(images, size, image_path):
  70. return imsave(inverse_transform(images), size, image_path)
  71. def imread(path, is_grayscale = False):
  72. if (is_grayscale):
  73. return _imread(path, flatten=True).astype(np.float)
  74. else:
  75. return _imread(path, mode='RGB').astype(np.float)
  76. def merge_images(images, size):
  77. return inverse_transform(images)
  78. def merge(images, size):
  79. h, w = images.shape[1], images.shape[2]
  80. img = np.zeros((h * size[0], w * size[1], 3))
  81. for idx, image in enumerate(images):
  82. i = idx % size[1]
  83. j = idx // size[1]
  84. img[j*h:j*h+h, i*w:i*w+w, :] = image
  85. return img
  86. def imsave(images, size, path):
  87. return scipy.misc.imsave(path, merge(images, size))
  88. def center_crop(x, crop_h, crop_w,
  89. resize_h=64, resize_w=64):
  90. if crop_w is None:
  91. crop_w = crop_h
  92. h, w = x.shape[:2]
  93. j = int(round((h - crop_h)/2.))
  94. i = int(round((w - crop_w)/2.))
  95. return scipy.misc.imresize(
  96. x[j:j+crop_h, i:i+crop_w], [resize_h, resize_w])
  97. def transform(image, npx=64, is_crop=True, resize_w=64):
  98. # npx : # of pixels width/height of image
  99. if is_crop:
  100. cropped_image = center_crop(image, npx, resize_w=resize_w)
  101. else:
  102. cropped_image = image
  103. return np.array(cropped_image)/127.5 - 1.
  104. def inverse_transform(images):
  105. return (images+1.)/2.