You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

AlexNet.py 2.4 kB

4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import hetu as ht
  2. from hetu import init
  3. def conv_bn_relu_pool(x, in_channel, out_channel, name, with_relu=True, with_pool=False):
  4. weight = init.random_normal(
  5. shape=(out_channel, in_channel, 3, 3), stddev=0.1, name=name+'_weight')
  6. bn_scale = init.random_normal(
  7. shape=(1, out_channel, 1, 1), stddev=0.1, name=name+'_bn_scale')
  8. bn_bias = init.random_normal(
  9. shape=(1, out_channel, 1, 1), stddev=0.1, name=name+'_bn_bias')
  10. x = ht.conv2d_op(x, weight, stride=1, padding=1)
  11. x = ht.batch_normalization_op(x, bn_scale, bn_bias)
  12. if with_relu:
  13. x = ht.relu_op(x)
  14. if with_pool:
  15. x = ht.max_pool2d_op(x, kernel_H=2, kernel_W=2, stride=2, padding=0)
  16. return x
  17. def fc(x, shape, name, with_relu=True):
  18. weight = init.random_normal(shape=shape, stddev=0.1, name=name+'_weight')
  19. bias = init.random_normal(shape=shape[-1:], stddev=0.1, name=name+'_bias')
  20. x = ht.matmul_op(x, weight)
  21. x = x + ht.broadcastto_op(bias, x)
  22. if with_relu:
  23. x = ht.relu_op(x)
  24. return x
  25. def alexnet(x, y_):
  26. '''
  27. AlexNet model, for MNIST dataset.
  28. Parameters:
  29. x: Variable(hetu.gpu_ops.Node.Node), shape (N, dims)
  30. y_: Variable(hetu.gpu_ops.Node.Node), shape (N, num_classes)
  31. Return:
  32. loss: Variable(hetu.gpu_ops.Node.Node), shape (1,)
  33. y: Variable(hetu.gpu_ops.Node.Node), shape (N, num_classes)
  34. '''
  35. print('Building AlexNet model...')
  36. x = ht.array_reshape_op(x, [-1, 1, 28, 28])
  37. x = conv_bn_relu_pool(x, 1, 32, 'alexnet_conv1',
  38. with_relu=True, with_pool=True)
  39. x = conv_bn_relu_pool(x, 32, 64, 'alexnet_conv2',
  40. with_relu=True, with_pool=True)
  41. x = conv_bn_relu_pool(x, 64, 128, 'alexnet_conv3',
  42. with_relu=True, with_pool=False)
  43. x = conv_bn_relu_pool(x, 128, 256, 'alexnet_conv4',
  44. with_relu=True, with_pool=False)
  45. x = conv_bn_relu_pool(x, 256, 256, 'alexnet_conv5',
  46. with_relu=False, with_pool=True)
  47. x = ht.array_reshape_op(x, (-1, 256*3*3))
  48. x = fc(x, (256*3*3, 1024), name='alexnet_fc1', with_relu=True)
  49. x = fc(x, (1024, 512), name='alexnet_fc2', with_relu=True)
  50. y = fc(x, (512, 10), name='alexnet_fc3', with_relu=False)
  51. loss = ht.softmaxcrossentropy_op(y, y_)
  52. loss = ht.reduce_mean_op(loss, [0])
  53. return loss, y