ops.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import tensorflow as tf
  2. def _norm(input, is_train, reuse=True, norm=None):
  3. assert norm in ['instance', 'batch', None]
  4. if norm == 'instance':
  5. with tf.variable_scope('instance_norm', reuse=reuse):
  6. eps = 1e-5
  7. mean, sigma = tf.nn.moments(input, [1, 2], keep_dims=True)
  8. normalized = (input - mean) / (tf.sqrt(sigma) + eps)
  9. out = normalized
  10. # Apply momentum (not mendatory)
  11. #c = input.get_shape()[-1]
  12. #shift = tf.get_variable('shift', shape=[c],
  13. # initializer=tf.zeros_initializer())
  14. #scale = tf.get_variable('scale', shape=[c],
  15. # initializer=tf.random_normal_initializer(1.0, 0.02))
  16. #out = scale * normalized + shift
  17. elif norm == 'batch':
  18. with tf.variable_scope('batch_norm', reuse=reuse):
  19. out = tf.contrib.layers.batch_norm(input,
  20. decay=0.99, center=True,
  21. scale=True, is_training=is_train,
  22. updates_collections=None)
  23. else:
  24. out = input
  25. return out
  26. def _activation(input, activation=None):
  27. assert activation in ['relu', 'leaky', 'tanh', 'sigmoid', None]
  28. if activation == 'relu':
  29. return tf.nn.relu(input)
  30. elif activation == 'leaky':
  31. return tf.contrib.keras.layers.LeakyReLU(0.2)(input)
  32. elif activation == 'tanh':
  33. return tf.tanh(input)
  34. elif activation == 'sigmoid':
  35. return tf.sigmoid(input)
  36. else:
  37. return input
  38. def conv2d(input, num_filters, filter_size, stride, reuse=False,
  39. pad='SAME', dtype=tf.float32, bias=False):
  40. stride_shape = [1, stride, stride, 1]
  41. filter_shape = [filter_size, filter_size, input.get_shape()[3], num_filters]
  42. w = tf.get_variable('w', filter_shape, dtype, tf.random_normal_initializer(0.0, 0.02))
  43. if pad == 'REFLECT':
  44. p = (filter_size - 1) // 2
  45. x = tf.pad(input, [[0,0],[p,p],[p,p],[0,0]], 'REFLECT')
  46. conv = tf.nn.conv2d(x, w, stride_shape, padding='VALID')
  47. else:
  48. assert pad in ['SAME', 'VALID']
  49. conv = tf.nn.conv2d(input, w, stride_shape, padding=pad)
  50. if bias:
  51. b = tf.get_variable('b', [1,1,1,num_filters], initializer=tf.constant_initializer(0.0))
  52. conv = conv + b
  53. return conv
  54. def conv2d_transpose(input, num_filters, filter_size, stride, reuse,
  55. pad='SAME', dtype=tf.float32):
  56. assert pad == 'SAME'
  57. n, h, w, c = input.get_shape().as_list()
  58. stride_shape = [1, stride, stride, 1]
  59. filter_shape = [filter_size, filter_size, num_filters, c]
  60. output_shape = [n, h * stride, w * stride, num_filters]
  61. w = tf.get_variable('w', filter_shape, dtype, tf.random_normal_initializer(0.0, 0.02))
  62. deconv = tf.nn.conv2d_transpose(input, w, output_shape, stride_shape, pad)
  63. return deconv
  64. def mlp(input, out_dim, name, is_train, reuse, norm=None, activation=None,
  65. dtype=tf.float32, bias=True):
  66. with tf.variable_scope(name, reuse=reuse):
  67. _, n = input.get_shape()
  68. w = tf.get_variable('w', [n, out_dim], dtype, tf.random_normal_initializer(0.0, 0.02))
  69. out = tf.matmul(input, w)
  70. if bias:
  71. b = tf.get_variable('b', [out_dim], initializer=tf.constant_initializer(0.0))
  72. out = out + b
  73. out = _activation(out, activation)
  74. out = _norm(out, is_train, reuse, norm)
  75. return out
  76. def conv_block(input, num_filters, name, k_size, stride, is_train, reuse, norm,
  77. activation, pad='SAME', bias=False):
  78. with tf.variable_scope(name, reuse=reuse):
  79. out = conv2d(input, num_filters, k_size, stride, reuse, pad, bias=bias)
  80. out = _norm(out, is_train, reuse, norm)
  81. out = _activation(out, activation)
  82. return out
  83. def residual(input, num_filters, name, is_train, reuse, norm, pad='REFLECT'):
  84. with tf.variable_scope(name, reuse=reuse):
  85. with tf.variable_scope('res1', reuse=reuse):
  86. out = conv2d(input, num_filters, 3, 1, reuse, pad)
  87. out = _norm(out, is_train, reuse, norm)
  88. out = tf.nn.relu(out)
  89. with tf.variable_scope('res2', reuse=reuse):
  90. out = conv2d(out, num_filters, 3, 1, reuse, pad)
  91. out = _norm(out, is_train, reuse, norm)
  92. return tf.nn.relu(input + out)
  93. def deconv_block(input, num_filters, name, k_size, stride, is_train, reuse,
  94. norm, activation):
  95. with tf.variable_scope(name, reuse=reuse):
  96. out = conv2d_transpose(input, num_filters, k_size, stride, reuse)
  97. out = _norm(out, is_train, reuse, norm)
  98. out = _activation(out, activation)
  99. return out