You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CNN.py 1.2 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041
  1. import hetu as ht
  2. from hetu import init
  3. def conv_relu_avg(x, shape):
  4. weight = init.random_normal(shape=shape, stddev=0.1)
  5. x = ht.conv2d_op(x, weight, padding=2, stride=1)
  6. x = ht.relu_op(x)
  7. x = ht.avg_pool2d_op(x, kernel_H=2, kernel_W=2, padding=0, stride=2)
  8. return x
  9. def fc(x, shape):
  10. weight = init.random_normal(shape=shape, stddev=0.1)
  11. bias = init.random_normal(shape=shape[-1:], stddev=0.1)
  12. x = ht.array_reshape_op(x, (-1, shape[0]))
  13. x = ht.matmul_op(x, weight)
  14. y = x + ht.broadcastto_op(bias, x)
  15. return y
  16. def cnn_3_layers(x, y_):
  17. '''
  18. 3-layer-CNN model, for MNIST dataset.
  19. Parameters:
  20. x: Variable(hetu.gpu_ops.Node.Node), shape (N, dims)
  21. y_: Variable(hetu.gpu_ops.Node.Node), shape (N, num_classes)
  22. Return:
  23. loss: Variable(hetu.gpu_ops.Node.Node), shape (1,)
  24. y: Variable(hetu.gpu_ops.Node.Node), shape (N, num_classes)
  25. '''
  26. print('Building 3-layer-CNN model...')
  27. x = ht.array_reshape_op(x, [-1, 1, 28, 28])
  28. x = conv_relu_avg(x, (32, 1, 5, 5))
  29. x = conv_relu_avg(x, (64, 32, 5, 5))
  30. y = fc(x, (7 * 7 * 64, 10))
  31. loss = ht.softmaxcrossentropy_op(y, y_)
  32. loss = ht.reduce_mean_op(loss, [0])
  33. return loss, y