You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LeNet.py 1.6 kB

4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. import hetu as ht
  2. from hetu import init
  3. def conv_pool(x, in_channel, out_channel, name):
  4. weight = init.random_normal(
  5. shape=(out_channel, in_channel, 5, 5), stddev=0.1, name=name+'_weight')
  6. x = ht.conv2d_op(x, weight, padding=2, stride=1)
  7. x = ht.relu_op(x)
  8. x = ht.max_pool2d_op(x, kernel_H=2, kernel_W=2, padding=0, stride=2)
  9. return x
  10. def fc(x, shape, name, with_relu=True):
  11. weight = init.random_normal(shape=shape, stddev=0.1, name=name+'_weight')
  12. bias = init.random_normal(shape=shape[-1:], stddev=0.1, name=name+'_bias')
  13. x = ht.matmul_op(x, weight)
  14. x = x + ht.broadcastto_op(bias, x)
  15. if with_relu:
  16. x = ht.relu_op(x)
  17. return x
  18. def lenet(x, y_):
  19. '''
  20. LeNet model, for MNIST dataset.
  21. Parameters:
  22. x: Variable(hetu.gpu_ops.Node.Node), shape (N, dims)
  23. y_: Variable(hetu.gpu_ops.Node.Node), shape (N, num_classes)
  24. Return:
  25. loss: Variable(hetu.gpu_ops.Node.Node), shape (1,)
  26. y: Variable(hetu.gpu_ops.Node.Node), shape (N, num_classes)
  27. '''
  28. print('Building LeNet model...')
  29. x = ht.array_reshape_op(x, (-1, 1, 28, 28))
  30. x = conv_pool(x, 1, 6, name='lenet_conv1')
  31. x = conv_pool(x, 6, 16, name='lenet_conv2')
  32. x = ht.array_reshape_op(x, (-1, 7*7*16))
  33. x = fc(x, (7*7*16, 120), name='lenet_fc1', with_relu=True)
  34. x = fc(x, (120, 84), name='lenet_fc2', with_relu=True)
  35. y = fc(x, (84, 10), name='lenet_fc3', with_relu=False)
  36. loss = ht.softmaxcrossentropy_op(y, y_)
  37. loss = ht.reduce_mean_op(loss, [0])
  38. return loss, y

分布式深度学习系统

Contributors (1)