ops.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import tensorflow as tf
  2. def _norm(input, is_train, reuse=True, norm=None):
  3. assert norm in ['instance', 'batch', None]
  4. if norm == 'instance':
  5. with tf.variable_scope('instance_norm', reuse=reuse):
  6. eps = 1e-5
  7. mean, sigma = tf.nn.moments(input, [1, 2], keep_dims=True)
  8. normalized = (input - mean) / (tf.sqrt(sigma) + eps)
  9. out = normalized
  10. # Apply momentum (not mendatory)
  11. #c = input.get_shape()[-1]
  12. #shift = tf.get_variable('shift', shape=[c],
  13. # initializer=tf.zeros_initializer())
  14. #scale = tf.get_variable('scale', shape=[c],
  15. # initializer=tf.random_normal_initializer(1.0, 0.02))
  16. #out = scale * normalized + shift
  17. elif norm == 'batch':
  18. with tf.variable_scope('batch_norm', reuse=reuse):
  19. out = tf.contrib.layers.batch_norm(input,
  20. decay=0.99, center=True,
  21. scale=True, is_training=True)
  22. else:
  23. out = input
  24. return out
  25. def _activation(input, activation=None):
  26. assert activation in ['relu', 'leaky', 'tanh', 'sigmoid', None]
  27. if activation == 'relu':
  28. return tf.nn.relu(input)
  29. elif activation == 'leaky':
  30. return tf.contrib.keras.layers.LeakyReLU(0.2)(input)
  31. elif activation == 'tanh':
  32. return tf.tanh(input)
  33. elif activation == 'sigmoid':
  34. return tf.sigmoid(input)
  35. else:
  36. return input
  37. def conv2d(input, num_filters, filter_size, stride, reuse=False,
  38. pad='SAME', dtype=tf.float32, bias=False):
  39. stride_shape = [1, stride, stride, 1]
  40. filter_shape = [filter_size, filter_size, input.get_shape()[3], num_filters]
  41. w = tf.get_variable('w', filter_shape, dtype, tf.random_normal_initializer(0.0, 0.02))
  42. if pad == 'REFLECT':
  43. p = (filter_size - 1) // 2
  44. x = tf.pad(input, [[0,0],[p,p],[p,p],[0,0]], 'REFLECT')
  45. conv = tf.nn.conv2d(x, w, stride_shape, padding='VALID')
  46. else:
  47. assert pad in ['SAME', 'VALID']
  48. conv = tf.nn.conv2d(input, w, stride_shape, padding=pad)
  49. if bias:
  50. b = tf.get_variable('b', [1,1,1,num_filters], initializer=tf.constant_initializer(0.0))
  51. conv = conv + b
  52. return conv
  53. def conv2d_transpose(input, num_filters, filter_size, stride, reuse,
  54. pad='SAME', dtype=tf.float32):
  55. assert pad == 'SAME'
  56. n, h, w, c = input.get_shape().as_list()
  57. stride_shape = [1, stride, stride, 1]
  58. filter_shape = [filter_size, filter_size, num_filters, c]
  59. output_shape = [n, h * stride, w * stride, num_filters]
  60. w = tf.get_variable('w', filter_shape, dtype, tf.random_normal_initializer(0.0, 0.02))
  61. deconv = tf.nn.conv2d_transpose(input, w, output_shape, stride_shape, pad)
  62. return deconv
  63. def mlp(input, out_dim, name, is_train, reuse, norm=None, activation=None,
  64. dtype=tf.float32, bias=True):
  65. with tf.variable_scope(name, reuse=reuse):
  66. _, n = input.get_shape()
  67. w = tf.get_variable('w', [n, out_dim], dtype, tf.random_normal_initializer(0.0, 0.02))
  68. out = tf.matmul(input, w)
  69. if bias:
  70. b = tf.get_variable('b', [out_dim], initializer=tf.constant_initializer(0.0))
  71. out = out + b
  72. out = _activation(out, activation)
  73. out = _norm(out, is_train, reuse, norm)
  74. return out
  75. def conv_block(input, num_filters, name, k_size, stride, is_train, reuse, norm,
  76. activation, pad='SAME', bias=False):
  77. with tf.variable_scope(name, reuse=reuse):
  78. out = conv2d(input, num_filters, k_size, stride, reuse, pad, bias=bias)
  79. out = _norm(out, is_train, reuse, norm)
  80. out = _activation(out, activation)
  81. return out
  82. def residual(input, num_filters, name, is_train, reuse, norm, pad='REFLECT'):
  83. with tf.variable_scope(name, reuse=reuse):
  84. with tf.variable_scope('res1', reuse=reuse):
  85. out = conv2d(input, num_filters, 3, 1, reuse, pad)
  86. out = _norm(out, is_train, reuse, norm)
  87. out = tf.nn.relu(out)
  88. with tf.variable_scope('res2', reuse=reuse):
  89. out = conv2d(out, num_filters, 3, 1, reuse, pad)
  90. out = _norm(out, is_train, reuse, norm)
  91. return tf.nn.relu(input + out)
  92. def deconv_block(input, num_filters, name, k_size, stride, is_train, reuse,
  93. norm, activation):
  94. with tf.variable_scope(name, reuse=reuse):
  95. out = conv2d_transpose(input, num_filters, k_size, stride, reuse)
  96. out = _norm(out, is_train, reuse, norm)
  97. out = _activation(out, activation)
  98. return out