main.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. from __future__ import print_function
  2. import os
  3. import argparse
  4. from glob import glob
  5. from PIL import Image
  6. import tensorflow.compat.v1 as tf
  7. tf.disable_v2_behavior()
  8. from model import lowlight_enhance
  9. from utils import *
  10. parser = argparse.ArgumentParser(description='')
  11. parser.add_argument('--use_gpu', dest='use_gpu', type=int, default=1, help='gpu flag, 1 for GPU and 0 for CPU')
  12. parser.add_argument('--gpu_idx', dest='gpu_idx', default="0", help='GPU idx')
  13. parser.add_argument('--gpu_mem', dest='gpu_mem', type=float, default=0.5, help="0 to 1, gpu memory usage")
  14. parser.add_argument('--phase', dest='phase', default='train', help='train or test')
  15. parser.add_argument('--epoch', dest='epoch', type=int, default=100, help='number of total epoches')
  16. parser.add_argument('--batch_size', dest='batch_size', type=int, default=16, help='number of samples in one batch')
  17. parser.add_argument('--patch_size', dest='patch_size', type=int, default=48, help='patch size')
  18. parser.add_argument('--start_lr', dest='start_lr', type=float, default=0.001, help='initial learning rate for adam')
  19. parser.add_argument('--eval_every_epoch', dest='eval_every_epoch', default=20, help='evaluating and saving checkpoints every # epoch')
  20. parser.add_argument('--checkpoint_dir', dest='ckpt_dir', default='./checkpoint', help='directory for checkpoints')
  21. parser.add_argument('--sample_dir', dest='sample_dir', default='./sample', help='directory for evaluating outputs')
  22. parser.add_argument('--save_dir', dest='save_dir', default='./test_results', help='directory for testing outputs')
  23. parser.add_argument('--test_dir', dest='test_dir', default='./data/test/low', help='directory for testing inputs')
  24. parser.add_argument('--decom', dest='decom', default=0, help='decom flag, 0 for enhanced results only and 1 for decomposition results')
  25. args = parser.parse_args()
  26. def lowlight_train(lowlight_enhance):
  27. if not os.path.exists(args.ckpt_dir):
  28. os.makedirs(args.ckpt_dir)
  29. if not os.path.exists(args.sample_dir):
  30. os.makedirs(args.sample_dir)
  31. lr = args.start_lr * np.ones([args.epoch])
  32. lr[20:] = lr[0] / 10.0
  33. train_low_data = []
  34. train_high_data = []
  35. train_low_data_names = glob('./data/our485/low/*.png') + glob('./data/syn/low/*.png')
  36. train_low_data_names.sort()
  37. train_high_data_names = glob('./data/our485/high/*.png') + glob('./data/syn/high/*.png')
  38. train_high_data_names.sort()
  39. assert len(train_low_data_names) == len(train_high_data_names)
  40. print('[*] Number of training data: %d' % len(train_low_data_names))
  41. for idx in range(len(train_low_data_names)):
  42. low_im = load_images(train_low_data_names[idx])
  43. train_low_data.append(low_im)
  44. high_im = load_images(train_high_data_names[idx])
  45. train_high_data.append(high_im)
  46. eval_low_data = []
  47. eval_high_data = []
  48. eval_low_data_name = glob('./data/eval/low/*.*')
  49. for idx in range(len(eval_low_data_name)):
  50. eval_low_im = load_images(eval_low_data_name[idx])
  51. eval_low_data.append(eval_low_im)
  52. lowlight_enhance.train(train_low_data, train_high_data, eval_low_data, batch_size=args.batch_size, patch_size=args.patch_size, epoch=args.epoch, lr=lr, sample_dir=args.sample_dir, ckpt_dir=os.path.join(args.ckpt_dir, 'Decom'), eval_every_epoch=args.eval_every_epoch, train_phase="Decom")
  53. lowlight_enhance.train(train_low_data, train_high_data, eval_low_data, batch_size=args.batch_size, patch_size=args.patch_size, epoch=args.epoch, lr=lr, sample_dir=args.sample_dir, ckpt_dir=os.path.join(args.ckpt_dir, 'Relight'), eval_every_epoch=args.eval_every_epoch, train_phase="Relight")
  54. def lowlight_test(lowlight_enhance):
  55. if args.test_dir == None:
  56. print("[!] please provide --test_dir")
  57. exit(0)
  58. if not os.path.exists(args.save_dir):
  59. os.makedirs(args.save_dir)
  60. test_low_data_name = glob(os.path.join(args.test_dir) + '/*.*')
  61. test_low_data = []
  62. test_high_data = []
  63. for idx in range(len(test_low_data_name)):
  64. test_low_im = load_images(test_low_data_name[idx])
  65. test_low_data.append(test_low_im)
  66. lowlight_enhance.test(test_low_data, test_high_data, test_low_data_name, save_dir=args.save_dir, decom_flag=args.decom)
  67. def main(_):
  68. if args.use_gpu:
  69. print("[*] GPU\n")
  70. os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_idx
  71. gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem)
  72. with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
  73. model = lowlight_enhance(sess)
  74. if args.phase == 'train':
  75. lowlight_train(model)
  76. elif args.phase == 'test':
  77. lowlight_test(model)
  78. else:
  79. print('[!] Unknown phase')
  80. exit(0)
  81. else:
  82. print("[*] CPU\n")
  83. with tf.Session() as sess:
  84. model = lowlight_enhance(sess)
  85. if args.phase == 'train':
  86. lowlight_train(model)
  87. elif args.phase == 'test':
  88. lowlight_test(model)
  89. else:
  90. print('[!] Unknown phase')
  91. exit(0)
  92. if __name__ == '__main__':
  93. tf.app.run()