diff --git a/docs/source/Graph.md b/docs/source/Graph.md index f6edbfc6..7bc473f2 100644 --- a/docs/source/Graph.md +++ b/docs/source/Graph.md @@ -21,3 +21,61 @@ A typical graph is looks like below: ![image](../assets/graph_vis_animation.gif) + + +### Save Model + +Saving the model means saving all the values of the parameters and the graph. + +```python +saver = tf.train.Saver() +saver.save(sess,'./tensorflowModel.ckpt') +``` + +After saving the model there will be four files: + +* tensorflowModel.ckpt.meta: +* tensorflowModel.ckpt.data-00000-of-00001: +* tensorflowModel.ckpt.index +* checkpoint + +We also created a protocol buffer file .pbtxt. It is human readable if you want to convert it to binary: `as_text: false`. + +* tensorflowModel.pbtxt: + +This holds a network of nodes, each representing one operation, connected to each other as inputs and outputs. + + + +### Freezing the Graph + +##### *Why we need it?* + +When we need to keep all the values of the variables and the Graph structure in a single file we have to freeze the graph. + +```csharp +from tensorflow.python.tools import freeze_graph + +freeze_graph.freeze_graph(input_graph = 'logistic_regression/tensorflowModel.pbtxt', + input_saver = "", + input_binary = False, + input_checkpoint = 'logistic_regression/tensorflowModel.ckpt', + output_node_names = "Softmax", + restore_op_name = "save/restore_all", + filename_tensor_name = "save/Const:0", + output_graph = 'frozentensorflowModel.pb', + clear_devices = True, + initializer_nodes = "") + +``` + +### Optimizing for Inference + +To Reduce the amount of computation needed when the network is used only for inferences we can remove some parts of a graph that are only needed for training. + + + +### Restoring the Model + + + diff --git a/docs/source/LogisticRegression.md b/docs/source/LogisticRegression.md new file mode 100644 index 00000000..00aa2f05 --- /dev/null +++ b/docs/source/LogisticRegression.md @@ -0,0 +1,7 @@ +# Chapter. Logistic Regression + +### What is logistic regression? + + + +The full example is [here](https://github.com/SciSharp/TensorFlow.NET/blob/master/test/TensorFlowNET.Examples/LogisticRegression.cs). \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index b41c621f..51ed7727 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -26,4 +26,5 @@ Welcome to TensorFlow.NET's documentation! Train EagerMode LinearRegression + LogisticRegression ImageRecognition \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs b/src/TensorFlowNET.Core/Framework/meta_graph.py.cs index 4e43fd7e..24be2cd6 100644 --- a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs +++ b/src/TensorFlowNET.Core/Framework/meta_graph.py.cs @@ -184,7 +184,7 @@ namespace Tensorflow // Adds graph_def or the default. if (graph_def == null) - meta_graph_def.GraphDef = graph._as_graph_def(add_shapes: true); + meta_graph_def.GraphDef = graph.as_graph_def(add_shapes: true); else meta_graph_def.GraphDef = graph_def; diff --git a/src/TensorFlowNET.Core/Graphs/FreezeGraph.cs b/src/TensorFlowNET.Core/Graphs/FreezeGraph.cs new file mode 100644 index 00000000..a77b4f3a --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/FreezeGraph.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class FreezeGraph + { + public static void freeze_graph(string input_graph, + string input_saver, + bool input_binary, + string input_checkpoint, + string output_node_names, + string restore_op_name, + string filename_tensor_name, + string output_graph, + bool clear_devices, + string initializer_nodes) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Export.cs b/src/TensorFlowNET.Core/Graphs/Graph.Export.cs index 6501de70..0ca80be3 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Export.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Export.cs @@ -18,7 +18,7 @@ namespace Tensorflow return buffer; } - public GraphDef _as_graph_def(bool add_shapes = false) + private GraphDef _as_graph_def(bool add_shapes = false) { var buffer = ToGraphDef(Status); Status.Check(); @@ -30,5 +30,8 @@ namespace Tensorflow return def; } + + public GraphDef as_graph_def(bool add_shapes = false) + => _as_graph_def(add_shapes); } } diff --git a/src/TensorFlowNET.Core/Graphs/graph_io.py.cs b/src/TensorFlowNET.Core/Graphs/graph_io.py.cs index 31f33221..7abc4cab 100644 --- a/src/TensorFlowNET.Core/Graphs/graph_io.py.cs +++ b/src/TensorFlowNET.Core/Graphs/graph_io.py.cs @@ -10,7 +10,7 @@ namespace Tensorflow { public static string write_graph(Graph graph, string logdir, string name, bool as_text = true) { - var graph_def = graph._as_graph_def(); + var graph_def = graph.as_graph_def(); string path = Path.Combine(logdir, name); if (as_text) File.WriteAllText(path, graph_def.ToString()); diff --git a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs index 01231597..626fc6e8 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs @@ -58,17 +58,24 @@ namespace Tensorflow { var value = tensor_values[j]; j += 1; - switch (value.dtype.Name) + if (value.ndim == 2) { - case "Int32": - full_values.Add(value.Data(0)); - break; - case "Single": - full_values.Add(value.Data(0)); - break; - case "Double": - full_values.Add(value.Data(0)); - break; + full_values.Add(value[0]); + } + else + { + switch (value.dtype.Name) + { + case "Int32": + full_values.Add(value.Data(0)); + break; + case "Single": + full_values.Add(value.Data(0)); + break; + case "Double": + full_values.Add(value.Data(0)); + break; + } } } i += 1; diff --git a/src/TensorFlowNET.Core/Train/Saving/Saver.cs b/src/TensorFlowNET.Core/Train/Saving/Saver.cs index d81908e7..c223ca97 100644 --- a/src/TensorFlowNET.Core/Train/Saving/Saver.cs +++ b/src/TensorFlowNET.Core/Train/Saving/Saver.cs @@ -251,7 +251,7 @@ namespace Tensorflow { return export_meta_graph( filename: filename, - graph_def: ops.get_default_graph()._as_graph_def(add_shapes: true), + graph_def: ops.get_default_graph().as_graph_def(add_shapes: true), saver_def: _saver_def, collection_list: collection_list, as_text: as_text, diff --git a/src/TensorFlowNET.Core/Train/tf.optimizers.cs b/src/TensorFlowNET.Core/Train/tf.optimizers.cs index a7a3a39b..b4925f3a 100644 --- a/src/TensorFlowNET.Core/Train/tf.optimizers.cs +++ b/src/TensorFlowNET.Core/Train/tf.optimizers.cs @@ -16,7 +16,8 @@ namespace Tensorflow public static Saver Saver() => new Saver(); - public static string write_graph(Graph graph, string logdir, string name, bool as_text = true) => graph_io.write_graph(graph, logdir, name, as_text); + public static string write_graph(Graph graph, string logdir, string name, bool as_text = true) + => graph_io.write_graph(graph, logdir, name, as_text); public static Saver import_meta_graph(string meta_graph_or_file, bool clear_devices = false, diff --git a/test/TensorFlowNET.Examples/LogisticRegression.cs b/test/TensorFlowNET.Examples/LogisticRegression.cs index e44002cc..7e2f8e3b 100644 --- a/test/TensorFlowNET.Examples/LogisticRegression.cs +++ b/test/TensorFlowNET.Examples/LogisticRegression.cs @@ -2,6 +2,7 @@ using NumSharp.Core; using System; using System.Collections.Generic; +using System.IO; using System.Linq; using System.Text; using Tensorflow; @@ -17,7 +18,7 @@ namespace TensorFlowNET.Examples public class LogisticRegression : Python, IExample { private float learning_rate = 0.01f; - private int training_epochs = 5; + private int training_epochs = 10; private int batch_size = 100; private int display_step = 1; @@ -78,6 +79,7 @@ namespace TensorFlowNET.Examples } print("Optimization Finished!"); + // SaveModel(sess); // Test model var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)); @@ -85,6 +87,8 @@ namespace TensorFlowNET.Examples var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)); float acc = accuracy.eval(new FeedItem(x, mnist.test.images), new FeedItem(y, mnist.test.labels)); print($"Accuracy: {acc.ToString("F4")}"); + + Predict(); }); } @@ -92,5 +96,49 @@ namespace TensorFlowNET.Examples { mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true); } + + public void SaveModel(Session sess) + { + var saver = tf.train.Saver(); + var save_path = saver.save(sess, "logistic_regression/model.ckpt"); + tf.train.write_graph(sess.graph, "logistic_regression", "model.pbtxt", as_text: true); + + FreezeGraph.freeze_graph(input_graph: "logistic_regression/model.pbtxt", + input_saver: "", + input_binary: false, + input_checkpoint: "logistic_regression/model.ckpt", + output_node_names: "Softmax", + restore_op_name: "save/restore_all", + filename_tensor_name: "save/Const:0", + output_graph: "logistic_regression/model.pb", + clear_devices: true, + initializer_nodes: ""); + } + + public void Predict() + { + var graph = new Graph().as_default(); + graph.Import(Path.Join("logistic_regression", "model.pb")); + + with(tf.Session(graph), sess => + { + // restoring the model + // var saver = tf.train.import_meta_graph("logistic_regression/tensorflowModel.ckpt.meta"); + // saver.restore(sess, tf.train.latest_checkpoint('logistic_regression')); + var pred = graph.OperationByName("Softmax"); + var output = pred.outputs[0]; + var x = graph.OperationByName("Placeholder"); + var input = x.outputs[0]; + + // predict + var (batch_xs, batch_ys) = mnist.train.next_batch(10); + var results = sess.run(output, new FeedItem(input, batch_xs[np.arange(1)])); + + if (results.argmax() == (batch_ys[0] as NDArray).argmax()) + print("predicted OK!"); + else + throw new ValueError("predict error, maybe 90% accuracy"); + }); + } } } diff --git a/test/TensorFlowNET.Examples/python/logistic_regression.py b/test/TensorFlowNET.Examples/python/logistic_regression.py index 338ebe5a..236d83d1 100644 --- a/test/TensorFlowNET.Examples/python/logistic_regression.py +++ b/test/TensorFlowNET.Examples/python/logistic_regression.py @@ -16,7 +16,7 @@ mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) # Parameters learning_rate = 0.01 -training_epochs = 25 +training_epochs = 10 batch_size = 100 display_step = 1 @@ -67,4 +67,34 @@ with tf.Session() as sess: correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)) # Calculate accuracy accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) - print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels})) \ No newline at end of file + print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels})) + + # predict + # results = sess.run(pred, feed_dict={x: batch_xs[:1]}) + + # save model + saver = tf.train.Saver() + save_path = saver.save(sess, "logistic_regression/model.ckpt") + tf.train.write_graph(sess.graph.as_graph_def(),'logistic_regression','model.pbtxt', as_text=True) + + freeze_graph.freeze_graph(input_graph = 'logistic_regression/model.pbtxt', + input_saver = "", + input_binary = False, + input_checkpoint = 'logistic_regression/model.ckpt', + output_node_names = "Softmax", + restore_op_name = "save/restore_all", + filename_tensor_name = "save/Const:0", + output_graph = 'logistic_regression/model.pb', + clear_devices = True, + initializer_nodes = "") + + # restoring the model + saver = tf.train.import_meta_graph('logistic_regression/tensorflowModel.ckpt.meta') + saver.restore(sess,tf.train.latest_checkpoint('logistic_regression')) + + # predict + # pred = graph._nodes_by_name["Softmax"] + # output = pred.outputs[0] + # x = graph._nodes_by_name["Placeholder"] + # input = x.outputs[0] + # results = sess.run(output, feed_dict={input: batch_xs[:1]}) \ No newline at end of file