bicycle-gan.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import argparse
  2. import sys
  3. import signal
  4. import os
  5. from datetime import datetime
  6. import tensorflow as tf
  7. from data_loader import get_data
  8. from model import BicycleGAN
  9. from utils import logger, makedirs
  10. # parsing cmd arguments
  11. parser = argparse.ArgumentParser(description="Run commands")
  12. def str2bool(v):
  13. return v.lower() == 'true'
  14. parser.add_argument('--train', default=True, type=str2bool,
  15. help="Training mode")
  16. parser.add_argument('--task', type=str, default='edges2shoes',
  17. help='Task name')
  18. parser.add_argument('--coeff_gan', type=float, default=1.0,
  19. help='Loss coefficient for GAN loss')
  20. parser.add_argument('--coeff_vae', type=float, default=1.0,
  21. help='Loss coefficient for VAE loss')
  22. parser.add_argument('--coeff_kl', type=float, default=0.01,
  23. help='Loss coefficient for KL divergence')
  24. parser.add_argument('--coeff_reconstruct', type=float, default=10,
  25. help='Loss coefficient for reconstruction error')
  26. parser.add_argument('--coeff_latent', type=float, default=0.5,
  27. help='Loss coefficient for latent cycle loss')
  28. parser.add_argument('--instance_normalization', default=False, type=bool,
  29. help="Use instance norm instead of batch norm")
  30. parser.add_argument('--log_step', default=100, type=int,
  31. help="Tensorboard log frequency")
  32. parser.add_argument('--batch_size', default=1, type=int,
  33. help="Batch size")
  34. parser.add_argument('--image_size', default=256, type=int,
  35. help="Image size")
  36. parser.add_argument('--latent_dim', default=8, type=int,
  37. help="Dimensionality of latent vector")
  38. parser.add_argument('--use_resnet', default=True, type=bool,
  39. help="Use the ResNet model for the encoder")
  40. parser.add_argument('--load_model', default='',
  41. help='Model path to load (e.g., train_2017-07-07_01-23-45)')
  42. parser.add_argument('--gpu', default="1", type=str,
  43. help="gpu index for CUDA_VISIBLE_DEVICES")
  44. class FastSaver(tf.train.Saver):
  45. def save(self, sess, save_path, global_step=None, latest_filename=None,
  46. meta_graph_suffix="meta", write_meta_graph=True):
  47. super(FastSaver, self).save(sess, save_path, global_step, latest_filename,
  48. meta_graph_suffix, False)
  49. def run(args):
  50. # setting the GPU #
  51. os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
  52. os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
  53. logger.info('Read data:')
  54. train_A, train_B, test_A, test_B = get_data(args.task, args.image_size)
  55. logger.info('Build graph:')
  56. model = BicycleGAN(args)
  57. variables_to_save = tf.global_variables()
  58. init_op = tf.variables_initializer(variables_to_save)
  59. init_all_op = tf.global_variables_initializer()
  60. saver = FastSaver(variables_to_save)
  61. logger.info('Trainable vars:')
  62. var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
  63. tf.get_variable_scope().name)
  64. for v in var_list:
  65. logger.info(' %s %s', v.name, v.get_shape())
  66. if args.load_model != '':
  67. model_name = args.load_model
  68. else:
  69. model_name = '{}_{}'.format(args.task, datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
  70. logdir = './logs'
  71. makedirs(logdir)
  72. logdir = os.path.join(logdir, model_name)
  73. logger.info('Events directory: %s', logdir)
  74. summary_writer = tf.summary.FileWriter(logdir)
  75. makedirs('./results')
  76. def init_fn(sess):
  77. logger.info('Initializing all parameters.')
  78. sess.run(init_all_op)
  79. sv = tf.train.Supervisor(is_chief=True,
  80. logdir=logdir,
  81. saver=saver,
  82. summary_op=None,
  83. init_op=init_op,
  84. init_fn=init_fn,
  85. summary_writer=summary_writer,
  86. ready_op=tf.report_uninitialized_variables(variables_to_save),
  87. global_step=model.global_step,
  88. save_model_secs=300,
  89. save_summaries_secs=30)
  90. if args.train:
  91. logger.info("Starting training session.")
  92. with sv.managed_session() as sess:
  93. model.train(sess, summary_writer, train_A, train_B)
  94. logger.info("Starting testing session.")
  95. with sv.managed_session() as sess:
  96. base_dir = os.path.join('results', model_name)
  97. makedirs(base_dir)
  98. model.test(sess, test_A, test_B, base_dir)
  99. def main():
  100. args, unparsed = parser.parse_known_args()
  101. def shutdown(signal, frame):
  102. tf.logging.warn('Received signal %s: exiting', signal)
  103. sys.exit(128+signal)
  104. signal.signal(signal.SIGHUP, shutdown)
  105. signal.signal(signal.SIGINT, shutdown)
  106. signal.signal(signal.SIGTERM, shutdown)
  107. run(args)
  108. if __name__ == "__main__":
  109. main()