You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LogReg.py 742 B

4 years ago
123456789101112131415161718192021222324
  1. import hetu as ht
  2. from hetu import init
  3. def logreg(x, y_):
  4. '''
  5. Logistic Regression model, for MNIST dataset.
  6. Parameters:
  7. x: Variable(hetu.gpu_ops.Node.Node), shape (N, dims)
  8. y_: Variable(hetu.gpu_ops.Node.Node), shape (N, num_classes)
  9. Return:
  10. loss: Variable(hetu.gpu_ops.Node.Node), shape (1,)
  11. y: Variable(hetu.gpu_ops.Node.Node), shape (N, num_classes)
  12. '''
  13. print("Build logistic regression model...")
  14. weight = init.zeros((784, 10), name='logreg_weight')
  15. bias = init.zeros((10,), name='logreg_bias')
  16. x = ht.matmul_op(x, weight)
  17. y = x + ht.broadcastto_op(bias, x)
  18. loss = ht.softmaxcrossentropy_op(y, y_)
  19. loss = ht.reduce_mean_op(loss, [0])
  20. return loss, y