bicycle-gan.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import argparse
  2. import sys
  3. import signal
  4. import os
  5. from datetime import datetime
  6. import tensorflow as tf
  7. from data_loader import get_data
  8. from model import BicycleGAN
  9. from utils import logger, makedirs
  10. # parsing cmd arguments
  11. parser = argparse.ArgumentParser(description="Run commands")
  12. def str2bool(v):
  13. return v.lower() == 'true'
  14. parser.add_argument('--train', default=True, type=str2bool,
  15. help="Training mode")
  16. parser.add_argument('--task', type=str, default='edges2shoes',
  17. help='Task name')
  18. parser.add_argument('--gamma', type=float, default=1,
  19. help='Loss coefficient')
  20. parser.add_argument('--lambda1', type=float, default=1,
  21. help='Loss coefficient')
  22. parser.add_argument('--lambda2', type=float, default=1,
  23. help='Loss coefficient')
  24. parser.add_argument('--instance_normalization', default=False, type=bool,
  25. help="Use instance norm instead of batch norm")
  26. parser.add_argument('--log_step', default=100, type=int,
  27. help="Tensorboard log frequency")
  28. parser.add_argument('--batch_size', default=1, type=int,
  29. help="Batch size")
  30. parser.add_argument('--image_size', default=128, type=int,
  31. help="Image size")
  32. parser.add_argument('--latent_dim', default=8, type=int,
  33. help="Dimensionality of latent vector")
  34. parser.add_argument('--load_model', default='',
  35. help='Model path to load (e.g., train_2017-07-07_01-23-45)')
  36. parser.add_argument('--gpu', default="1", type=str,
  37. help="gpu index for CUDA_VISIBLE_DEVICES")
  38. class FastSaver(tf.train.Saver):
  39. def save(self, sess, save_path, global_step=None, latest_filename=None,
  40. meta_graph_suffix="meta", write_meta_graph=True):
  41. super(FastSaver, self).save(sess, save_path, global_step, latest_filename,
  42. meta_graph_suffix, False)
  43. def run(args):
  44. # setting the GPU #
  45. os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
  46. os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
  47. logger.info('Read data:')
  48. train_A, train_B, test_A, test_B = get_data(args.task, args.image_size)
  49. logger.info('Build graph:')
  50. model = BicycleGAN(args)
  51. variables_to_save = tf.global_variables()
  52. init_op = tf.variables_initializer(variables_to_save)
  53. init_all_op = tf.global_variables_initializer()
  54. saver = FastSaver(variables_to_save)
  55. logger.info('Trainable vars:')
  56. var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
  57. tf.get_variable_scope().name)
  58. for v in var_list:
  59. logger.info(' %s %s', v.name, v.get_shape())
  60. if args.load_model != '':
  61. model_name = args.load_model
  62. else:
  63. model_name = '{}_{}'.format(args.task, datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
  64. logdir = './logs'
  65. makedirs(logdir)
  66. logdir = os.path.join(logdir, model_name)
  67. logger.info('Events directory: %s', logdir)
  68. summary_writer = tf.summary.FileWriter(logdir)
  69. def init_fn(sess):
  70. logger.info('Initializing all parameters.')
  71. sess.run(init_all_op)
  72. sv = tf.train.Supervisor(is_chief=True,
  73. logdir=logdir,
  74. saver=saver,
  75. summary_op=None,
  76. init_op=init_op,
  77. init_fn=init_fn,
  78. summary_writer=summary_writer,
  79. ready_op=tf.report_uninitialized_variables(variables_to_save),
  80. global_step=model.global_step,
  81. save_model_secs=300,
  82. save_summaries_secs=30)
  83. if args.train:
  84. logger.info("Starting training session.")
  85. with sv.managed_session() as sess:
  86. model.train(sess, summary_writer, train_A, train_B)
  87. logger.info("Starting testing session.")
  88. with sv.managed_session() as sess:
  89. base_dir = os.path.join('results', model_name)
  90. makedirs(base_dir)
  91. model.test(sess, test_A, test_B, base_dir)
  92. def main():
  93. args, unparsed = parser.parse_known_args()
  94. def shutdown(signal, frame):
  95. tf.logging.warn('Received signal %s: exiting', signal)
  96. sys.exit(128+signal)
  97. signal.signal(signal.SIGHUP, shutdown)
  98. signal.signal(signal.SIGINT, shutdown)
  99. signal.signal(signal.SIGTERM, shutdown)
  100. run(args)
  101. if __name__ == "__main__":
  102. main()