bicycle-gan.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import argparse
  2. import sys
  3. import signal
  4. import os
  5. from datetime import datetime
  6. import tensorflow as tf
  7. from data_loader import get_data
  8. from model import BicycleGAN
  9. from utils import logger, makedirs
  10. # parsing cmd arguments
  11. parser = argparse.ArgumentParser(description="Run commands")
  12. def str2bool(v):
  13. return v.lower() == 'true'
  14. parser.add_argument('--train', default=True, type=str2bool,
  15. help="Training mode")
  16. parser.add_argument('--task', type=str, default='edges2shoes',
  17. help='Task name')
  18. parser.add_argument('--coeff_kl', type=float, default=0.01,
  19. help='Loss coefficient for KL divergence')
  20. parser.add_argument('--coeff_reconstruct', type=float, default=10,
  21. help='Loss coefficient for reconstruct')
  22. parser.add_argument('--coeff_latent', type=float, default=0.5,
  23. help='Loss coefficient for latent cycle')
  24. parser.add_argument('--instance_normalization', default=False, type=bool,
  25. help="Use instance norm instead of batch norm")
  26. parser.add_argument('--log_step', default=100, type=int,
  27. help="Tensorboard log frequency")
  28. parser.add_argument('--batch_size', default=1, type=int,
  29. help="Batch size")
  30. parser.add_argument('--image_size', default=256, type=int,
  31. help="Image size")
  32. parser.add_argument('--latent_dim', default=8, type=int,
  33. help="Dimensionality of latent vector")
  34. parser.add_argument('--use_resnet', default=True, type=bool,
  35. help="Use the ResNet model for the encoder")
  36. parser.add_argument('--load_model', default='',
  37. help='Model path to load (e.g., train_2017-07-07_01-23-45)')
  38. parser.add_argument('--gpu', default="1", type=str,
  39. help="gpu index for CUDA_VISIBLE_DEVICES")
  40. class FastSaver(tf.train.Saver):
  41. def save(self, sess, save_path, global_step=None, latest_filename=None,
  42. meta_graph_suffix="meta", write_meta_graph=True):
  43. super(FastSaver, self).save(sess, save_path, global_step, latest_filename,
  44. meta_graph_suffix, False)
  45. def run(args):
  46. # setting the GPU #
  47. os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
  48. os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
  49. logger.info('Read data:')
  50. train_A, train_B, test_A, test_B = get_data(args.task, args.image_size)
  51. logger.info('Build graph:')
  52. model = BicycleGAN(args)
  53. variables_to_save = tf.global_variables()
  54. init_op = tf.variables_initializer(variables_to_save)
  55. init_all_op = tf.global_variables_initializer()
  56. saver = FastSaver(variables_to_save)
  57. logger.info('Trainable vars:')
  58. var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
  59. tf.get_variable_scope().name)
  60. for v in var_list:
  61. logger.info(' %s %s', v.name, v.get_shape())
  62. if args.load_model != '':
  63. model_name = args.load_model
  64. else:
  65. model_name = '{}_{}'.format(args.task, datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
  66. logdir = './logs'
  67. makedirs(logdir)
  68. logdir = os.path.join(logdir, model_name)
  69. logger.info('Events directory: %s', logdir)
  70. summary_writer = tf.summary.FileWriter(logdir)
  71. makedirs('./results')
  72. def init_fn(sess):
  73. logger.info('Initializing all parameters.')
  74. sess.run(init_all_op)
  75. sv = tf.train.Supervisor(is_chief=True,
  76. logdir=logdir,
  77. saver=saver,
  78. summary_op=None,
  79. init_op=init_op,
  80. init_fn=init_fn,
  81. summary_writer=summary_writer,
  82. ready_op=tf.report_uninitialized_variables(variables_to_save),
  83. global_step=model.global_step,
  84. save_model_secs=300,
  85. save_summaries_secs=30)
  86. if args.train:
  87. logger.info("Starting training session.")
  88. with sv.managed_session() as sess:
  89. model.train(sess, summary_writer, train_A, train_B)
  90. logger.info("Starting testing session.")
  91. with sv.managed_session() as sess:
  92. base_dir = os.path.join('results', model_name)
  93. makedirs(base_dir)
  94. model.test(sess, test_A, test_B, base_dir)
  95. def main():
  96. args, unparsed = parser.parse_known_args()
  97. def shutdown(signal, frame):
  98. tf.logging.warn('Received signal %s: exiting', signal)
  99. sys.exit(128+signal)
  100. signal.signal(signal.SIGHUP, shutdown)
  101. signal.signal(signal.SIGINT, shutdown)
  102. signal.signal(signal.SIGTERM, shutdown)
  103. run(args)
  104. if __name__ == "__main__":
  105. main()