@@ -62,7 +62,7 @@ namespace Tensorflow | |||
{ | |||
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); | |||
// need to create a class ImportGraphDefWithResults with IDisposal | |||
results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer.Handle, scoped_options.Handle, status.Handle); | |||
results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer.Handle, scoped_options.Handle, status.Handle)); | |||
status.Check(true); | |||
} | |||
@@ -33,7 +33,7 @@ namespace Tensorflow | |||
return c_api.TF_NewOperation(_handle, opType, opName); | |||
} | |||
public Operation[] ReturnOperations(IntPtr results) | |||
public Operation[] ReturnOperations(SafeImportGraphDefResultsHandle results) | |||
{ | |||
TF_Operation return_oper_handle = new TF_Operation(); | |||
int num_return_opers = 0; | |||
@@ -413,7 +413,7 @@ namespace Tensorflow | |||
return name; | |||
} | |||
public TF_Output[] ReturnOutputs(IntPtr results) | |||
public TF_Output[] ReturnOutputs(SafeImportGraphDefResultsHandle results) | |||
{ | |||
IntPtr return_output_handle = IntPtr.Zero; | |||
int num_return_outputs = 0; | |||
@@ -0,0 +1,40 @@ | |||
/***************************************************************************** | |||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using Tensorflow.Util; | |||
namespace Tensorflow | |||
{ | |||
public sealed class SafeImportGraphDefResultsHandle : SafeTensorflowHandle | |||
{ | |||
private SafeImportGraphDefResultsHandle() | |||
{ | |||
} | |||
public SafeImportGraphDefResultsHandle(IntPtr handle) | |||
: base(handle) | |||
{ | |||
} | |||
protected override bool ReleaseHandle() | |||
{ | |||
c_api.TF_DeleteImportGraphDefResults(handle); | |||
SetHandle(IntPtr.Zero); | |||
return true; | |||
} | |||
} | |||
} |
@@ -1,18 +1,35 @@ | |||
using System; | |||
using System.Runtime.InteropServices; | |||
/***************************************************************************** | |||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
namespace Tensorflow | |||
{ | |||
public class TF_ImportGraphDefResults : DisposableObject | |||
public sealed class TF_ImportGraphDefResults : IDisposable | |||
{ | |||
/*public IntPtr return_nodes; | |||
public IntPtr missing_unused_key_names; | |||
public IntPtr missing_unused_key_indexes; | |||
public IntPtr missing_unused_key_names_data;*/ | |||
public TF_ImportGraphDefResults(IntPtr handle) | |||
private SafeImportGraphDefResultsHandle Handle { get; } | |||
public TF_ImportGraphDefResults(SafeImportGraphDefResultsHandle handle) | |||
{ | |||
_handle = handle; | |||
Handle = handle; | |||
} | |||
public TF_Output[] return_tensors | |||
@@ -21,7 +38,7 @@ namespace Tensorflow | |||
{ | |||
IntPtr return_output_handle = IntPtr.Zero; | |||
int num_outputs = -1; | |||
c_api.TF_ImportGraphDefResultsReturnOutputs(_handle, ref num_outputs, ref return_output_handle); | |||
c_api.TF_ImportGraphDefResultsReturnOutputs(Handle, ref num_outputs, ref return_output_handle); | |||
TF_Output[] return_outputs = new TF_Output[num_outputs]; | |||
unsafe | |||
{ | |||
@@ -52,13 +69,7 @@ namespace Tensorflow | |||
} | |||
} | |||
public static implicit operator TF_ImportGraphDefResults(IntPtr handle) | |||
=> new TF_ImportGraphDefResults(handle); | |||
public static implicit operator IntPtr(TF_ImportGraphDefResults results) | |||
=> results._handle; | |||
protected override void DisposeUnmanagedResources(IntPtr handle) | |||
=> c_api.TF_DeleteImportGraphDefResults(handle); | |||
public void Dispose() | |||
=> Handle.Dispose(); | |||
} | |||
} |
@@ -92,7 +92,7 @@ namespace Tensorflow | |||
/// <param name="status">TF_Status*</param> | |||
/// <returns>TF_ImportGraphDefResults*</returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TF_GraphImportGraphDefWithResults(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status); | |||
public static extern SafeImportGraphDefResultsHandle TF_GraphImportGraphDefWithResults(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status); | |||
/// <summary> | |||
/// Import the graph serialized in `graph_def` into `graph`. | |||
@@ -258,7 +258,7 @@ namespace Tensorflow | |||
/// <param name="num_opers">int*</param> | |||
/// <param name="opers">TF_Operation***</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_ImportGraphDefResultsReturnOperations(IntPtr results, ref int num_opers, ref TF_Operation opers); | |||
public static extern void TF_ImportGraphDefResultsReturnOperations(SafeImportGraphDefResultsHandle results, ref int num_opers, ref TF_Operation opers); | |||
/// <summary> | |||
/// Fetches the return outputs requested via | |||
@@ -270,7 +270,7 @@ namespace Tensorflow | |||
/// <param name="num_outputs">int*</param> | |||
/// <param name="outputs">TF_Output**</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_ImportGraphDefResultsReturnOutputs(IntPtr results, ref int num_outputs, ref IntPtr outputs); | |||
public static extern void TF_ImportGraphDefResultsReturnOutputs(SafeImportGraphDefResultsHandle results, ref int num_outputs, ref IntPtr outputs); | |||
/// <summary> | |||
/// This function creates a new TF_Session (which is created on success) using | |||
@@ -258,11 +258,9 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
EXPECT_EQ(0, neg.NumControlOutputs); | |||
EXPECT_EQ(0, neg.GetControlOutputs().Length); | |||
// Import it again, with an input mapping, return outputs, and a return | |||
// operation, into the same graph. | |||
IntPtr results; | |||
using (var opts = c_api.TF_NewImportGraphDefOptions()) | |||
static SafeImportGraphDefResultsHandle ImportGraph(Status s, Graph graph, Buffer graph_def, Operation scalar) | |||
{ | |||
using var opts = c_api.TF_NewImportGraphDefOptions(); | |||
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported2"); | |||
c_api.TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, new TF_Output(scalar, 0)); | |||
c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0); | |||
@@ -270,32 +268,39 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts)); | |||
c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar"); | |||
EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts)); | |||
results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def.Handle, opts, s.Handle); | |||
var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def.Handle, opts, s.Handle); | |||
EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
} | |||
Operation scalar2 = graph.OperationByName("imported2/scalar"); | |||
Operation feed2 = graph.OperationByName("imported2/feed"); | |||
Operation neg2 = graph.OperationByName("imported2/neg"); | |||
// Check input mapping | |||
neg_input = neg.Input(0); | |||
EXPECT_EQ(scalar, neg_input.oper); | |||
EXPECT_EQ(0, neg_input.index); | |||
// Check return outputs | |||
var return_outputs = graph.ReturnOutputs(results); | |||
ASSERT_EQ(2, return_outputs.Length); | |||
EXPECT_EQ(feed2, return_outputs[0].oper); | |||
EXPECT_EQ(0, return_outputs[0].index); | |||
EXPECT_EQ(scalar, return_outputs[1].oper); // remapped | |||
EXPECT_EQ(0, return_outputs[1].index); | |||
return results; | |||
} | |||
// Check return operation | |||
var return_opers = graph.ReturnOperations(results); | |||
ASSERT_EQ(1, return_opers.Length); | |||
EXPECT_EQ(scalar2, return_opers[0]); // not remapped | |||
c_api.TF_DeleteImportGraphDefResults(results); | |||
// Import it again, with an input mapping, return outputs, and a return | |||
// operation, into the same graph. | |||
Operation feed2; | |||
using (SafeImportGraphDefResultsHandle results = ImportGraph(s, graph, graph_def, scalar)) | |||
{ | |||
Operation scalar2 = graph.OperationByName("imported2/scalar"); | |||
feed2 = graph.OperationByName("imported2/feed"); | |||
Operation neg2 = graph.OperationByName("imported2/neg"); | |||
// Check input mapping | |||
neg_input = neg.Input(0); | |||
EXPECT_EQ(scalar, neg_input.oper); | |||
EXPECT_EQ(0, neg_input.index); | |||
// Check return outputs | |||
var return_outputs = graph.ReturnOutputs(results); | |||
ASSERT_EQ(2, return_outputs.Length); | |||
EXPECT_EQ(feed2, return_outputs[0].oper); | |||
EXPECT_EQ(0, return_outputs[0].index); | |||
EXPECT_EQ(scalar, return_outputs[1].oper); // remapped | |||
EXPECT_EQ(0, return_outputs[1].index); | |||
// Check return operation | |||
var return_opers = graph.ReturnOperations(results); | |||
ASSERT_EQ(1, return_opers.Length); | |||
EXPECT_EQ(scalar2, return_opers[0]); // not remapped | |||
} | |||
// Import again, with control dependencies, into the same graph. | |||
using (var opts = c_api.TF_NewImportGraphDefOptions()) | |||