@@ -21,3 +21,61 @@ A typical graph is looks like below: | |||
 | |||
### Save Model | |||
Saving the model means saving all the values of the parameters and the graph. | |||
```python | |||
saver = tf.train.Saver() | |||
saver.save(sess,'./tensorflowModel.ckpt') | |||
``` | |||
After saving the model there will be four files: | |||
* tensorflowModel.ckpt.meta: | |||
* tensorflowModel.ckpt.data-00000-of-00001: | |||
* tensorflowModel.ckpt.index | |||
* checkpoint | |||
We also created a protocol buffer file .pbtxt. It is human readable if you want to convert it to binary: `as_text: false`. | |||
* tensorflowModel.pbtxt: | |||
This holds a network of nodes, each representing one operation, connected to each other as inputs and outputs. | |||
### Freezing the Graph | |||
##### *Why we need it?* | |||
When we need to keep all the values of the variables and the Graph structure in a single file we have to freeze the graph. | |||
```csharp | |||
from tensorflow.python.tools import freeze_graph | |||
freeze_graph.freeze_graph(input_graph = 'logistic_regression/tensorflowModel.pbtxt', | |||
input_saver = "", | |||
input_binary = False, | |||
input_checkpoint = 'logistic_regression/tensorflowModel.ckpt', | |||
output_node_names = "Softmax", | |||
restore_op_name = "save/restore_all", | |||
filename_tensor_name = "save/Const:0", | |||
output_graph = 'frozentensorflowModel.pb', | |||
clear_devices = True, | |||
initializer_nodes = "") | |||
``` | |||
### Optimizing for Inference | |||
To Reduce the amount of computation needed when the network is used only for inferences we can remove some parts of a graph that are only needed for training. | |||
### Restoring the Model | |||
@@ -0,0 +1,7 @@ | |||
# Chapter. Logistic Regression | |||
### What is logistic regression? | |||
The full example is [here](https://github.com/SciSharp/TensorFlow.NET/blob/master/test/TensorFlowNET.Examples/LogisticRegression.cs). |
@@ -26,4 +26,5 @@ Welcome to TensorFlow.NET's documentation! | |||
Train | |||
EagerMode | |||
LinearRegression | |||
LogisticRegression | |||
ImageRecognition |
@@ -184,7 +184,7 @@ namespace Tensorflow | |||
// Adds graph_def or the default. | |||
if (graph_def == null) | |||
meta_graph_def.GraphDef = graph._as_graph_def(add_shapes: true); | |||
meta_graph_def.GraphDef = graph.as_graph_def(add_shapes: true); | |||
else | |||
meta_graph_def.GraphDef = graph_def; | |||
@@ -0,0 +1,23 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public class FreezeGraph | |||
{ | |||
public static void freeze_graph(string input_graph, | |||
string input_saver, | |||
bool input_binary, | |||
string input_checkpoint, | |||
string output_node_names, | |||
string restore_op_name, | |||
string filename_tensor_name, | |||
string output_graph, | |||
bool clear_devices, | |||
string initializer_nodes) | |||
{ | |||
} | |||
} | |||
} |
@@ -18,7 +18,7 @@ namespace Tensorflow | |||
return buffer; | |||
} | |||
public GraphDef _as_graph_def(bool add_shapes = false) | |||
private GraphDef _as_graph_def(bool add_shapes = false) | |||
{ | |||
var buffer = ToGraphDef(Status); | |||
Status.Check(); | |||
@@ -30,5 +30,8 @@ namespace Tensorflow | |||
return def; | |||
} | |||
public GraphDef as_graph_def(bool add_shapes = false) | |||
=> _as_graph_def(add_shapes); | |||
} | |||
} |
@@ -10,7 +10,7 @@ namespace Tensorflow | |||
{ | |||
public static string write_graph(Graph graph, string logdir, string name, bool as_text = true) | |||
{ | |||
var graph_def = graph._as_graph_def(); | |||
var graph_def = graph.as_graph_def(); | |||
string path = Path.Combine(logdir, name); | |||
if (as_text) | |||
File.WriteAllText(path, graph_def.ToString()); | |||
@@ -58,17 +58,24 @@ namespace Tensorflow | |||
{ | |||
var value = tensor_values[j]; | |||
j += 1; | |||
switch (value.dtype.Name) | |||
if (value.ndim == 2) | |||
{ | |||
case "Int32": | |||
full_values.Add(value.Data<int>(0)); | |||
break; | |||
case "Single": | |||
full_values.Add(value.Data<float>(0)); | |||
break; | |||
case "Double": | |||
full_values.Add(value.Data<double>(0)); | |||
break; | |||
full_values.Add(value[0]); | |||
} | |||
else | |||
{ | |||
switch (value.dtype.Name) | |||
{ | |||
case "Int32": | |||
full_values.Add(value.Data<int>(0)); | |||
break; | |||
case "Single": | |||
full_values.Add(value.Data<float>(0)); | |||
break; | |||
case "Double": | |||
full_values.Add(value.Data<double>(0)); | |||
break; | |||
} | |||
} | |||
} | |||
i += 1; | |||
@@ -251,7 +251,7 @@ namespace Tensorflow | |||
{ | |||
return export_meta_graph( | |||
filename: filename, | |||
graph_def: ops.get_default_graph()._as_graph_def(add_shapes: true), | |||
graph_def: ops.get_default_graph().as_graph_def(add_shapes: true), | |||
saver_def: _saver_def, | |||
collection_list: collection_list, | |||
as_text: as_text, | |||
@@ -16,7 +16,8 @@ namespace Tensorflow | |||
public static Saver Saver() => new Saver(); | |||
public static string write_graph(Graph graph, string logdir, string name, bool as_text = true) => graph_io.write_graph(graph, logdir, name, as_text); | |||
public static string write_graph(Graph graph, string logdir, string name, bool as_text = true) | |||
=> graph_io.write_graph(graph, logdir, name, as_text); | |||
public static Saver import_meta_graph(string meta_graph_or_file, | |||
bool clear_devices = false, | |||
@@ -2,6 +2,7 @@ | |||
using NumSharp.Core; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.IO; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow; | |||
@@ -17,7 +18,7 @@ namespace TensorFlowNET.Examples | |||
public class LogisticRegression : Python, IExample | |||
{ | |||
private float learning_rate = 0.01f; | |||
private int training_epochs = 5; | |||
private int training_epochs = 10; | |||
private int batch_size = 100; | |||
private int display_step = 1; | |||
@@ -78,6 +79,7 @@ namespace TensorFlowNET.Examples | |||
} | |||
print("Optimization Finished!"); | |||
// SaveModel(sess); | |||
// Test model | |||
var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)); | |||
@@ -85,6 +87,8 @@ namespace TensorFlowNET.Examples | |||
var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)); | |||
float acc = accuracy.eval(new FeedItem(x, mnist.test.images), new FeedItem(y, mnist.test.labels)); | |||
print($"Accuracy: {acc.ToString("F4")}"); | |||
Predict(); | |||
}); | |||
} | |||
@@ -92,5 +96,49 @@ namespace TensorFlowNET.Examples | |||
{ | |||
mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true); | |||
} | |||
public void SaveModel(Session sess) | |||
{ | |||
var saver = tf.train.Saver(); | |||
var save_path = saver.save(sess, "logistic_regression/model.ckpt"); | |||
tf.train.write_graph(sess.graph, "logistic_regression", "model.pbtxt", as_text: true); | |||
FreezeGraph.freeze_graph(input_graph: "logistic_regression/model.pbtxt", | |||
input_saver: "", | |||
input_binary: false, | |||
input_checkpoint: "logistic_regression/model.ckpt", | |||
output_node_names: "Softmax", | |||
restore_op_name: "save/restore_all", | |||
filename_tensor_name: "save/Const:0", | |||
output_graph: "logistic_regression/model.pb", | |||
clear_devices: true, | |||
initializer_nodes: ""); | |||
} | |||
public void Predict() | |||
{ | |||
var graph = new Graph().as_default(); | |||
graph.Import(Path.Join("logistic_regression", "model.pb")); | |||
with(tf.Session(graph), sess => | |||
{ | |||
// restoring the model | |||
// var saver = tf.train.import_meta_graph("logistic_regression/tensorflowModel.ckpt.meta"); | |||
// saver.restore(sess, tf.train.latest_checkpoint('logistic_regression')); | |||
var pred = graph.OperationByName("Softmax"); | |||
var output = pred.outputs[0]; | |||
var x = graph.OperationByName("Placeholder"); | |||
var input = x.outputs[0]; | |||
// predict | |||
var (batch_xs, batch_ys) = mnist.train.next_batch(10); | |||
var results = sess.run(output, new FeedItem(input, batch_xs[np.arange(1)])); | |||
if (results.argmax() == (batch_ys[0] as NDArray).argmax()) | |||
print("predicted OK!"); | |||
else | |||
throw new ValueError("predict error, maybe 90% accuracy"); | |||
}); | |||
} | |||
} | |||
} |
@@ -16,7 +16,7 @@ mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) | |||
# Parameters | |||
learning_rate = 0.01 | |||
training_epochs = 25 | |||
training_epochs = 10 | |||
batch_size = 100 | |||
display_step = 1 | |||
@@ -67,4 +67,34 @@ with tf.Session() as sess: | |||
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)) | |||
# Calculate accuracy | |||
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) | |||
print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels})) | |||
print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels})) | |||
# predict | |||
# results = sess.run(pred, feed_dict={x: batch_xs[:1]}) | |||
# save model | |||
saver = tf.train.Saver() | |||
save_path = saver.save(sess, "logistic_regression/model.ckpt") | |||
tf.train.write_graph(sess.graph.as_graph_def(),'logistic_regression','model.pbtxt', as_text=True) | |||
freeze_graph.freeze_graph(input_graph = 'logistic_regression/model.pbtxt', | |||
input_saver = "", | |||
input_binary = False, | |||
input_checkpoint = 'logistic_regression/model.ckpt', | |||
output_node_names = "Softmax", | |||
restore_op_name = "save/restore_all", | |||
filename_tensor_name = "save/Const:0", | |||
output_graph = 'logistic_regression/model.pb', | |||
clear_devices = True, | |||
initializer_nodes = "") | |||
# restoring the model | |||
saver = tf.train.import_meta_graph('logistic_regression/tensorflowModel.ckpt.meta') | |||
saver.restore(sess,tf.train.latest_checkpoint('logistic_regression')) | |||
# predict | |||
# pred = graph._nodes_by_name["Softmax"] | |||
# output = pred.outputs[0] | |||
# x = graph._nodes_by_name["Placeholder"] | |||
# input = x.outputs[0] | |||
# results = sess.run(output, feed_dict={input: batch_xs[:1]}) |