Browse Source

Change Graph.Import interface.

tags/v0.12
Oceania2018 6 years ago
parent
commit
f32692b1c7
7 changed files with 44 additions and 46 deletions
  1. +10
    -9
      src/TensorFlowNET.Core/Framework/importer.py.cs
  2. +22
    -24
      src/TensorFlowNET.Core/Graphs/Graph.Import.cs
  3. +3
    -7
      src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs
  4. +2
    -2
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  5. +1
    -1
      test/TensorFlowNET.Examples/ImageProcessing/ImageRecognitionInception.cs
  6. +2
    -1
      test/TensorFlowNET.Examples/ImageProcessing/InceptionArchGoogLeNet.cs
  7. +4
    -2
      test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs

+ 10
- 9
src/TensorFlowNET.Core/Framework/importer.py.cs View File

@@ -54,16 +54,17 @@ namespace Tensorflow
input_map = _ConvertInputMapValues(name, input_map);
});

var scoped_options = c_api_util.ScopedTFImportGraphDefOptions();
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements);

var bytes = graph_def.ToByteString().ToArray();
IntPtr buffer = c_api_util.tf_buffer(bytes);

var status = new Status();
// need to create a class ImportGraphDefWithResults with IDisposal
var results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status);
status.Check(true);
using (var buffer = c_api_util.tf_buffer(bytes))
using (var scoped_options = c_api_util.ScopedTFImportGraphDefOptions())
using (var status = new Status())
{
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements);
// need to create a class ImportGraphDefWithResults with IDisposal
var results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status);
status.Check(true);
c_api.TF_DeleteImportGraphDefResults(results);
}

_ProcessNewOps(graph);



+ 22
- 24
src/TensorFlowNET.Core/Graphs/Graph.Import.cs View File

@@ -23,6 +23,7 @@ namespace Tensorflow
{
public unsafe TF_Output[] ImportGraphDefWithReturnOutputs(Buffer graph_def, ImportGraphDefOptions opts, Status s)
{
as_default();
var num_return_outputs = opts.NumReturnOutputs;
var return_outputs = new TF_Output[num_return_outputs];
int size = Marshal.SizeOf<TF_Output>();
@@ -35,40 +36,37 @@ namespace Tensorflow
return_outputs[i] = Marshal.PtrToStructure<TF_Output>(handle);
}

Marshal.FreeHGlobal(return_output_handle);

return return_outputs;
}

public Status Import(string file_path)
public bool Import(string file_path, string prefix = "")
{
var bytes = File.ReadAllBytes(file_path);
var graph_def = new Tensorflow.Buffer(bytes);
var opts = c_api.TF_NewImportGraphDefOptions();
var status = new Status();
c_api.TF_GraphImportGraphDef(_handle, graph_def, opts, status);
return status;
}

public Status Import(byte[] bytes, string prefix = "")
{
var graph_def = new Tensorflow.Buffer(bytes);
var opts = c_api.TF_NewImportGraphDefOptions();
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, prefix);
var status = new Status();
c_api.TF_GraphImportGraphDef(_handle, graph_def, opts, status);
c_api.TF_DeleteImportGraphDefOptions(opts);
return status;
return Import(bytes, prefix: prefix);
}

static object locker = new object();
public static Graph ImportFromPB(string file_path, string name = null)
public bool Import(byte[] bytes, string prefix = "")
{
lock (locker)
using (var opts = new ImportGraphDefOptions())
using (var status = new Status())
using (var graph_def = new Buffer(bytes))
{
var graph = tf.Graph().as_default();
var graph_def = GraphDef.Parser.ParseFrom(File.ReadAllBytes(file_path));
importer.import_graph_def(graph_def, name: name);
return graph;
as_default();
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, prefix);
c_api.TF_GraphImportGraphDef(_handle, graph_def, opts, status);
status.Check(true);
return status.Code == TF_Code.TF_OK;
}
}

/*public Graph Import(string file_path, string name = null)
{
as_default();
var graph_def = GraphDef.Parser.ParseFrom(File.ReadAllBytes(file_path));
importer.import_graph_def(graph_def, name: name);
return this;
}*/
}
}

+ 3
- 7
src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs View File

@@ -18,10 +18,8 @@ using System;

namespace Tensorflow
{
public class ImportGraphDefOptions : IDisposable
public class ImportGraphDefOptions : DisposableObject
{
private IntPtr _handle;

public int NumReturnOutputs => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle);

public ImportGraphDefOptions()
@@ -39,10 +37,8 @@ namespace Tensorflow
c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index);
}

public void Dispose()
{
c_api.TF_DeleteImportGraphDefOptions(_handle);
}
protected override void DisposeUnManagedState(IntPtr handle)
=> c_api.TF_DeleteImportGraphDefOptions(handle);

public static implicit operator IntPtr(ImportGraphDefOptions opts) => opts._handle;
public static implicit operator ImportGraphDefOptions(IntPtr handle) => new ImportGraphDefOptions(handle);


+ 2
- 2
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -5,7 +5,7 @@
<AssemblyName>TensorFlow.NET</AssemblyName>
<RootNamespace>Tensorflow</RootNamespace>
<TargetTensorFlow>1.14.0</TargetTensorFlow>
<Version>0.10.9</Version>
<Version>0.10.10</Version>
<Authors>Haiping Chen, Meinrad Recheis</Authors>
<Company>SciSharp STACK</Company>
<GeneratePackageOnBuild>true</GeneratePackageOnBuild>
@@ -17,7 +17,7 @@
<PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags>
<Description>Google's TensorFlow full binding in .NET Standard.
Docs: https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.10.9.0</AssemblyVersion>
<AssemblyVersion>0.10.10.0</AssemblyVersion>
<PackageReleaseNotes>Changes since v0.9.0:

1. Added full connected Convolution Neural Network example.


+ 1
- 1
test/TensorFlowNET.Examples/ImageProcessing/ImageRecognitionInception.cs View File

@@ -31,7 +31,7 @@ namespace TensorFlowNET.Examples
{
PrepareData();
var graph = new Graph().as_default();
var graph = new Graph();
//import GraphDef from pb file
graph.Import(Path.Join(dir, pbFile));



+ 2
- 1
test/TensorFlowNET.Examples/ImageProcessing/InceptionArchGoogLeNet.cs View File

@@ -41,7 +41,8 @@ namespace TensorFlowNET.Examples
input_mean: input_mean,
input_std: input_std);

var graph = Graph.ImportFromPB(Path.Join(dir, pbFile));
var graph = new Graph();
graph.Import(Path.Join(dir, pbFile));
var input_operation = graph.get_operation_by_name(input_name);
var output_operation = graph.get_operation_by_name(output_name);



+ 4
- 2
test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs View File

@@ -738,7 +738,8 @@ namespace TensorFlowNET.Examples
var fileBytes = ReadTensorFromImageFile(img_path);

// import graph and variables
var graph = Graph.ImportFromPB(output_graph, "");
var graph = new Graph();
graph.Import(output_graph, "");

Tensor input = graph.OperationByName("Placeholder");
Tensor output = graph.OperationByName("final_result");
@@ -778,7 +779,8 @@ namespace TensorFlowNET.Examples
if (!File.Exists(output_graph))
return;

var graph = Graph.ImportFromPB(output_graph);
var graph = new Graph();
graph.Import(output_graph);
var (jpeg_data_tensor, decoded_image_tensor) = add_jpeg_decoding();

tf_with(tf.Session(graph), sess =>


Loading…
Cancel
Save