Browse Source

Implement SafeImportGraphDefResultsHandle as a wrapper for TF_ImportGraphDefResults

tags/v0.20
Sam Harwell Haiping 5 years ago
parent
commit
f7e61b0199
7 changed files with 103 additions and 47 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Framework/importer.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Graphs/Graph.cs
  4. +40
    -0
      src/TensorFlowNET.Core/Graphs/SafeImportGraphDefResultsHandle.cs
  5. +25
    -14
      src/TensorFlowNET.Core/Graphs/TF_ImportGraphDefResults.cs
  6. +3
    -3
      src/TensorFlowNET.Core/Graphs/c_api.graph.cs
  7. +32
    -27
      test/TensorFlowNET.UnitTest/GraphTest.cs

+ 1
- 1
src/TensorFlowNET.Core/Framework/importer.cs View File

@@ -62,7 +62,7 @@ namespace Tensorflow
{
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements);
// need to create a class ImportGraphDefWithResults with IDisposal
results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer.Handle, scoped_options.Handle, status.Handle);
results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer.Handle, scoped_options.Handle, status.Handle));
status.Check(true);
}



+ 1
- 1
src/TensorFlowNET.Core/Graphs/Graph.Operation.cs View File

@@ -33,7 +33,7 @@ namespace Tensorflow
return c_api.TF_NewOperation(_handle, opType, opName);
}

public Operation[] ReturnOperations(IntPtr results)
public Operation[] ReturnOperations(SafeImportGraphDefResultsHandle results)
{
TF_Operation return_oper_handle = new TF_Operation();
int num_return_opers = 0;


+ 1
- 1
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -413,7 +413,7 @@ namespace Tensorflow
return name;
}

public TF_Output[] ReturnOutputs(IntPtr results)
public TF_Output[] ReturnOutputs(SafeImportGraphDefResultsHandle results)
{
IntPtr return_output_handle = IntPtr.Zero;
int num_return_outputs = 0;


+ 40
- 0
src/TensorFlowNET.Core/Graphs/SafeImportGraphDefResultsHandle.cs View File

@@ -0,0 +1,40 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using Tensorflow.Util;

namespace Tensorflow
{
public sealed class SafeImportGraphDefResultsHandle : SafeTensorflowHandle
{
private SafeImportGraphDefResultsHandle()
{
}

public SafeImportGraphDefResultsHandle(IntPtr handle)
: base(handle)
{
}

protected override bool ReleaseHandle()
{
c_api.TF_DeleteImportGraphDefResults(handle);
SetHandle(IntPtr.Zero);
return true;
}
}
}

+ 25
- 14
src/TensorFlowNET.Core/Graphs/TF_ImportGraphDefResults.cs View File

@@ -1,18 +1,35 @@
using System;
using System.Runtime.InteropServices;
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;

namespace Tensorflow
{
public class TF_ImportGraphDefResults : DisposableObject
public sealed class TF_ImportGraphDefResults : IDisposable
{
/*public IntPtr return_nodes;
public IntPtr missing_unused_key_names;
public IntPtr missing_unused_key_indexes;
public IntPtr missing_unused_key_names_data;*/

public TF_ImportGraphDefResults(IntPtr handle)
private SafeImportGraphDefResultsHandle Handle { get; }

public TF_ImportGraphDefResults(SafeImportGraphDefResultsHandle handle)
{
_handle = handle;
Handle = handle;
}

public TF_Output[] return_tensors
@@ -21,7 +38,7 @@ namespace Tensorflow
{
IntPtr return_output_handle = IntPtr.Zero;
int num_outputs = -1;
c_api.TF_ImportGraphDefResultsReturnOutputs(_handle, ref num_outputs, ref return_output_handle);
c_api.TF_ImportGraphDefResultsReturnOutputs(Handle, ref num_outputs, ref return_output_handle);
TF_Output[] return_outputs = new TF_Output[num_outputs];
unsafe
{
@@ -52,13 +69,7 @@ namespace Tensorflow
}
}

public static implicit operator TF_ImportGraphDefResults(IntPtr handle)
=> new TF_ImportGraphDefResults(handle);

public static implicit operator IntPtr(TF_ImportGraphDefResults results)
=> results._handle;

protected override void DisposeUnmanagedResources(IntPtr handle)
=> c_api.TF_DeleteImportGraphDefResults(handle);
public void Dispose()
=> Handle.Dispose();
}
}

+ 3
- 3
src/TensorFlowNET.Core/Graphs/c_api.graph.cs View File

@@ -92,7 +92,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param>
/// <returns>TF_ImportGraphDefResults*</returns>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TF_GraphImportGraphDefWithResults(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status);
public static extern SafeImportGraphDefResultsHandle TF_GraphImportGraphDefWithResults(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status);

/// <summary>
/// Import the graph serialized in `graph_def` into `graph`.
@@ -258,7 +258,7 @@ namespace Tensorflow
/// <param name="num_opers">int*</param>
/// <param name="opers">TF_Operation***</param>
[DllImport(TensorFlowLibName)]
public static extern void TF_ImportGraphDefResultsReturnOperations(IntPtr results, ref int num_opers, ref TF_Operation opers);
public static extern void TF_ImportGraphDefResultsReturnOperations(SafeImportGraphDefResultsHandle results, ref int num_opers, ref TF_Operation opers);

/// <summary>
/// Fetches the return outputs requested via
@@ -270,7 +270,7 @@ namespace Tensorflow
/// <param name="num_outputs">int*</param>
/// <param name="outputs">TF_Output**</param>
[DllImport(TensorFlowLibName)]
public static extern void TF_ImportGraphDefResultsReturnOutputs(IntPtr results, ref int num_outputs, ref IntPtr outputs);
public static extern void TF_ImportGraphDefResultsReturnOutputs(SafeImportGraphDefResultsHandle results, ref int num_outputs, ref IntPtr outputs);

/// <summary>
/// This function creates a new TF_Session (which is created on success) using


+ 32
- 27
test/TensorFlowNET.UnitTest/GraphTest.cs View File

@@ -258,11 +258,9 @@ namespace TensorFlowNET.UnitTest.NativeAPI
EXPECT_EQ(0, neg.NumControlOutputs);
EXPECT_EQ(0, neg.GetControlOutputs().Length);

// Import it again, with an input mapping, return outputs, and a return
// operation, into the same graph.
IntPtr results;
using (var opts = c_api.TF_NewImportGraphDefOptions())
static SafeImportGraphDefResultsHandle ImportGraph(Status s, Graph graph, Buffer graph_def, Operation scalar)
{
using var opts = c_api.TF_NewImportGraphDefOptions();
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported2");
c_api.TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, new TF_Output(scalar, 0));
c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0);
@@ -270,32 +268,39 @@ namespace TensorFlowNET.UnitTest.NativeAPI
EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts));
c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar");
EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts));
results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def.Handle, opts, s.Handle);
var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def.Handle, opts, s.Handle);
EXPECT_EQ(TF_Code.TF_OK, s.Code);
}

