ops.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import tensorflow as tf
  2. import numpy as np
  3. def _norm(input, is_train, reuse=True, norm=None):
  4. assert norm in ['instance', 'batch', None]
  5. if norm == 'instance':
  6. with tf.variable_scope('instance_norm', reuse=reuse):
  7. eps = 1e-5
  8. mean, sigma = tf.nn.moments(input, [1, 2], keep_dims=True)
  9. normalized = (input - mean) / (tf.sqrt(sigma) + eps)
  10. out = normalized
  11. # Apply momentum (not mendatory)
  12. #c = input.get_shape()[-1]
  13. #shift = tf.get_variable('shift', shape=[c],
  14. # initializer=tf.zeros_initializer())
  15. #scale = tf.get_variable('scale', shape=[c],
  16. # initializer=tf.random_normal_initializer(1.0, 0.02))
  17. #out = scale * normalized + shift
  18. elif norm == 'batch':
  19. with tf.variable_scope('batch_norm', reuse=reuse):
  20. out = tf.contrib.layers.batch_norm(input,
  21. decay=0.99, center=True,
  22. scale=True, is_training=True)
  23. else:
  24. out = input
  25. return out
  26. def _activation(input, activation=None):
  27. assert activation in ['relu', 'leaky', 'tanh', 'sigmoid', None]
  28. if activation == 'relu':
  29. return tf.nn.relu(input)
  30. elif activation == 'leaky':
  31. return tf.contrib.keras.layers.LeakyReLU(0.2)(input)
  32. elif activation == 'tanh':
  33. return tf.tanh(input)
  34. elif activation == 'sigmoid':
  35. return tf.sigmoid(input)
  36. else:
  37. return input
  38. def flatten(input):
  39. return tf.reshape(input, [-1, np.prod(input.get_shape().as_list()[1:])])
  40. def conv2d(input, num_filters, filter_size, stride, reuse=False,
  41. pad='SAME', dtype=tf.float32, bias=False):
  42. stride_shape = [1, stride, stride, 1]
  43. filter_shape = [filter_size, filter_size, input.get_shape()[3], num_filters]
  44. w = tf.get_variable('w', filter_shape, dtype, tf.random_normal_initializer(0.0, 0.02))
  45. if pad == 'REFLECT':
  46. p = (filter_size - 1) // 2
  47. x = tf.pad(input, [[0,0],[p,p],[p,p],[0,0]], 'REFLECT')
  48. conv = tf.nn.conv2d(x, w, stride_shape, padding='VALID')
  49. else:
  50. assert pad in ['SAME', 'VALID']
  51. conv = tf.nn.conv2d(input, w, stride_shape, padding=pad)
  52. if bias:
  53. b = tf.get_variable('b', [1,1,1,num_filters], initializer=tf.constant_initializer(0.0))
  54. conv = conv + b
  55. return conv
  56. def conv2d_transpose(input, num_filters, filter_size, stride, reuse,
  57. pad='SAME', dtype=tf.float32):
  58. assert pad == 'SAME'
  59. n, h, w, c = input.get_shape().as_list()
  60. stride_shape = [1, stride, stride, 1]
  61. filter_shape = [filter_size, filter_size, num_filters, c]
  62. output_shape = [n, h * stride, w * stride, num_filters]
  63. w = tf.get_variable('w', filter_shape, dtype, tf.random_normal_initializer(0.0, 0.02))
  64. deconv = tf.nn.conv2d_transpose(input, w, output_shape, stride_shape, pad)
  65. return deconv
  66. def mlp(input, out_dim, name, is_train, reuse, norm=None, activation=None,
  67. dtype=tf.float32, bias=True):
  68. with tf.variable_scope(name, reuse=reuse):
  69. _, n = input.get_shape()
  70. w = tf.get_variable('w', [n, out_dim], dtype, tf.random_normal_initializer(0.0, 0.02))
  71. out = tf.matmul(input, w)
  72. if bias:
  73. b = tf.get_variable('b', [out_dim], initializer=tf.constant_initializer(0.0))
  74. out = out + b
  75. out = _activation(out, activation)
  76. out = _norm(out, is_train, reuse, norm)
  77. return out
  78. def conv_block(input, num_filters, name, k_size, stride, is_train, reuse, norm,
  79. activation, pad='SAME', bias=False):
  80. with tf.variable_scope(name, reuse=reuse):
  81. out = conv2d(input, num_filters, k_size, stride, reuse, pad, bias=bias)
  82. out = _norm(out, is_train, reuse, norm)
  83. out = _activation(out, activation)
  84. return out
  85. def residual(input, num_filters, name, is_train, reuse, norm, pad='REFLECT',
  86. bias=False):
  87. with tf.variable_scope(name, reuse=reuse):
  88. with tf.variable_scope('res1', reuse=reuse):
  89. out = conv2d(input, num_filters, 3, 1, reuse, pad, bias=bias)
  90. out = _norm(out, is_train, reuse, norm)
  91. out = tf.nn.relu(out)
  92. with tf.variable_scope('res2', reuse=reuse):
  93. out = conv2d(out, num_filters, 3, 1, reuse, pad, bias=bias)
  94. out = _norm(out, is_train, reuse, norm)
  95. with tf.variable_scope('shortcut', reuse=reuse):
  96. shortcut = conv2d(input, num_filters, 1, 1, reuse, pad, bias=bias)
  97. return tf.nn.relu(shortcut + out)
  98. def deconv_block(input, num_filters, name, k_size, stride, is_train, reuse,
  99. norm, activation):
  100. with tf.variable_scope(name, reuse=reuse):
  101. out = conv2d_transpose(input, num_filters, k_size, stride, reuse)
  102. out = _norm(out, is_train, reuse, norm)
  103. out = _activation(out, activation)
  104. return out