Browse Source

add predict for logistic regression.

tags/v0.8.0
haiping008 6 years ago
parent
commit
476f8cfad0
12 changed files with 196 additions and 18 deletions
  1. +58
    -0
      docs/source/Graph.md
  2. +7
    -0
      docs/source/LogisticRegression.md
  3. +1
    -0
      docs/source/index.rst
  4. +1
    -1
      src/TensorFlowNET.Core/Framework/meta_graph.py.cs
  5. +23
    -0
      src/TensorFlowNET.Core/Graphs/FreezeGraph.cs
  6. +4
    -1
      src/TensorFlowNET.Core/Graphs/Graph.Export.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Graphs/graph_io.py.cs
  8. +17
    -10
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  9. +1
    -1
      src/TensorFlowNET.Core/Train/Saving/Saver.cs
  10. +2
    -1
      src/TensorFlowNET.Core/Train/tf.optimizers.cs
  11. +49
    -1
      test/TensorFlowNET.Examples/LogisticRegression.cs
  12. +32
    -2
      test/TensorFlowNET.Examples/python/logistic_regression.py

+ 58
- 0
docs/source/Graph.md View File

@@ -21,3 +21,61 @@ A typical graph is looks like below:

![image](../assets/graph_vis_animation.gif)



### Save Model

Saving the model means saving all the values of the parameters and the graph.

```python
saver = tf.train.Saver()
saver.save(sess,'./tensorflowModel.ckpt')
```

After saving the model there will be four files:

* tensorflowModel.ckpt.meta:
* tensorflowModel.ckpt.data-00000-of-00001:
* tensorflowModel.ckpt.index
* checkpoint

We also created a protocol buffer file .pbtxt. It is human readable if you want to convert it to binary: `as_text: false`.

* tensorflowModel.pbtxt:

This holds a network of nodes, each representing one operation, connected to each other as inputs and outputs.



### Freezing the Graph

##### *Why we need it?*

When we need to keep all the values of the variables and the Graph structure in a single file we have to freeze the graph.

```csharp
from tensorflow.python.tools import freeze_graph

freeze_graph.freeze_graph(input_graph = 'logistic_regression/tensorflowModel.pbtxt',
input_saver = "",
input_binary = False,
input_checkpoint = 'logistic_regression/tensorflowModel.ckpt',
output_node_names = "Softmax",
restore_op_name = "save/restore_all",
filename_tensor_name = "save/Const:0",
output_graph = 'frozentensorflowModel.pb',
clear_devices = True,
initializer_nodes = "")

```

### Optimizing for Inference

To Reduce the amount of computation needed when the network is used only for inferences we can remove some parts of a graph that are only needed for training.



### Restoring the Model




+ 7
- 0
docs/source/LogisticRegression.md View File

@@ -0,0 +1,7 @@
# Chapter. Logistic Regression

### What is logistic regression?



The full example is [here](https://github.com/SciSharp/TensorFlow.NET/blob/master/test/TensorFlowNET.Examples/LogisticRegression.cs).

+ 1
- 0
docs/source/index.rst View File

@@ -26,4 +26,5 @@ Welcome to TensorFlow.NET's documentation!
Train
EagerMode
LinearRegression
LogisticRegression
ImageRecognition

+ 1
- 1
src/TensorFlowNET.Core/Framework/meta_graph.py.cs View File

@@ -184,7 +184,7 @@ namespace Tensorflow

// Adds graph_def or the default.
if (graph_def == null)
meta_graph_def.GraphDef = graph._as_graph_def(add_shapes: true);
meta_graph_def.GraphDef = graph.as_graph_def(add_shapes: true);
else
meta_graph_def.GraphDef = graph_def;



+ 23
- 0
src/TensorFlowNET.Core/Graphs/FreezeGraph.cs View File

@@ -0,0 +1,23 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class FreezeGraph
{
public static void freeze_graph(string input_graph,
string input_saver,
bool input_binary,
string input_checkpoint,
string output_node_names,
string restore_op_name,
string filename_tensor_name,
string output_graph,
bool clear_devices,
string initializer_nodes)
{

}
}
}

+ 4
- 1
src/TensorFlowNET.Core/Graphs/Graph.Export.cs View File

@@ -18,7 +18,7 @@ namespace Tensorflow
return buffer;
}

public GraphDef _as_graph_def(bool add_shapes = false)
private GraphDef _as_graph_def(bool add_shapes = false)
{
var buffer = ToGraphDef(Status);
Status.Check();
@@ -30,5 +30,8 @@ namespace Tensorflow

return def;
}

public GraphDef as_graph_def(bool add_shapes = false)
=> _as_graph_def(add_shapes);
}
}

+ 1
- 1
src/TensorFlowNET.Core/Graphs/graph_io.py.cs View File

