|
|
@@ -26,7 +26,7 @@ namespace TensorFlowNET.Examples |
|
|
|
|
|
|
|
private void PrepareData() |
|
|
|
{ |
|
|
|
var mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true); |
|
|
|
//var mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true); |
|
|
|
|
|
|
|
// tf Graph Input |
|
|
|
var x = tf.placeholder(tf.float32, new TensorShape(-1, 784)); // mnist data image of shape 28*28=784 |
|
|
@@ -40,8 +40,11 @@ namespace TensorFlowNET.Examples |
|
|
|
var pred = tf.nn.softmax(tf.matmul(x, W) + b); // Softmax |
|
|
|
|
|
|
|
// Minimize error using cross entropy |
|
|
|
var sum = -tf.reduce_sum(y * tf.log(pred), reduction_indices: 1); |
|
|
|
var cost = tf.reduce_mean(sum); |
|
|
|
var log = tf.log(pred); |
|
|
|
var mul = y * log; |
|
|
|
var sum = tf.reduce_sum(mul, reduction_indices: 1); |
|
|
|
var neg = -sum; |
|
|
|
var cost = tf.reduce_mean(neg); |
|
|
|
|
|
|
|
// Gradient Descent |
|
|
|
var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); |
|
|
|