You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tf_LogReg.py 856 B

4 years ago
1234567891011121314151617181920212223
  1. import numpy as np
  2. import tensorflow as tf
  3. def tf_logreg(x, y_):
  4. '''
  5. Logistic Regression model in TensorFlow, for MNIST dataset.
  6. Parameters:
  7. x: Variable(tensorflow.python.framework.ops.Tensor), shape (N, dims)
  8. y_: Variable(tensorflow.python.framework.ops.Tensor), shape (N, num_classes)
  9. Return:
  10. loss: Variable(tensorflow.python.framework.ops.Tensor), shape (1,)
  11. y: Variable(tensorflow.python.framework.ops.Tensor), shape (N, num_classes)
  12. '''
  13. print("Build logistic regression model in tensorflow...")
  14. weight = tf.Variable(np.zeros(shape=(784, 10)).astype(np.float32))
  15. bias = tf.Variable(np.zeros(shape=(10, )).astype(np.float32))
  16. y = tf.matmul(x, weight) + bias
  17. loss = tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_)
  18. loss = tf.reduce_mean(loss)
  19. return loss, y

分布式深度学习系统

Contributors (1)