diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Export.cs b/src/TensorFlowNET.Core/Graphs/Graph.Export.cs index 707195b3..5809d78f 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Export.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Export.cs @@ -17,5 +17,18 @@ namespace Tensorflow return buffer; } + + public GraphDef _as_graph_def() + { + var buffer = ToGraphDef(Status); + Status.Check(); + var def = GraphDef.Parser.ParseFrom(buffer); + buffer.Dispose(); + + // Strip the experimental library field iff it's empty. + // if(def.Library.Function.Count == 0) + + return def; + } } } diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 42cf5111..a926a57f 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -53,7 +53,7 @@ namespace Tensorflow return null; } - + private ITensorOrOperation _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true) { string types_str = ""; diff --git a/src/TensorFlowNET.Core/Graphs/graph_io.py.cs b/src/TensorFlowNET.Core/Graphs/graph_io.py.cs new file mode 100644 index 00000000..5021bb28 --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/graph_io.py.cs @@ -0,0 +1,21 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; + +namespace Tensorflow +{ + public class graph_io + { + public static string write_graph(Graph graph, string logdir, string name, bool as_text = true) + { + var def = graph._as_graph_def(); + string path = Path.Combine(logdir, name); + string text = def.ToString(); + if (as_text) + File.WriteAllText(path, text); + + return path; + } + } +} diff --git a/src/TensorFlowNET.Core/Train/tf.optimizers.cs b/src/TensorFlowNET.Core/Train/tf.optimizers.cs index ba4fbea8..7d5f1527 100644 --- a/src/TensorFlowNET.Core/Train/tf.optimizers.cs +++ b/src/TensorFlowNET.Core/Train/tf.optimizers.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.IO; using System.Text; namespace Tensorflow @@ -11,6 +12,8 @@ namespace Tensorflow public static Optimizer GradientDescentOptimizer(double learning_rate) => new GradientDescentOptimizer(learning_rate); public static Saver Saver() => new Saver(); + + public static string write_graph(Graph graph, string logdir, string name, bool as_text = true) => graph_io.write_graph(graph, logdir, name, as_text); } } } diff --git a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs index 541bf893..b7c33e5b 100644 --- a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs +++ b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs @@ -9,6 +9,14 @@ namespace TensorFlowNET.UnitTest [TestClass] public class TrainSaverTest : Python { + [TestMethod] + public void WriteGraph() + { + var v = tf.Variable(0, name: "my_variable"); + var sess = tf.Session(); + tf.train.write_graph(sess.graph, "/tmp/my-model", "train.pbtxt"); + } + [TestMethod] public void Save() {