Browse Source

graph_io.write_graph

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
fe06a29c97
5 changed files with 46 additions and 1 deletions
  1. +13
    -0
      src/TensorFlowNET.Core/Graphs/Graph.Export.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Graphs/Graph.cs
  3. +21
    -0
      src/TensorFlowNET.Core/Graphs/graph_io.py.cs
  4. +3
    -0
      src/TensorFlowNET.Core/Train/tf.optimizers.cs
  5. +8
    -0
      test/TensorFlowNET.UnitTest/TrainSaverTest.cs

+ 13
- 0
src/TensorFlowNET.Core/Graphs/Graph.Export.cs View File

@@ -17,5 +17,18 @@ namespace Tensorflow

return buffer;
}

public GraphDef _as_graph_def()
{
var buffer = ToGraphDef(Status);
Status.Check();
var def = GraphDef.Parser.ParseFrom(buffer);
buffer.Dispose();

// Strip the experimental library field iff it's empty.
// if(def.Library.Function.Count == 0)

return def;
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -53,7 +53,7 @@ namespace Tensorflow

return null;
}
private ITensorOrOperation _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true)
{
string types_str = "";


+ 21
- 0
src/TensorFlowNET.Core/Graphs/graph_io.py.cs View File

@@ -0,0 +1,21 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;

namespace Tensorflow
{
public class graph_io
{
public static string write_graph(Graph graph, string logdir, string name, bool as_text = true)
{
var def = graph._as_graph_def();
string path = Path.Combine(logdir, name);
string text = def.ToString();
if (as_text)
File.WriteAllText(path, text);

return path;
}
}
}

+ 3
- 0
src/TensorFlowNET.Core/Train/tf.optimizers.cs View File

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;

namespace Tensorflow
@@ -11,6 +12,8 @@ namespace Tensorflow
public static Optimizer GradientDescentOptimizer(double learning_rate) => new GradientDescentOptimizer(learning_rate);

public static Saver Saver() => new Saver();

public static string write_graph(Graph graph, string logdir, string name, bool as_text = true) => graph_io.write_graph(graph, logdir, name, as_text);
}
}
}

+ 8
- 0
test/TensorFlowNET.UnitTest/TrainSaverTest.cs View File

@@ -9,6 +9,14 @@ namespace TensorFlowNET.UnitTest
[TestClass]
public class TrainSaverTest : Python
{
[TestMethod]
public void WriteGraph()
{
var v = tf.Variable(0, name: "my_variable");
var sess = tf.Session();
tf.train.write_graph(sess.graph, "/tmp/my-model", "train.pbtxt");
}

[TestMethod]
public void Save()
{


Loading…
Cancel
Save