@@ -17,5 +17,18 @@ namespace Tensorflow | |||||
return buffer; | return buffer; | ||||
} | } | ||||
public GraphDef _as_graph_def() | |||||
{ | |||||
var buffer = ToGraphDef(Status); | |||||
Status.Check(); | |||||
var def = GraphDef.Parser.ParseFrom(buffer); | |||||
buffer.Dispose(); | |||||
// Strip the experimental library field iff it's empty. | |||||
// if(def.Library.Function.Count == 0) | |||||
return def; | |||||
} | |||||
} | } | ||||
} | } |
@@ -53,7 +53,7 @@ namespace Tensorflow | |||||
return null; | return null; | ||||
} | } | ||||
private ITensorOrOperation _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true) | private ITensorOrOperation _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true) | ||||
{ | { | ||||
string types_str = ""; | string types_str = ""; | ||||
@@ -0,0 +1,21 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.IO; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class graph_io | |||||
{ | |||||
public static string write_graph(Graph graph, string logdir, string name, bool as_text = true) | |||||
{ | |||||
var def = graph._as_graph_def(); | |||||
string path = Path.Combine(logdir, name); | |||||
string text = def.ToString(); | |||||
if (as_text) | |||||
File.WriteAllText(path, text); | |||||
return path; | |||||
} | |||||
} | |||||
} |
@@ -1,5 +1,6 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.IO; | |||||
using System.Text; | using System.Text; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -11,6 +12,8 @@ namespace Tensorflow | |||||
public static Optimizer GradientDescentOptimizer(double learning_rate) => new GradientDescentOptimizer(learning_rate); | public static Optimizer GradientDescentOptimizer(double learning_rate) => new GradientDescentOptimizer(learning_rate); | ||||
public static Saver Saver() => new Saver(); | public static Saver Saver() => new Saver(); | ||||
public static string write_graph(Graph graph, string logdir, string name, bool as_text = true) => graph_io.write_graph(graph, logdir, name, as_text); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -9,6 +9,14 @@ namespace TensorFlowNET.UnitTest | |||||
[TestClass] | [TestClass] | ||||
public class TrainSaverTest : Python | public class TrainSaverTest : Python | ||||
{ | { | ||||
[TestMethod] | |||||
public void WriteGraph() | |||||
{ | |||||
var v = tf.Variable(0, name: "my_variable"); | |||||
var sess = tf.Session(); | |||||
tf.train.write_graph(sess.graph, "/tmp/my-model", "train.pbtxt"); | |||||
} | |||||
[TestMethod] | [TestMethod] | ||||
public void Save() | public void Save() | ||||
{ | { | ||||