You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ResNet.py 4.9 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import hetu as ht
  2. from hetu import init
  3. def conv2d(x, in_channel, out_channel, stride=1, padding=1, name=''):
  4. weight = init.random_normal(
  5. shape=(out_channel, in_channel, 3, 3), stddev=0.1, name=name+'_weight')
  6. x = ht.conv2d_op(x, weight, stride=stride, padding=padding)
  7. return x
  8. def batch_norm_with_relu(x, hidden, name):
  9. scale = init.random_normal(
  10. shape=(1, hidden, 1, 1), stddev=0.1, name=name+'_scale')
  11. bias = init.random_normal(shape=(1, hidden, 1, 1),
  12. stddev=0.1, name=name+'_bias')
  13. x = ht.batch_normalization_op(x, scale, bias)
  14. x = ht.relu_op(x)
  15. return x
  16. def resnet_block(x, in_channel, num_blocks, is_first=False, name=''):
  17. if is_first:
  18. out_channel = in_channel
  19. identity = x
  20. x = conv2d(x, in_channel, out_channel, stride=1,
  21. padding=1, name=name+'_conv1')
  22. x = batch_norm_with_relu(x, out_channel, name+'_bn1')
  23. x = conv2d(x, out_channel, out_channel, stride=1,
  24. padding=1, name=name+'_conv2')
  25. x = x + identity
  26. else:
  27. out_channel = 2 * in_channel
  28. identity = x
  29. x = batch_norm_with_relu(x, in_channel, name+'_bn0')
  30. x = ht.pad_op(x, [[0, 0], [0, 0], [0, 1], [0, 1]])
  31. x = conv2d(x, in_channel, out_channel, stride=2,
  32. padding=0, name=name+'_conv1')
  33. x = batch_norm_with_relu(x, out_channel, name+'_bn1')
  34. x = conv2d(x, out_channel, out_channel, stride=1,
  35. padding=1, name=name+'_conv2')
  36. identity = ht.avg_pool2d_op(
  37. identity, kernel_H=2, kernel_W=2, padding=0, stride=2)
  38. identity = ht.pad_op(
  39. identity, [[0, 0], [in_channel // 2, in_channel // 2], [0, 0], [0, 0]])
  40. x = x + identity
  41. for i in range(1, num_blocks):
  42. identity = x
  43. x = batch_norm_with_relu(x, out_channel, name+'_bn%d' % (2 * i))
  44. x = conv2d(x, out_channel, out_channel, stride=1,
  45. padding=1, name=name+'_conv%d' % (2 * i + 1))
  46. x = batch_norm_with_relu(x, out_channel, name+'_bn%d' % (2 * i + 1))
  47. x = conv2d(x, out_channel, out_channel, stride=1,
  48. padding=1, name=name+'_conv%d' % (2 * i + 2))
  49. x = x + identity
  50. return x
  51. def fc(x, shape, name):
  52. weight = init.random_normal(shape=shape, stddev=0.1, name=name+'_weight')
  53. bias = init.random_normal(shape=shape[-1:], stddev=0.1, name=name+'_bias')
  54. x = ht.matmul_op(x, weight)
  55. x = x + ht.broadcastto_op(bias, x)
  56. return x
  57. def resnet(x, y_, num_layers=18, num_class=10):
  58. '''
  59. ResNet model, for CIFAR10 dataset.
  60. Parameters:
  61. x: Variable(hetu.gpu_ops.Node.Node), shape (N, C, H, W)
  62. y_: Variable(hetu.gpu_ops.Node.Node), shape (N, num_classes)
  63. num_layers: 18 or 34
  64. Return:
  65. loss: Variable(hetu.gpu_ops.Node.Node), shape (1,)
  66. y: Variable(hetu.gpu_ops.Node.Node), shape (N, num_classes)
  67. '''
  68. base_size = 16
  69. x = conv2d(x, 3, base_size, stride=1, padding=1,
  70. name='resnet_initial_conv')
  71. x = batch_norm_with_relu(x, base_size, 'resnet_initial_bn')
  72. if num_layers == 18:
  73. print("Building ResNet-18 model...")
  74. x = resnet_block(x, base_size, num_blocks=2,
  75. is_first=True, name='resnet_block1')
  76. x = resnet_block(x, base_size, num_blocks=2,
  77. is_first=False, name='resnet_block2')
  78. x = resnet_block(x, 2 * base_size, num_blocks=2,
  79. is_first=False, name='resnet_block3')
  80. x = resnet_block(x, 4 * base_size, num_blocks=2,
  81. is_first=False, name='resnet_block4')
  82. elif num_layers == 34:
  83. print("Building ResNet-34 model...")
  84. x = resnet_block(x, base_size, num_blocks=3,
  85. is_first=True, name='resnet_block1')
  86. x = resnet_block(x, base_size, num_blocks=4,
  87. is_first=False, name='resnet_block2')
  88. x = resnet_block(x, 2 * base_size, num_blocks=6,
  89. is_first=False, name='resnet_block3')
  90. x = resnet_block(x, 4 * base_size, num_blocks=3,
  91. is_first=False, name='resnet_block4')
  92. else:
  93. assert False, "Number of layers should be 18 or 34 !"
  94. x = batch_norm_with_relu(x, 8 * base_size, 'resnet_final_bn')
  95. x = ht.array_reshape_op(x, (-1, 128 * base_size))
  96. y = fc(x, (128 * base_size, num_class), name='resnet_final_fc')
  97. # here we don't use cudnn for softmax crossentropy to avoid overflows
  98. loss = ht.softmaxcrossentropy_op(y, y_, use_cudnn=False)
  99. loss = ht.reduce_mean_op(loss, [0])
  100. return loss, y
  101. def resnet18(x, y_, num_class=10):
  102. return resnet(x, y_, 18, num_class)
  103. def resnet34(x, y_, num_class=10):
  104. return resnet(x, y_, 34, num_class)

分布式深度学习系统

Contributors (1)