utils.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import numpy as np
  2. from PIL import Image
  3. def data_augmentation(image, mode):
  4. if mode == 0:
  5. # original
  6. return image
  7. elif mode == 1:
  8. # flip up and down
  9. return np.flipud(image)
  10. elif mode == 2:
  11. # rotate counterwise 90 degree
  12. return np.rot90(image)
  13. elif mode == 3:
  14. # rotate 90 degree and flip up and down
  15. image = np.rot90(image)
  16. return np.flipud(image)
  17. elif mode == 4:
  18. # rotate 180 degree
  19. return np.rot90(image, k=2)
  20. elif mode == 5:
  21. # rotate 180 degree and flip
  22. image = np.rot90(image, k=2)
  23. return np.flipud(image)
  24. elif mode == 6:
  25. # rotate 270 degree
  26. return np.rot90(image, k=3)
  27. elif mode == 7:
  28. # rotate 270 degree and flip
  29. image = np.rot90(image, k=3)
  30. return np.flipud(image)
  31. def load_images(file):
  32. im = Image.open(file)
  33. return np.array(im, dtype="float32") / 255.0
  34. def save_images(filepath, result_1, result_2 = None):
  35. result_1 = np.squeeze(result_1)
  36. result_2 = np.squeeze(result_2)
  37. if not result_2.any():
  38. cat_image = result_1
  39. else:
  40. cat_image = np.concatenate([result_1, result_2], axis = 1)
  41. im = Image.fromarray(np.clip(cat_image * 255.0, 0, 255.0).astype('uint8'))
  42. im.save(filepath, 'png')