Operation scalar2 = graph.OperationByName("imported2/scalar");
Operation feed2 = graph.OperationByName("imported2/feed");
Operation neg2 = graph.OperationByName("imported2/neg");

// Check input mapping
neg_input = neg.Input(0);
EXPECT_EQ(scalar, neg_input.oper);
EXPECT_EQ(0, neg_input.index);

// Check return outputs
var return_outputs = graph.ReturnOutputs(results);
ASSERT_EQ(2, return_outputs.Length);
EXPECT_EQ(feed2, return_outputs[0].oper);
EXPECT_EQ(0, return_outputs[0].index);
EXPECT_EQ(scalar, return_outputs[1].oper); // remapped
EXPECT_EQ(0, return_outputs[1].index);
return results;
}

// Check return operation
var return_opers = graph.ReturnOperations(results);
ASSERT_EQ(1, return_opers.Length);
EXPECT_EQ(scalar2, return_opers[0]); // not remapped
c_api.TF_DeleteImportGraphDefResults(results);
// Import it again, with an input mapping, return outputs, and a return
// operation, into the same graph.
Operation feed2;
using (SafeImportGraphDefResultsHandle results = ImportGraph(s, graph, graph_def, scalar))
{
Operation scalar2 = graph.OperationByName("imported2/scalar");
feed2 = graph.OperationByName("imported2/feed");
Operation neg2 = graph.OperationByName("imported2/neg");

// Check input mapping
neg_input = neg.Input(0);
EXPECT_EQ(scalar, neg_input.oper);
EXPECT_EQ(0, neg_input.index);

// Check return outputs
var return_outputs = graph.ReturnOutputs(results);
ASSERT_EQ(2, return_outputs.Length);
EXPECT_EQ(feed2, return_outputs[0].oper);
EXPECT_EQ(0, return_outputs[0].index);
EXPECT_EQ(scalar, return_outputs[1].oper); // remapped
EXPECT_EQ(0, return_outputs[1].index);

// Check return operation
var return_opers = graph.ReturnOperations(results);
ASSERT_EQ(1, return_opers.Length);
EXPECT_EQ(scalar2, return_opers[0]); // not remapped
}

// Import again, with control dependencies, into the same graph.
using (var opts = c_api.TF_NewImportGraphDefOptions())


Loading…
Cancel
Save