Browse Source

Change Graph.Import interface.

tags/v0.12
Oceania2018 6 years ago
parent
commit
f32692b1c7
7 changed files with 44 additions and 46 deletions
  1. +10
    -9
      src/TensorFlowNET.Core/Framework/importer.py.cs
  2. +22
    -24
      src/TensorFlowNET.Core/Graphs/Graph.Import.cs
  3. +3
    -7
      src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs
  4. +2
    -2
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  5. +1
    -1
      test/TensorFlowNET.Examples/ImageProcessing/ImageRecognitionInception.cs
  6. +2
    -1
      test/TensorFlowNET.Examples/ImageProcessing/InceptionArchGoogLeNet.cs
  7. +4
    -2
      test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs

+ 10
- 9
src/TensorFlowNET.Core/Framework/importer.py.cs View File

@@ -54,16 +54,17 @@ namespace Tensorflow
input_map = _ConvertInputMapValues(name, input_map); input_map = _ConvertInputMapValues(name, input_map);
}); });


var scoped_options = c_api_util.ScopedTFImportGraphDefOptions();
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements);

var bytes = graph_def.ToByteString().ToArray(); var bytes = graph_def.ToByteString().ToArray();
IntPtr buffer = c_api_util.tf_buffer(bytes);

var status = new Status();
// need to create a class ImportGraphDefWithResults with IDisposal
var results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status);
status.Check(true);
using (var buffer = c_api_util.tf_buffer(bytes))
using (var scoped_options = c_api_util.ScopedTFImportGraphDefOptions())
using (var status = new Status())
{
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements);
// need to create a class ImportGraphDefWithResults with IDisposal
var results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status);
status.Check(true);
c_api.TF_DeleteImportGraphDefResults(results);
}


_ProcessNewOps(graph); _ProcessNewOps(graph);




+ 22
- 24
src/TensorFlowNET.Core/Graphs/Graph.Import.cs View File

@@ -23,6 +23,7 @@ namespace Tensorflow
{ {
public unsafe TF_Output[] ImportGraphDefWithReturnOutputs(Buffer graph_def, ImportGraphDefOptions opts, Status s) public unsafe TF_Output[] ImportGraphDefWithReturnOutputs(Buffer graph_def, ImportGraphDefOptions opts, Status s)
{ {
as_default();
var num_return_outputs = opts.NumReturnOutputs; var num_return_outputs = opts.NumReturnOutputs;
var return_outputs = new TF_Output[num_return_outputs]; var return_outputs = new TF_Output[num_return_outputs];
int size = Marshal.SizeOf<TF_Output>(); int size = Marshal.SizeOf<TF_Output>();
@@ -35,40 +36,37 @@ namespace Tensorflow
return_outputs[i] = Marshal.PtrToStructure<TF_Output>(handle); return_outputs[i] = Marshal.PtrToStructure<TF_Output>(handle);
} }


Marshal.FreeHGlobal(return_output_handle);

return return_outputs; return return_outputs;
} }


public Status Import(string file_path)
public bool Import(string file_path, string prefix = "")
{ {
var bytes = File.ReadAllBytes(file_path); var bytes = File.ReadAllBytes(file_path);
var graph_def = new Tensorflow.Buffer(bytes);
var opts = c_api.TF_NewImportGraphDefOptions();
var status = new Status();
c_api.TF_GraphImportGraphDef(_handle, graph_def, opts, status);
return status;
}

public Status Import(byte[] bytes, string prefix = "")
{
var graph_def = new Tensorflow.Buffer(bytes);
var opts = c_api.TF_NewImportGraphDefOptions();
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, prefix);
var status = new Status();
c_api.TF_GraphImportGraphDef(_handle, graph_def, opts, status);
c_api.TF_DeleteImportGraphDefOptions(opts);
return status;
return Import(bytes, prefix: prefix);
} }


static object locker = new object();
public static Graph ImportFromPB(string file_path, string name = null)
public bool Import(byte[] bytes, string prefix = "")
{ {
lock (locker)
using (var opts = new ImportGraphDefOptions())
using (var status = new Status())
using (var graph_def = new Buffer(bytes))
{ {
var graph = tf.Graph().as_default();
var graph_def = GraphDef.Parser.ParseFrom(File.ReadAllBytes(file_path));
importer.import_graph_def(graph_def, name: name);
return graph;
as_default();
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, prefix);
c_api.TF_GraphImportGraphDef(_handle, graph_def, opts, status);
status.Check(true);
return status.Code == TF_Code.TF_OK;
} }
} }

/*public Graph Import(string file_path, string name = null)
{
as_default();
var graph_def = GraphDef.Parser.ParseFrom(File.ReadAllBytes(file_path));
importer.import_graph_def(graph_def, name: name);
return this;
}*/
} }
} }

+ 3
- 7
src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs View File

@@ -18,10 +18,8 @@ using System;


