You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LogisticRegression.cs 2.1 kB

6 years ago
6 years ago
6 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. using NumSharp.Core;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Text;
  5. using Tensorflow;
  6. using TensorFlowNET.Examples.Utility;
  7. namespace TensorFlowNET.Examples
  8. {
  9. /// <summary>
  10. /// A logistic regression learning algorithm example using TensorFlow library.
  11. /// This example is using the MNIST database of handwritten digits
  12. /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/logistic_regression.py
  13. /// </summary>
  14. public class LogisticRegression : Python, IExample
  15. {
  16. private float learning_rate = 0.01f;
  17. private int training_epochs = 25;
  18. private int batch_size = 100;
  19. private int display_step = 1;
  20. public void Run()
  21. {
  22. PrepareData();
  23. }
  24. private void PrepareData()
  25. {
  26. //var mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true);
  27. // tf Graph Input
  28. var x = tf.placeholder(tf.float32, new TensorShape(-1, 784)); // mnist data image of shape 28*28=784
  29. var y = tf.placeholder(tf.float32, new TensorShape(-1, 10)); // 0-9 digits recognition => 10 classes
  30. // Set model weights
  31. var W = tf.Variable(tf.zeros(new Shape(784, 10)));
  32. var b = tf.Variable(tf.zeros(new Shape(10)));
  33. // Construct model
  34. var pred = tf.nn.softmax(tf.matmul(x, W) + b); // Softmax
  35. // Minimize error using cross entropy
  36. var log = tf.log(pred);
  37. var mul = y * log;
  38. var sum = tf.reduce_sum(mul, reduction_indices: 1);
  39. var neg = -sum;
  40. var cost = tf.reduce_mean(neg);
  41. // Gradient Descent
  42. var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost);
  43. // Initialize the variables (i.e. assign their default value)
  44. var init = tf.global_variables_initializer();
  45. with(tf.Session(), sess =>
  46. {
  47. // Run the initializer
  48. sess.run(init);
  49. });
  50. }
  51. }
  52. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。