You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tf_ResNet.py 4.2 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import numpy as np
  2. import tensorflow as tf
  3. def tf_conv2d(x, in_channel, out_channel, stride=1):
  4. weight = tf.Variable(np.random.normal(scale=0.1, size=(
  5. out_channel, in_channel, 3, 3)).transpose([2, 3, 1, 0]).astype(np.float32))
  6. x = tf.nn.conv2d(x, weight, strides=[1, stride, stride, 1], padding='SAME')
  7. return x
  8. def tf_batch_norm_with_relu(x, hidden):
  9. scale = tf.Variable(np.random.normal(
  10. scale=0.1, size=(hidden,)).astype(np.float32))
  11. bias = tf.Variable(np.random.normal(
  12. scale=0.1, size=(hidden,)).astype(np.float32))
  13. axis = list(range(len(x.shape) - 1))
  14. a_mean, a_var = tf.nn.moments(x, axis)
  15. x = tf.nn.batch_normalization(
  16. x, mean=a_mean, variance=a_var, scale=scale, offset=bias, variance_epsilon=1e-2)
  17. x = tf.nn.relu(x)
  18. return x
  19. def tf_resnet_block(x, in_channel, num_blocks, is_first=False):
  20. if is_first:
  21. out_channel = in_channel
  22. identity = x
  23. x = tf_conv2d(x, in_channel, out_channel, stride=1)
  24. x = tf_batch_norm_with_relu(x, out_channel)
  25. x = tf_conv2d(x, out_channel, out_channel, stride=1)
  26. x = x + identity
  27. else:
  28. out_channel = 2 * in_channel
  29. identity = x
  30. x = tf_batch_norm_with_relu(x, in_channel)
  31. x = tf_conv2d(x, in_channel, out_channel, stride=2)
  32. x = tf_batch_norm_with_relu(x, out_channel)
  33. x = tf_conv2d(x, out_channel, out_channel, stride=1)
  34. identity = tf.nn.avg_pool(identity, ksize=[1, 2, 2, 1], strides=[
  35. 1, 2, 2, 1], padding='VALID')
  36. identity = tf.pad(identity, [[0, 0], [0, 0], [0, 0], [
  37. in_channel // 2, in_channel // 2]])
  38. x = x + identity
  39. for i in range(1, num_blocks):
  40. identity = x
  41. x = tf_batch_norm_with_relu(x, out_channel)
  42. x = tf_conv2d(x, out_channel, out_channel, stride=1)
  43. x = tf_batch_norm_with_relu(x, out_channel)
  44. x = tf_conv2d(x, out_channel, out_channel, stride=1)
  45. x = x + identity
  46. return x
  47. def tf_fc(x, shape):
  48. weight = tf.Variable(np.random.normal(
  49. scale=0.1, size=shape).astype(np.float32))
  50. bias = tf.Variable(np.random.normal(
  51. scale=0.1, size=shape[-1:]).astype(np.float32))
  52. x = tf.matmul(x, weight) + bias
  53. return x
  54. def tf_resnet(x, y_, num_layers, num_class=10):
  55. '''
  56. ResNet model in TensorFlow, for CIFAR10 dataset.
  57. Parameters:
  58. x: Variable(tensorflow.python.framework.ops.Tensor), shape (N, H, W, C)
  59. y_: Variable(tensorflow.python.framework.ops.Tensor), shape (N, num_classes)
  60. num_layers: 18 or 34
  61. Return:
  62. loss: Variable(tensorflow.python.framework.ops.Tensor), shape (1,)
  63. y: Variable(tensorflow.python.framework.ops.Tensor), shape (N, num_classes)
  64. '''
  65. print("Number of Class: {}".format(num_class))
  66. base_size = 16
  67. x = tf_conv2d(x, 3, base_size, stride=1)
  68. x = tf_batch_norm_with_relu(x, base_size)
  69. if num_layers == 18:
  70. print("Building ResNet-18 model in tensorflow...")
  71. x = tf_resnet_block(x, base_size, num_blocks=2, is_first=True)
  72. x = tf_resnet_block(x, base_size, num_blocks=2)
  73. x = tf_resnet_block(x, 2 * base_size, num_blocks=2)
  74. x = tf_resnet_block(x, 4 * base_size, num_blocks=2)
  75. elif num_layers == 34:
  76. print("Building ResNet-34 model in tensorflow...")
  77. x = tf_resnet_block(x, base_size, num_blocks=3, is_first=True)
  78. x = tf_resnet_block(x, base_size, num_blocks=4)
  79. x = tf_resnet_block(x, 2 * base_size, num_blocks=6)
  80. x = tf_resnet_block(x, 4 * base_size, num_blocks=3)
  81. else:
  82. assert False, "Number of layers should be 18 or 34 !"
  83. x = tf_batch_norm_with_relu(x, 8 * base_size)
  84. x = tf.transpose(x, [0, 3, 1, 2])
  85. x = tf.reshape(x, [-1, 128 * base_size])
  86. y = tf_fc(x, (128 * base_size, num_class))
  87. loss = tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_)
  88. loss = tf.reduce_mean(loss)
  89. return loss, y
  90. def tf_resnet18(x, y_, num_class=10):
  91. return tf_resnet(x, y_, 18, num_class)
  92. def tf_resnet34(x, y_, num_class=10):
  93. return tf_resnet(x, y_, 34, num_class)

分布式深度学习系统

Contributors (1)