You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MLP.py 1.0 kB

4 years ago
123456789101112131415161718192021222324252627282930313233
  1. import hetu as ht
  2. from hetu import init
  3. def fc(x, shape, name, with_relu=True):
  4. weight = init.random_normal(shape=shape, stddev=0.1, name=name+'_weight')
  5. bias = init.random_normal(shape=shape[-1:], stddev=0.1, name=name+'_bias')
  6. x = ht.matmul_op(x, weight)
  7. x = x + ht.broadcastto_op(bias, x)
  8. if with_relu:
  9. x = ht.relu_op(x)
  10. return x
  11. def mlp(x, y_):
  12. '''
  13. MLP model, for MNIST dataset.
  14. Parameters:
  15. x: Variable(hetu.gpu_ops.Node.Node), shape (N, dims)
  16. y_: Variable(hetu.gpu_ops.Node.Node), shape (N, num_classes)
  17. Return:
  18. loss: Variable(hetu.gpu_ops.Node.Node), shape (1,)
  19. y: Variable(hetu.gpu_ops.Node.Node), shape (N, num_classes)
  20. '''
  21. print("Building MLP model...")
  22. x = fc(x, (784, 256), 'mlp_fc1', with_relu=True)
  23. x = fc(x, (256, 256), 'mlp_fc2', with_relu=True)
  24. y = fc(x, (256, 10), 'mlp_fc3', with_relu=False)
  25. loss = ht.softmaxcrossentropy_op(y, y_)
  26. loss = ht.reduce_mean_op(loss, [0])
  27. return loss, y