@@ -10,7 +10,7 @@ namespace Tensorflow
{
public static string write_graph(Graph graph, string logdir, string name, bool as_text = true)
{
var graph_def = graph._as_graph_def();
var graph_def = graph.as_graph_def();
string path = Path.Combine(logdir, name);
if (as_text)
File.WriteAllText(path, graph_def.ToString());


+ 17
- 10
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -58,17 +58,24 @@ namespace Tensorflow
{
var value = tensor_values[j];
j += 1;
switch (value.dtype.Name)
if (value.ndim == 2)
{
case "Int32":
full_values.Add(value.Data<int>(0));
break;
case "Single":
full_values.Add(value.Data<float>(0));
break;
case "Double":
full_values.Add(value.Data<double>(0));
break;
full_values.Add(value[0]);
}
else
{
switch (value.dtype.Name)
{
case "Int32":
full_values.Add(value.Data<int>(0));
break;
case "Single":
full_values.Add(value.Data<float>(0));
break;
case "Double":
full_values.Add(value.Data<double>(0));
break;
}
}
}
i += 1;


+ 1
- 1
src/TensorFlowNET.Core/Train/Saving/Saver.cs View File

@@ -251,7 +251,7 @@ namespace Tensorflow
{
return export_meta_graph(
filename: filename,
graph_def: ops.get_default_graph()._as_graph_def(add_shapes: true),
graph_def: ops.get_default_graph().as_graph_def(add_shapes: true),
saver_def: _saver_def,
collection_list: collection_list,
as_text: as_text,


+ 2
- 1
src/TensorFlowNET.Core/Train/tf.optimizers.cs View File

@@ -16,7 +16,8 @@ namespace Tensorflow

public static Saver Saver() => new Saver();

public static string write_graph(Graph graph, string logdir, string name, bool as_text = true) => graph_io.write_graph(graph, logdir, name, as_text);
public static string write_graph(Graph graph, string logdir, string name, bool as_text = true)
=> graph_io.write_graph(graph, logdir, name, as_text);

public static Saver import_meta_graph(string meta_graph_or_file,
bool clear_devices = false,


+ 49
- 1
test/TensorFlowNET.Examples/LogisticRegression.cs View File

@@ -2,6 +2,7 @@
using NumSharp.Core;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using Tensorflow;
@@ -17,7 +18,7 @@ namespace TensorFlowNET.Examples
public class LogisticRegression : Python, IExample
{
private float learning_rate = 0.01f;
private int training_epochs = 5;
private int training_epochs = 10;
private int batch_size = 100;
private int display_step = 1;

@@ -78,6 +79,7 @@ namespace TensorFlowNET.Examples
}

print("Optimization Finished!");
// SaveModel(sess);

// Test model
var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1));
@@ -85,6 +87,8 @@ namespace TensorFlowNET.Examples
var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32));
float acc = accuracy.eval(new FeedItem(x, mnist.test.images), new FeedItem(y, mnist.test.labels));
print($"Accuracy: {acc.ToString("F4")}");

Predict();
});
}

@@ -92,5 +96,49 @@ namespace TensorFlowNET.Examples
{
mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true);
}

public void SaveModel(Session sess)
{
var saver = tf.train.Saver();
var save_path = saver.save(sess, "logistic_regression/model.ckpt");
tf.train.write_graph(sess.graph, "logistic_regression", "model.pbtxt", as_text: true);

FreezeGraph.freeze_graph(input_graph: "logistic_regression/model.pbtxt",
input_saver: "",
input_binary: false,
input_checkpoint: "logistic_regression/model.ckpt",
output_node_names: "Softmax",
restore_op_name: "save/restore_all",
filename_tensor_name: "save/Const:0",
output_graph: "logistic_regression/model.pb",
clear_devices: true,
initializer_nodes: "");
}

public void Predict()
{
var graph = new Graph().as_default();
graph.Import(Path.Join("logistic_regression", "model.pb"));

with(tf.Session(graph), sess =>
{
// restoring the model
// var saver = tf.train.import_meta_graph("logistic_regression/tensorflowModel.ckpt.meta");
// saver.restore(sess, tf.train.latest_checkpoint('logistic_regression'));
var pred = graph.OperationByName("Softmax");
var output = pred.outputs[0];
var x = graph.OperationByName("Placeholder");
var input = x.outputs[0];

// predict
var (batch_xs, batch_ys) = mnist.train.next_batch(10);
var results = sess.run(output, new FeedItem(input, batch_xs[np.arange(1)]));

if (results.argmax() == (batch_ys[0] as NDArray).argmax())
print("predicted OK!");
else
throw new ValueError("predict error, maybe 90% accuracy");
});
}
}
}

+ 32
- 2
test/TensorFlowNET.Examples/python/logistic_regression.py View File

@@ -16,7 +16,7 @@ mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)

# Parameters
learning_rate = 0.01
training_epochs = 25
training_epochs = 10
batch_size = 100
display_step = 1

@@ -67,4 +67,34 @@ with tf.Session() as sess:
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
# Calculate accuracy
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))

# predict
# results = sess.run(pred, feed_dict={x: batch_xs[:1]})

# save model
saver = tf.train.Saver()
save_path = saver.save(sess, "logistic_regression/model.ckpt")
tf.train.write_graph(sess.graph.as_graph_def(),'logistic_regression','model.pbtxt', as_text=True)

freeze_graph.freeze_graph(input_graph = 'logistic_regression/model.pbtxt',
input_saver = "",
input_binary = False,
input_checkpoint = 'logistic_regression/model.ckpt',
output_node_names = "Softmax",
restore_op_name = "save/restore_all",
filename_tensor_name = "save/Const:0",
output_graph = 'logistic_regression/model.pb',
clear_devices = True,
initializer_nodes = "")

# restoring the model
saver = tf.train.import_meta_graph('logistic_regression/tensorflowModel.ckpt.meta')
saver.restore(sess,tf.train.latest_checkpoint('logistic_regression'))

# predict
# pred = graph._nodes_by_name["Softmax"]
# output = pred.outputs[0]
# x = graph._nodes_by_name["Placeholder"]
# input = x.outputs[0]
# results = sess.run(output, feed_dict={input: batch_xs[:1]})

Loading…
Cancel
Save