bicycle-gan.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import argparse
  2. import sys
  3. import signal
  4. import os
  5. from datetime import datetime
  6. import tensorflow as tf
  7. from data_loader import get_data
  8. from model import BicycleGAN
  9. from utils import logger, makedirs
  10. # parsing cmd arguments
  11. parser = argparse.ArgumentParser(description="Run commands")
  12. def str2bool(v):
  13. return v.lower() == 'true'
  14. parser.add_argument('--train', default=True, type=str2bool,
  15. help="Training mode")
  16. parser.add_argument('--task', type=str, default='edges2shoes',
  17. help='Task name')
  18. parser.add_argument('--coeff_vae', type=float, default=1.0,
  19. help='Loss coefficient for VAE')
  20. parser.add_argument('--coeff_kl', type=float, default=0.01,
  21. help='Loss coefficient for KL divergence')
  22. parser.add_argument('--coeff_reconstruct', type=float, default=10,
  23. help='Loss coefficient for reconstruct')
  24. parser.add_argument('--coeff_latent', type=float, default=0.5,
  25. help='Loss coefficient for latent cycle')
  26. parser.add_argument('--instance_normalization', default=False, type=bool,
  27. help="Use instance norm instead of batch norm")
  28. parser.add_argument('--log_step', default=100, type=int,
  29. help="Tensorboard log frequency")
  30. parser.add_argument('--batch_size', default=1, type=int,
  31. help="Batch size")
  32. parser.add_argument('--image_size', default=256, type=int,
  33. help="Image size")
  34. parser.add_argument('--latent_dim', default=8, type=int,
  35. help="Dimensionality of latent vector")
  36. parser.add_argument('--use_resnet', default=True, type=bool,
  37. help="Use the ResNet model for the encoder")
  38. parser.add_argument('--load_model', default='',
  39. help='Model path to load (e.g., train_2017-07-07_01-23-45)')
  40. parser.add_argument('--gpu', default="1", type=str,
  41. help="gpu index for CUDA_VISIBLE_DEVICES")
  42. class FastSaver(tf.train.Saver):
  43. def save(self, sess, save_path, global_step=None, latest_filename=None,
  44. meta_graph_suffix="meta", write_meta_graph=True):
  45. super(FastSaver, self).save(sess, save_path, global_step, latest_filename,
  46. meta_graph_suffix, False)
  47. def run(args):
  48. # setting the GPU #
  49. os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
  50. os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
  51. logger.info('Read data:')
  52. train_A, train_B, test_A, test_B = get_data(args.task, args.image_size)
  53. logger.info('Build graph:')
  54. model = BicycleGAN(args)
  55. variables_to_save = tf.global_variables()
  56. init_op = tf.variables_initializer(variables_to_save)
  57. init_all_op = tf.global_variables_initializer()
  58. saver = FastSaver(variables_to_save)
  59. logger.info('Trainable vars:')
  60. var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
  61. tf.get_variable_scope().name)
  62. for v in var_list:
  63. logger.info(' %s %s', v.name, v.get_shape())
  64. if args.load_model != '':
  65. model_name = args.load_model
  66. else:
  67. model_name = '{}_{}'.format(args.task, datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
  68. logdir = './logs'
  69. makedirs(logdir)
  70. logdir = os.path.join(logdir, model_name)
  71. logger.info('Events directory: %s', logdir)
  72. summary_writer = tf.summary.FileWriter(logdir)
  73. makedirs('./results')
  74. def init_fn(sess):
  75. logger.info('Initializing all parameters.')
  76. sess.run(init_all_op)
  77. sv = tf.train.Supervisor(is_chief=True,
  78. logdir=logdir,
  79. saver=saver,
  80. summary_op=None,
  81. init_op=init_op,
  82. init_fn=init_fn,
  83. summary_writer=summary_writer,
  84. ready_op=tf.report_uninitialized_variables(variables_to_save),
  85. global_step=model.global_step,
  86. save_model_secs=300,
  87. save_summaries_secs=30)
  88. if args.train:
  89. logger.info("Starting training session.")
  90. with sv.managed_session() as sess:
  91. model.train(sess, summary_writer, train_A, train_B)
  92. logger.info("Starting testing session.")
  93. with sv.managed_session() as sess:
  94. base_dir = os.path.join('results', model_name)
  95. makedirs(base_dir)
  96. model.test(sess, test_A, test_B, base_dir)
  97. def main():
  98. args, unparsed = parser.parse_known_args()
  99. def shutdown(signal, frame):
  100. tf.logging.warn('Received signal %s: exiting', signal)
  101. sys.exit(128+signal)
  102. signal.signal(signal.SIGHUP, shutdown)
  103. signal.signal(signal.SIGINT, shutdown)
  104. signal.signal(signal.SIGTERM, shutdown)
  105. run(args)
  106. if __name__ == "__main__":
  107. main()