You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

VGG.py 3.6 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import hetu as ht
  2. from hetu import init
  3. def conv_bn_relu(x, in_channel, out_channel, name):
  4. weight = init.random_normal(shape=(out_channel, in_channel, 3, 3),
  5. stddev=0.1, name=name+'_weight')
  6. bn_scale = init.random_normal(shape=(1, out_channel, 1, 1),
  7. stddev=0.1, name=name+'_bn_scale')
  8. bn_bias = init.random_normal(shape=(1, out_channel, 1, 1),
  9. stddev=0.1, name=name+'_bn_bias')
  10. x = ht.conv2d_op(x, weight, padding=1, stride=1)
  11. x = ht.batch_normalization_op(x, bn_scale, bn_bias)
  12. act = ht.relu_op(x)
  13. return act
  14. def vgg_2block(x, in_channel, out_channel, name):
  15. x = conv_bn_relu(x, in_channel, out_channel, name=name+'_layer1')
  16. x = conv_bn_relu(x, out_channel, out_channel, name=name+'_layer2')
  17. x = ht.max_pool2d_op(x, kernel_H=2, kernel_W=2, padding=0, stride=2)
  18. return x
  19. def vgg_3block(x, in_channel, out_channel, name):
  20. x = conv_bn_relu(x, in_channel, out_channel, name=name+'_layer1')
  21. x = conv_bn_relu(x, out_channel, out_channel, name=name+'_layer2')
  22. x = conv_bn_relu(x, out_channel, out_channel, name=name+'_layer3')
  23. x = ht.max_pool2d_op(x, kernel_H=2, kernel_W=2, padding=0, stride=2)
  24. return x
  25. def vgg_4block(x, in_channel, out_channel, name):
  26. x = conv_bn_relu(x, in_channel, out_channel, name=name+'_layer1')
  27. x = conv_bn_relu(x, out_channel, out_channel, name=name+'_layer2')
  28. x = conv_bn_relu(x, out_channel, out_channel, name=name+'_layer3')
  29. x = conv_bn_relu(x, out_channel, out_channel, name=name+'_layer4')
  30. x = ht.max_pool2d_op(x, kernel_H=2, kernel_W=2, padding=0, stride=2)
  31. return x
  32. def vgg_fc(x, in_feat, out_feat, name):
  33. weight = init.random_normal(shape=(in_feat, out_feat),
  34. stddev=0.1, name=name+'_weight')
  35. bias = init.random_normal(shape=(out_feat,),
  36. stddev=0.1, name=name+'_bias')
  37. x = ht.matmul_op(x, weight)
  38. x = x + ht.broadcastto_op(bias, x)
  39. return x
  40. def vgg(x, y_, num_layers, num_class=10):
  41. '''
  42. VGG model, for CIFAR10/CIFAR100 dataset.
  43. Parameters:
  44. x: Variable(hetu.gpu_ops.Node.Node), shape (N, C, H, W)
  45. y_: Variable(hetu.gpu_ops.Node.Node), shape (N, num_classes)
  46. num_layers: 16 or 19
  47. Return:
  48. loss: Variable(hetu.gpu_ops.Node.Node), shape (1,)
  49. y: Variable(hetu.gpu_ops.Node.Node), shape (N, num_classes)
  50. '''
  51. if num_layers == 16:
  52. print('Building VGG-16 model...')
  53. x = vgg_2block(x, 3, 64, 'vgg_block1')
  54. x = vgg_2block(x, 64, 128, 'vgg_block2')
  55. x = vgg_3block(x, 128, 256, 'vgg_block3')
  56. x = vgg_3block(x, 256, 512, 'vgg_block4')
  57. x = vgg_3block(x, 512, 512, 'vgg_block5')
  58. elif num_layers == 19:
  59. print('Building VGG-19 model...')
  60. x = vgg_2block(x, 3, 64, 'vgg_block1')
  61. x = vgg_2block(x, 64, 128, 'vgg_block2')
  62. x = vgg_4block(x, 128, 256, 'vgg_block3')
  63. x = vgg_4block(x, 256, 512, 'vgg_block4')
  64. x = vgg_4block(x, 512, 512, 'vgg_block5')
  65. else:
  66. assert False, 'VGG model should have 16 or 19 layers!'
  67. x = ht.array_reshape_op(x, (-1, 512))
  68. x = vgg_fc(x, 512, 4096, 'vgg_fc1')
  69. x = vgg_fc(x, 4096, 4096, 'vgg_fc2')
  70. y = vgg_fc(x, 4096, num_class, 'vgg_fc3')
  71. loss = ht.softmaxcrossentropy_op(y, y_)
  72. loss = ht.reduce_mean_op(loss, [0])
  73. return loss, y
  74. def vgg16(x, y_, num_class=10):
  75. return vgg(x, y_, 16, num_class)
  76. def vgg19(x, y_, num_class=10):
  77. return vgg(x, y_, 19, num_class)