module.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. from __future__ import division
  2. import tensorflow as tf
  3. from ops import *
  4. from utils import *
  5. def discriminator(image, options, reuse=False, name="discriminator"):
  6. with tf.variable_scope(name):
  7. # image is 256 x 256 x input_c_dim
  8. if reuse:
  9. tf.get_variable_scope().reuse_variables()
  10. else:
  11. assert tf.get_variable_scope().reuse is False
  12. h0 = lrelu(conv2d(image, options.df_dim, name='d_h0_conv'))
  13. # h0 is (128 x 128 x self.df_dim)
  14. h1 = lrelu(instance_norm(conv2d(h0, options.df_dim*2, name='d_h1_conv'), 'd_bn1'))
  15. # h1 is (64 x 64 x self.df_dim*2)
  16. h2 = lrelu(instance_norm(conv2d(h1, options.df_dim*4, name='d_h2_conv'), 'd_bn2'))
  17. # h2 is (32x 32 x self.df_dim*4)
  18. h3 = lrelu(instance_norm(conv2d(h2, options.df_dim*8, s=1, name='d_h3_conv'), 'd_bn3'))
  19. # h3 is (32 x 32 x self.df_dim*8)
  20. h4 = conv2d(h3, 1, s=1, name='d_h3_pred')
  21. # h4 is (32 x 32 x 1)
  22. return h4
  23. def generator_unet(image, options, reuse=False, name="generator"):
  24. dropout_rate = 0.5 if options.is_training else 1.0
  25. with tf.variable_scope(name):
  26. # image is 256 x 256 x input_c_dim
  27. if reuse:
  28. tf.get_variable_scope().reuse_variables()
  29. else:
  30. assert tf.get_variable_scope().reuse is False
  31. # image is (256 x 256 x input_c_dim)
  32. e1 = instance_norm(conv2d(image, options.gf_dim, name='g_e1_conv'))
  33. # e1 is (128 x 128 x self.gf_dim)
  34. e2 = instance_norm(conv2d(lrelu(e1), options.gf_dim*2, name='g_e2_conv'), 'g_bn_e2')
  35. # e2 is (64 x 64 x self.gf_dim*2)
  36. e3 = instance_norm(conv2d(lrelu(e2), options.gf_dim*4, name='g_e3_conv'), 'g_bn_e3')
  37. # e3 is (32 x 32 x self.gf_dim*4)
  38. e4 = instance_norm(conv2d(lrelu(e3), options.gf_dim*8, name='g_e4_conv'), 'g_bn_e4')
  39. # e4 is (16 x 16 x self.gf_dim*8)
  40. e5 = instance_norm(conv2d(lrelu(e4), options.gf_dim*8, name='g_e5_conv'), 'g_bn_e5')
  41. # e5 is (8 x 8 x self.gf_dim*8)
  42. e6 = instance_norm(conv2d(lrelu(e5), options.gf_dim*8, name='g_e6_conv'), 'g_bn_e6')
  43. # e6 is (4 x 4 x self.gf_dim*8)
  44. e7 = instance_norm(conv2d(lrelu(e6), options.gf_dim*8, name='g_e7_conv'), 'g_bn_e7')
  45. # e7 is (2 x 2 x self.gf_dim*8)
  46. e8 = instance_norm(conv2d(lrelu(e7), options.gf_dim*8, name='g_e8_conv'), 'g_bn_e8')
  47. # e8 is (1 x 1 x self.gf_dim*8)
  48. d1 = deconv2d(tf.nn.relu(e8), options.gf_dim*8, name='g_d1')
  49. d1 = tf.nn.dropout(d1, dropout_rate)
  50. d1 = tf.concat([instance_norm(d1, 'g_bn_d1'), e7], 3)
  51. # d1 is (2 x 2 x self.gf_dim*8*2)
  52. d2 = deconv2d(tf.nn.relu(d1), options.gf_dim*8, name='g_d2')
  53. d2 = tf.nn.dropout(d2, dropout_rate)
  54. d2 = tf.concat([instance_norm(d2, 'g_bn_d2'), e6], 3)
  55. # d2 is (4 x 4 x self.gf_dim*8*2)
  56. d3 = deconv2d(tf.nn.relu(d2), options.gf_dim*8, name='g_d3')
  57. d3 = tf.nn.dropout(d3, dropout_rate)
  58. d3 = tf.concat([instance_norm(d3, 'g_bn_d3'), e5], 3)
  59. # d3 is (8 x 8 x self.gf_dim*8*2)
  60. d4 = deconv2d(tf.nn.relu(d3), options.gf_dim*8, name='g_d4')
  61. d4 = tf.concat([instance_norm(d4, 'g_bn_d4'), e4], 3)
  62. # d4 is (16 x 16 x self.gf_dim*8*2)
  63. d5 = deconv2d(tf.nn.relu(d4), options.gf_dim*4, name='g_d5')
  64. d5 = tf.concat([instance_norm(d5, 'g_bn_d5'), e3], 3)
  65. # d5 is (32 x 32 x self.gf_dim*4*2)
  66. d6 = deconv2d(tf.nn.relu(d5), options.gf_dim*2, name='g_d6')
  67. d6 = tf.concat([instance_norm(d6, 'g_bn_d6'), e2], 3)
  68. # d6 is (64 x 64 x self.gf_dim*2*2)
  69. d7 = deconv2d(tf.nn.relu(d6), options.gf_dim, name='g_d7')
  70. d7 = tf.concat([instance_norm(d7, 'g_bn_d7'), e1], 3)
  71. # d7 is (128 x 128 x self.gf_dim*1*2)
  72. d8 = deconv2d(tf.nn.relu(d7), options.output_c_dim, name='g_d8')
  73. # d8 is (256 x 256 x output_c_dim)
  74. return tf.nn.tanh(d8)
  75. def generator_resnet(image, options, reuse=False, name="generator"):
  76. with tf.variable_scope(name):
  77. # image is 256 x 256 x input_c_dim
  78. if reuse:
  79. tf.get_variable_scope().reuse_variables()
  80. else:
  81. assert tf.get_variable_scope().reuse is False
  82. def residule_block(x, dim, ks=3, s=1, name='res'):
  83. p = int((ks - 1) / 2)
  84. y = tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT")
  85. y = instance_norm(conv2d(y, dim, ks, s, padding='VALID', name=name+'_c1'), name+'_bn1')
  86. y = tf.pad(tf.nn.relu(y), [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT")
  87. y = instance_norm(conv2d(y, dim, ks, s, padding='VALID', name=name+'_c2'), name+'_bn2')
  88. return y + x
  89. # Justin Johnson's model from https://github.com/jcjohnson/fast-neural-style/
  90. # The network with 9 blocks consists of: c7s1-32, d64, d128, R128, R128, R128,
  91. # R128, R128, R128, R128, R128, R128, u64, u32, c7s1-3
  92. c0 = tf.pad(image, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT")
  93. c1 = tf.nn.relu(instance_norm(conv2d(c0, options.gf_dim, 7, 1, padding='VALID', name='g_e1_c'), 'g_e1_bn'))
  94. c2 = tf.nn.relu(instance_norm(conv2d(c1, options.gf_dim*2, 3, 2, name='g_e2_c'), 'g_e2_bn'))
  95. c3 = tf.nn.relu(instance_norm(conv2d(c2, options.gf_dim*4, 3, 2, name='g_e3_c'), 'g_e3_bn'))
  96. # define G network with 9 resnet blocks
  97. r1 = residule_block(c3, options.gf_dim*4, name='g_r1')
  98. r2 = residule_block(r1, options.gf_dim*4, name='g_r2')
  99. r3 = residule_block(r2, options.gf_dim*4, name='g_r3')
  100. r4 = residule_block(r3, options.gf_dim*4, name='g_r4')
  101. r5 = residule_block(r4, options.gf_dim*4, name='g_r5')
  102. r6 = residule_block(r5, options.gf_dim*4, name='g_r6')
  103. r7 = residule_block(r6, options.gf_dim*4, name='g_r7')
  104. r8 = residule_block(r7, options.gf_dim*4, name='g_r8')
  105. r9 = residule_block(r8, options.gf_dim*4, name='g_r9')
  106. d1 = deconv2d(r9, options.gf_dim*2, 3, 2, name='g_d1_dc')
  107. d1 = tf.nn.relu(instance_norm(d1, 'g_d1_bn'))
  108. d2 = deconv2d(d1, options.gf_dim, 3, 2, name='g_d2_dc')
  109. d2 = tf.nn.relu(instance_norm(d2, 'g_d2_bn'))
  110. d2 = tf.pad(d2, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT")
  111. pred = tf.nn.tanh(conv2d(d2, options.output_c_dim, 7, 1, padding='VALID', name='g_pred_c'))
  112. return pred
  113. def abs_criterion(in_, target):
  114. return tf.reduce_mean(tf.abs(in_ - target))
  115. def mae_criterion(in_, target):
  116. return tf.reduce_mean((in_-target)**2)
  117. def sce_criterion(logits, labels):
  118. return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels))