Browse Source

import_graph_def

tags/v0.8.0
haiping008 6 years ago
parent
commit
fdf8231a72
12 changed files with 300 additions and 23 deletions
  1. +22
    -0
      src/TensorFlowNET.Core/Framework/c_api_util.py.cs
  2. +158
    -0
      src/TensorFlowNET.Core/Framework/importer.py.cs
  3. +54
    -0
      src/TensorFlowNET.Core/Framework/meta_graph.py.cs
  4. +12
    -0
      src/TensorFlowNET.Core/Graphs/c_api.graph.cs
  5. +7
    -4
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  6. +7
    -0
      src/TensorFlowNET.Core/Train/Saving/Saver.cs
  7. +27
    -0
      src/TensorFlowNET.Core/Train/Saving/saver.py.cs
  8. +6
    -0
      src/TensorFlowNET.Core/Train/tf.optimizers.cs
  9. +0
    -14
      src/TensorFlowNET.Core/c_api_util.cs
  10. +1
    -1
      test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
  11. +1
    -1
      test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj
  12. +5
    -3
      test/TensorFlowNET.UnitTest/TrainSaverTest.cs

+ 22
- 0
src/TensorFlowNET.Core/Framework/c_api_util.py.cs View File

@@ -0,0 +1,22 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class c_api_util
{
public static TF_Output tf_output(IntPtr c_op, int index) => new TF_Output(c_op, index);

public static ImportGraphDefOptions ScopedTFImportGraphDefOptions() => new ImportGraphDefOptions();

public static IntPtr tf_buffer(byte[] data)
{
if (data != null)
throw new NotImplementedException("");
// var buf = c_api.TF_NewBufferFromString(data);
else
throw new NotImplementedException("");
}
}
}

+ 158
- 0
src/TensorFlowNET.Core/Framework/importer.py.cs View File

@@ -0,0 +1,158 @@
using Google.Protobuf;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using static Tensorflow.OpDef.Types;

namespace Tensorflow
{
public class importer
{
public static ITensorOrOperation[] import_graph_def(GraphDef graph_def,
Dictionary<string, Tensor> input_map = null,
string[] return_elements = null,
string name = "",
OpList producer_op_list = null)
{
var op_dict = op_def_registry.get_registered_ops();

graph_def = _ProcessGraphDefParam(graph_def, op_dict);
input_map = _ProcessInputMapParam(input_map);
return_elements = _ProcessReturnElementsParam(return_elements);

if (producer_op_list != null)
_RemoveDefaultAttrs(op_dict, producer_op_list, graph_def);

string prefix = "";
var graph = ops.get_default_graph();
Python.with<ops.name_scope>(new ops.name_scope(name, "import", input_map.Values), scope =>
{
/*prefix = scope;
if (!string.IsNullOrEmpty(prefix))
prefix = prefix.Substring(0, prefix.Length - 1);
else
prefix = "";*/

// Generate any input map tensors inside name scope
input_map = _ConvertInputMapValues(name, input_map);
});

var scoped_options = c_api_util.ScopedTFImportGraphDefOptions();
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements);

var bytes = graph_def.ToByteString().ToArray();

var status = new Status();
c_api.TF_GraphImportGraphDefWithResults(graph, IntPtr.Zero, scoped_options, status);

throw new NotImplementedException("importer.import_graph_def");
}

public static void _PopulateTFImportGraphDefOptions(ImportGraphDefOptions options,
string prefix,
Dictionary<string, Tensor> input_map,
string[] return_elements)
{
c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix);
c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, (char)1);

foreach(var input in input_map)
{
throw new NotImplementedException("_PopulateTFImportGraphDefOptions");
}

if (return_elements == null)
return_elements = new string[0];

foreach (var name in return_elements)
{
throw new NotImplementedException("_PopulateTFImportGraphDefOptions");
}
}

public static Dictionary<string, Tensor> _ConvertInputMapValues(string name, Dictionary<string, Tensor> input_map)
{
return input_map;
}

public static GraphDef _ProcessGraphDefParam(GraphDef graph_def, Dictionary<string, OpDef> op_dict)
{
foreach(var node in graph_def.Node)
{
if (!op_dict.ContainsKey(node.Op))
continue;

var op_def = op_dict[node.Op];
_SetDefaultAttrValues(node, op_def);
}

return graph_def;
}

private static void _SetDefaultAttrValues(NodeDef node_def, OpDef op_def)
{
foreach(var attr_def in op_def.Attr)
{
var key = attr_def.Name;
if(attr_def.DefaultValue != null)
{
var value = node_def.Attr[key];
if (value == null)
node_def.Attr[key] = attr_def.DefaultValue;
}
}
}

private static Dictionary<string, Tensor> _ProcessInputMapParam(Dictionary<string, Tensor> input_map)
{
if (input_map == null)
return new Dictionary<string, Tensor>();

return input_map;
}

private static string[] _ProcessReturnElementsParam(string[] return_elements)
{
if (return_elements == null)
return null;

return return_elements;
}

private static void _RemoveDefaultAttrs(Dictionary<string, OpDef> op_dict, OpList producer_op_list, GraphDef graph_def)
{
var producer_op_dict = new Dictionary<string, OpDef>();
producer_op_list.Op.Select(op =>
{
producer_op_dict[op.Name] = op;
return op;
}).ToArray();

foreach(var node in graph_def.Node)
{
// Remove any default attr values that aren't in op_def.
if (producer_op_dict.ContainsKey(node.Op))
{
var op_def = op_dict[node.Op];
var producer_op_def = producer_op_dict[node.Op];
foreach(var key in node.Attr)
{
if(_FindAttrInOpDef(key.Key, op_def) == null)
{
var attr_def = _FindAttrInOpDef(key.Key, producer_op_def);
if (attr_def != null && attr_def.DefaultValue != null &&
node.Attr[key.Key] == attr_def.DefaultValue)
node.Attr[key.Key].ClearValue();
}
}
}
}
}

private static AttrDef _FindAttrInOpDef(string name, OpDef op_def)
{
return op_def.Attr.FirstOrDefault(x => x.Name == name);
}
}
}

+ 54
- 0
src/TensorFlowNET.Core/Framework/meta_graph.py.cs View File

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using static Tensorflow.MetaGraphDef.Types;
@@ -8,6 +9,59 @@ namespace Tensorflow
{
public class meta_graph
{
public static MetaGraphDef read_meta_graph_file(string filename)
{
var bytes = File.ReadAllBytes(filename);
var meta_graph_def = MetaGraphDef.Parser.ParseFrom(bytes);
return meta_graph_def;
}

public static void import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file,
bool clear_devices = false,
string import_scope = "",
Dictionary<string, Tensor> input_map = null,
string unbound_inputs_col_name = "unbound_inputs",
string[] return_elements = null)
{
var meta_graph_def = meta_graph_or_file;

if (!string.IsNullOrEmpty(unbound_inputs_col_name))
{
foreach(var col in meta_graph_def.CollectionDef)
{
if(col.Key == unbound_inputs_col_name)
{
throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
}
}
}

// Sets graph to default graph if it's not passed in.
var graph = ops.get_default_graph();

// Gathers the list of nodes we are interested in.
OpList producer_op_list = null;
if (meta_graph_def.MetaInfoDef.StrippedOpList != null)
producer_op_list = meta_graph_def.MetaInfoDef.StrippedOpList;
var input_graph_def = meta_graph_def.GraphDef;
// Remove all the explicit device specifications for this node. This helps to
// make the graph more portable.
if (clear_devices)
foreach (var node in input_graph_def.Node)
node.Device = "";

var scope_to_prepend_to_names = graph.unique_name("", mark_as_used: false);
importer.import_graph_def(input_graph_def,
name: scope_to_prepend_to_names,
input_map: input_map,
producer_op_list: producer_op_list,
return_elements: return_elements);

// Restores all the other collections.
var variable_objects = new Dictionary<string, RefVariable>();

}

/// <summary>
/// Returns `MetaGraphDef` proto. Optionally writes it to filename.
/// </summary>


+ 12
- 0
src/TensorFlowNET.Core/Graphs/c_api.graph.cs View File

@@ -218,6 +218,18 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern void TF_ImportGraphDefOptionsSetPrefix(IntPtr ops, string prefix);

/// <summary>
/// Set whether to uniquify imported operation names. If true, imported operation
/// names will be modified if their name already exists in the graph. If false,
/// conflicting names will be treated as an error. Note that this option has no
/// effect if a prefix is set, since the prefix will guarantee all names are
/// unique. Defaults to false.
/// </summary>
/// <param name="ops">TF_ImportGraphDefOptions*</param>
/// <param name="uniquify_prefix">unsigned char</param>
[DllImport(TensorFlowLibName)]
public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(IntPtr ops, char uniquify_prefix);

/// <summary>
/// Fetches the return operations requested via
/// TF_ImportGraphDefOptionsAddReturnOperation(). The number of fetched


+ 7
- 4
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -4,7 +4,7 @@
<TargetFramework>netstandard2.0</TargetFramework>
<AssemblyName>TensorFlow.NET</AssemblyName>
<RootNamespace>Tensorflow</RootNamespace>
<Version>0.1.0</Version>
<Version>0.2.0</Version>
<Authors>Haiping Chen</Authors>
<Company>SciSharp STACK</Company>
<GeneratePackageOnBuild>true</GeneratePackageOnBuild>
@@ -16,10 +16,13 @@
<PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET</PackageTags>
<Description>Google's TensorFlow binding in .NET Standard.
Docs: https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.1.0.0</AssemblyVersion>
<PackageReleaseNotes>Implemented the tf.Variable().
TensorFlow 1.13 RC.</PackageReleaseNotes>
<AssemblyVersion>0.2.0.0</AssemblyVersion>
<PackageReleaseNotes>Added a bunch of APIs.
Fixed String tensor creation bug.
Upgraded to TensorFlow 1.13 RC-1.
</PackageReleaseNotes>
<LangVersion>7.2</LangVersion>
<FileVersion>0.2.0.0</FileVersion>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">


+ 7
- 0
src/TensorFlowNET.Core/Train/Saving/Saver.cs View File

@@ -193,6 +193,13 @@ namespace Tensorflow
return _is_empty ? string.Empty : model_checkpoint_path;
}

