From f32692b1c74881c4b4de2d7cf230b1e06615a2ad Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 3 Aug 2019 14:23:18 -0500 Subject: [PATCH] Change Graph.Import interface. --- .../Framework/importer.py.cs | 19 ++++---- src/TensorFlowNET.Core/Graphs/Graph.Import.cs | 46 +++++++++---------- .../Graphs/ImportGraphDefOptions.cs | 10 ++-- .../TensorFlowNET.Core.csproj | 4 +- .../ImageRecognitionInception.cs | 2 +- .../ImageProcessing/InceptionArchGoogLeNet.cs | 3 +- .../ImageProcessing/RetrainImageClassifier.cs | 6 ++- 7 files changed, 44 insertions(+), 46 deletions(-) diff --git a/src/TensorFlowNET.Core/Framework/importer.py.cs b/src/TensorFlowNET.Core/Framework/importer.py.cs index 0c405be9..254fda19 100644 --- a/src/TensorFlowNET.Core/Framework/importer.py.cs +++ b/src/TensorFlowNET.Core/Framework/importer.py.cs @@ -54,16 +54,17 @@ namespace Tensorflow input_map = _ConvertInputMapValues(name, input_map); }); - var scoped_options = c_api_util.ScopedTFImportGraphDefOptions(); - _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); - var bytes = graph_def.ToByteString().ToArray(); - IntPtr buffer = c_api_util.tf_buffer(bytes); - - var status = new Status(); - // need to create a class ImportGraphDefWithResults with IDisposal - var results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status); - status.Check(true); + using (var buffer = c_api_util.tf_buffer(bytes)) + using (var scoped_options = c_api_util.ScopedTFImportGraphDefOptions()) + using (var status = new Status()) + { + _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); + // need to create a class ImportGraphDefWithResults with IDisposal + var results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status); + status.Check(true); + c_api.TF_DeleteImportGraphDefResults(results); + } _ProcessNewOps(graph); diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Import.cs b/src/TensorFlowNET.Core/Graphs/Graph.Import.cs index 1c91868b..82695527 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Import.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Import.cs @@ -23,6 +23,7 @@ namespace Tensorflow { public unsafe TF_Output[] ImportGraphDefWithReturnOutputs(Buffer graph_def, ImportGraphDefOptions opts, Status s) { + as_default(); var num_return_outputs = opts.NumReturnOutputs; var return_outputs = new TF_Output[num_return_outputs]; int size = Marshal.SizeOf(); @@ -35,40 +36,37 @@ namespace Tensorflow return_outputs[i] = Marshal.PtrToStructure(handle); } + Marshal.FreeHGlobal(return_output_handle); + return return_outputs; } - public Status Import(string file_path) + public bool Import(string file_path, string prefix = "") { var bytes = File.ReadAllBytes(file_path); - var graph_def = new Tensorflow.Buffer(bytes); - var opts = c_api.TF_NewImportGraphDefOptions(); - var status = new Status(); - c_api.TF_GraphImportGraphDef(_handle, graph_def, opts, status); - return status; - } - - public Status Import(byte[] bytes, string prefix = "") - { - var graph_def = new Tensorflow.Buffer(bytes); - var opts = c_api.TF_NewImportGraphDefOptions(); - c_api.TF_ImportGraphDefOptionsSetPrefix(opts, prefix); - var status = new Status(); - c_api.TF_GraphImportGraphDef(_handle, graph_def, opts, status); - c_api.TF_DeleteImportGraphDefOptions(opts); - return status; + return Import(bytes, prefix: prefix); } - static object locker = new object(); - public static Graph ImportFromPB(string file_path, string name = null) + public bool Import(byte[] bytes, string prefix = "") { - lock (locker) + using (var opts = new ImportGraphDefOptions()) + using (var status = new Status()) + using (var graph_def = new Buffer(bytes)) { - var graph = tf.Graph().as_default(); - var graph_def = GraphDef.Parser.ParseFrom(File.ReadAllBytes(file_path)); - importer.import_graph_def(graph_def, name: name); - return graph; + as_default(); + c_api.TF_ImportGraphDefOptionsSetPrefix(opts, prefix); + c_api.TF_GraphImportGraphDef(_handle, graph_def, opts, status); + status.Check(true); + return status.Code == TF_Code.TF_OK; } } + + /*public Graph Import(string file_path, string name = null) + { + as_default(); + var graph_def = GraphDef.Parser.ParseFrom(File.ReadAllBytes(file_path)); + importer.import_graph_def(graph_def, name: name); + return this; + }*/ } } diff --git a/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs b/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs index 6a0a812a..97720206 100644 --- a/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs +++ b/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs @@ -18,10 +18,8 @@ using System; namespace Tensorflow { - public class ImportGraphDefOptions : IDisposable + public class ImportGraphDefOptions : DisposableObject { - private IntPtr _handle; - public int NumReturnOutputs => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle); public ImportGraphDefOptions() @@ -39,10 +37,8 @@ namespace Tensorflow c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); } - public void Dispose() - { - c_api.TF_DeleteImportGraphDefOptions(_handle); - } + protected override void DisposeUnManagedState(IntPtr handle) + => c_api.TF_DeleteImportGraphDefOptions(handle); public static implicit operator IntPtr(ImportGraphDefOptions opts) => opts._handle; public static implicit operator ImportGraphDefOptions(IntPtr handle) => new ImportGraphDefOptions(handle); diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 17b5191b..af74b38d 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -5,7 +5,7 @@ TensorFlow.NET Tensorflow 1.14.0 - 0.10.9 + 0.10.10 Haiping Chen, Meinrad Recheis SciSharp STACK true @@ -17,7 +17,7 @@ TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C# Google's TensorFlow full binding in .NET Standard. Docs: https://tensorflownet.readthedocs.io - 0.10.9.0 + 0.10.10.0 Changes since v0.9.0: 1. Added full connected Convolution Neural Network example. diff --git a/test/TensorFlowNET.Examples/ImageProcessing/ImageRecognitionInception.cs b/test/TensorFlowNET.Examples/ImageProcessing/ImageRecognitionInception.cs index efcb0b73..548c84f4 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/ImageRecognitionInception.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/ImageRecognitionInception.cs @@ -31,7 +31,7 @@ namespace TensorFlowNET.Examples { PrepareData(); - var graph = new Graph().as_default(); + var graph = new Graph(); //import GraphDef from pb file graph.Import(Path.Join(dir, pbFile)); diff --git a/test/TensorFlowNET.Examples/ImageProcessing/InceptionArchGoogLeNet.cs b/test/TensorFlowNET.Examples/ImageProcessing/InceptionArchGoogLeNet.cs index f51833d2..93fa9c2c 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/InceptionArchGoogLeNet.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/InceptionArchGoogLeNet.cs @@ -41,7 +41,8 @@ namespace TensorFlowNET.Examples input_mean: input_mean, input_std: input_std); - var graph = Graph.ImportFromPB(Path.Join(dir, pbFile)); + var graph = new Graph(); + graph.Import(Path.Join(dir, pbFile)); var input_operation = graph.get_operation_by_name(input_name); var output_operation = graph.get_operation_by_name(output_name); diff --git a/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs b/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs index 72c7b296..60329304 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs @@ -738,7 +738,8 @@ namespace TensorFlowNET.Examples var fileBytes = ReadTensorFromImageFile(img_path); // import graph and variables - var graph = Graph.ImportFromPB(output_graph, ""); + var graph = new Graph(); + graph.Import(output_graph, ""); Tensor input = graph.OperationByName("Placeholder"); Tensor output = graph.OperationByName("final_result"); @@ -778,7 +779,8 @@ namespace TensorFlowNET.Examples if (!File.Exists(output_graph)) return; - var graph = Graph.ImportFromPB(output_graph); + var graph = new Graph(); + graph.Import(output_graph); var (jpeg_data_tensor, decoded_image_tensor) = add_jpeg_decoding(); tf_with(tf.Session(graph), sess =>