namespace Tensorflow namespace Tensorflow
{ {
public class ImportGraphDefOptions : IDisposable
public class ImportGraphDefOptions : DisposableObject
{ {
private IntPtr _handle;

public int NumReturnOutputs => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle); public int NumReturnOutputs => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle);


public ImportGraphDefOptions() public ImportGraphDefOptions()
@@ -39,10 +37,8 @@ namespace Tensorflow
c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index);
} }


public void Dispose()
{
c_api.TF_DeleteImportGraphDefOptions(_handle);
}
protected override void DisposeUnManagedState(IntPtr handle)
=> c_api.TF_DeleteImportGraphDefOptions(handle);


public static implicit operator IntPtr(ImportGraphDefOptions opts) => opts._handle; public static implicit operator IntPtr(ImportGraphDefOptions opts) => opts._handle;
public static implicit operator ImportGraphDefOptions(IntPtr handle) => new ImportGraphDefOptions(handle); public static implicit operator ImportGraphDefOptions(IntPtr handle) => new ImportGraphDefOptions(handle);


+ 2
- 2
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -5,7 +5,7 @@
<AssemblyName>TensorFlow.NET</AssemblyName> <AssemblyName>TensorFlow.NET</AssemblyName>
<RootNamespace>Tensorflow</RootNamespace> <RootNamespace>Tensorflow</RootNamespace>
<TargetTensorFlow>1.14.0</TargetTensorFlow> <TargetTensorFlow>1.14.0</TargetTensorFlow>
<Version>0.10.9</Version>
<Version>0.10.10</Version>
<Authors>Haiping Chen, Meinrad Recheis</Authors> <Authors>Haiping Chen, Meinrad Recheis</Authors>
<Company>SciSharp STACK</Company> <Company>SciSharp STACK</Company>
<GeneratePackageOnBuild>true</GeneratePackageOnBuild> <GeneratePackageOnBuild>true</GeneratePackageOnBuild>
@@ -17,7 +17,7 @@
<PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags> <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags>
<Description>Google's TensorFlow full binding in .NET Standard. <Description>Google's TensorFlow full binding in .NET Standard.
Docs: https://tensorflownet.readthedocs.io</Description> Docs: https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.10.9.0</AssemblyVersion>
<AssemblyVersion>0.10.10.0</AssemblyVersion>
<PackageReleaseNotes>Changes since v0.9.0: <PackageReleaseNotes>Changes since v0.9.0:


1. Added full connected Convolution Neural Network example. 1. Added full connected Convolution Neural Network example.


+ 1
- 1
test/TensorFlowNET.Examples/ImageProcessing/ImageRecognitionInception.cs View File

@@ -31,7 +31,7 @@ namespace TensorFlowNET.Examples
{ {
PrepareData(); PrepareData();
var graph = new Graph().as_default();
var graph = new Graph();
//import GraphDef from pb file //import GraphDef from pb file
graph.Import(Path.Join(dir, pbFile)); graph.Import(Path.Join(dir, pbFile));




+ 2
- 1
test/TensorFlowNET.Examples/ImageProcessing/InceptionArchGoogLeNet.cs View File

@@ -41,7 +41,8 @@ namespace TensorFlowNET.Examples
input_mean: input_mean, input_mean: input_mean,
input_std: input_std); input_std: input_std);


var graph = Graph.ImportFromPB(Path.Join(dir, pbFile));
var graph = new Graph();
graph.Import(Path.Join(dir, pbFile));
var input_operation = graph.get_operation_by_name(input_name); var input_operation = graph.get_operation_by_name(input_name);
var output_operation = graph.get_operation_by_name(output_name); var output_operation = graph.get_operation_by_name(output_name);




+ 4
- 2
test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs View File

@@ -738,7 +738,8 @@ namespace TensorFlowNET.Examples
var fileBytes = ReadTensorFromImageFile(img_path); var fileBytes = ReadTensorFromImageFile(img_path);


// import graph and variables // import graph and variables
var graph = Graph.ImportFromPB(output_graph, "");
var graph = new Graph();
graph.Import(output_graph, "");


Tensor input = graph.OperationByName("Placeholder"); Tensor input = graph.OperationByName("Placeholder");
Tensor output = graph.OperationByName("final_result"); Tensor output = graph.OperationByName("final_result");
@@ -778,7 +779,8 @@ namespace TensorFlowNET.Examples
if (!File.Exists(output_graph)) if (!File.Exists(output_graph))
return; return;


var graph = Graph.ImportFromPB(output_graph);
var graph = new Graph();
graph.Import(output_graph);
var (jpeg_data_tensor, decoded_image_tensor) = add_jpeg_decoding(); var (jpeg_data_tensor, decoded_image_tensor) = add_jpeg_decoding();


tf_with(tf.Session(graph), sess => tf_with(tf.Session(graph), sess =>


Loading…
Cancel
Save