public Saver import_meta_graph(string meta_graph_or_file,
bool clear_devices = false,
string import_scope = "")
{
return saver._import_meta_graph_with_return_elements(meta_graph_or_file, clear_devices, import_scope);
}

/// <summary>
/// Writes `MetaGraphDef` to save_path/filename.
/// </summary>


+ 27
- 0
src/TensorFlowNET.Core/Train/Saving/saver.py.cs View File

@@ -0,0 +1,27 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class saver
{
public static Saver _import_meta_graph_with_return_elements(string meta_graph_or_file,
bool clear_devices = false,
string import_scope = "",
string[] return_elements = null)
{
var meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file);

meta_graph.import_scoped_meta_graph_with_return_elements(
meta_graph_def,
clear_devices: clear_devices,
import_scope: import_scope,
return_elements: return_elements);

return null;
/*var (imported_vars, imported_return_elements) = (
, false);*/
}
}
}

+ 6
- 0
src/TensorFlowNET.Core/Train/tf.optimizers.cs View File

@@ -14,6 +14,12 @@ namespace Tensorflow
public static Saver Saver() => new Saver();

public static string write_graph(Graph graph, string logdir, string name, bool as_text = true) => graph_io.write_graph(graph, logdir, name, as_text);

public static Saver import_meta_graph(string meta_graph_or_file,
bool clear_devices = false,
string import_scope = "") => saver._import_meta_graph_with_return_elements(meta_graph_or_file,
clear_devices,
import_scope);
}
}
}

+ 0
- 14
src/TensorFlowNET.Core/c_api_util.cs View File

@@ -1,14 +0,0 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class c_api_util
{
public static TF_Output tf_output(IntPtr c_op, int index)
{
return new TF_Output(c_op, index);
}
}
}

+ 1
- 1
test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj View File

@@ -7,7 +7,7 @@

<ItemGroup>
<PackageReference Include="NumSharp" Version="0.7.1" />
<PackageReference Include="TensorFlow.NET" Version="0.1.0" />
<PackageReference Include="TensorFlow.NET" Version="0.2.0" />
</ItemGroup>

<ItemGroup>


+ 1
- 1
test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj View File

@@ -20,7 +20,7 @@
<PackageReference Include="MSTest.TestAdapter" Version="1.4.0" />
<PackageReference Include="MSTest.TestFramework" Version="1.4.0" />
<PackageReference Include="NumSharp" Version="0.7.1" />
<PackageReference Include="TensorFlow.NET" Version="0.1.0" />
<PackageReference Include="TensorFlow.NET" Version="0.2.0" />
</ItemGroup>

<ItemGroup>


+ 5
- 3
test/TensorFlowNET.UnitTest/TrainSaverTest.cs View File

@@ -20,9 +20,10 @@ namespace TensorFlowNET.UnitTest
[TestMethod]
public void ImportGraph()
{
var v = tf.Variable(0, name: "my_variable");
var sess = tf.Session();
tf.train.write_graph(sess.graph, "/tmp/my-model", "train2.pbtxt");
with<Session>(tf.Session(), sess =>
{
var new_saver = tf.train.import_meta_graph("C:/tmp/my-model.meta");
});
}

[TestMethod]
@@ -45,6 +46,7 @@ namespace TensorFlowNET.UnitTest
});
}

[TestMethod]
public void Save2()
{
var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer);


Loading…
Cancel
Save