From ab6d624ea886e42af069485639ea4b97eb8b28a2 Mon Sep 17 00:00:00 2001 From: Sam Harwell Date: Tue, 4 Feb 2020 18:15:02 -0800 Subject: [PATCH] Implement SafeContextHandle as a wrapper for TFE_Context --- src/TensorFlowNET.Core/Device/c_api.device.cs | 3 +- src/TensorFlowNET.Core/Eager/Context.cs | 23 ++++------- .../Eager/EagerOperation.cs | 2 +- .../Eager/SafeContextHandle.cs | 40 +++++++++++++++++++ src/TensorFlowNET.Core/Eager/TFE_Context.cs | 23 ----------- src/TensorFlowNET.Core/Eager/c_api.eager.cs | 20 +++++----- .../NativeAPI/CApiTest.cs | 14 +++---- .../NativeAPI/Eager/CApi.Eager.Context.cs | 13 +++--- .../Eager/CApi.Eager.Execute_MatMul_CPU.cs | 36 +++++++++-------- .../CApi.Eager.OpGetInputAndOutputLengths.cs | 3 +- ...pi.Eager.OpInferMixedTypeInputListAttrs.cs | 3 +- .../Eager/CApi.Eager.TensorHandleDevices.cs | 3 +- .../NativeAPI/Eager/CApi.Eager.Variables.cs | 3 +- .../NativeAPI/Eager/CApi.Eager.cs | 9 +++-- 14 files changed, 103 insertions(+), 92 deletions(-) create mode 100644 src/TensorFlowNET.Core/Eager/SafeContextHandle.cs delete mode 100644 src/TensorFlowNET.Core/Eager/TFE_Context.cs diff --git a/src/TensorFlowNET.Core/Device/c_api.device.cs b/src/TensorFlowNET.Core/Device/c_api.device.cs index 698aa227..8715ad26 100644 --- a/src/TensorFlowNET.Core/Device/c_api.device.cs +++ b/src/TensorFlowNET.Core/Device/c_api.device.cs @@ -16,6 +16,7 @@ using System; using System.Runtime.InteropServices; +using Tensorflow.Eager; namespace Tensorflow { @@ -64,7 +65,7 @@ namespace Tensorflow /// TF_Status* /// TFE_TensorHandle* [DllImport(TensorFlowLibName)] - public static extern IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, IntPtr ctx, string device_name, SafeStatusHandle status); + public static extern IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status); /// /// Retrieves the full name of the device (e.g. /job:worker/replica:0/...) diff --git a/src/TensorFlowNET.Core/Eager/Context.cs b/src/TensorFlowNET.Core/Eager/Context.cs index d697fafa..36985438 100644 --- a/src/TensorFlowNET.Core/Eager/Context.cs +++ b/src/TensorFlowNET.Core/Eager/Context.cs @@ -2,7 +2,7 @@ namespace Tensorflow.Eager { - public class Context : DisposableObject + public sealed class Context : IDisposable { public const int GRAPH_MODE = 0; public const int EAGER_MODE = 1; @@ -12,9 +12,11 @@ namespace Tensorflow.Eager public string scope_name = ""; bool _initialized = false; + public SafeContextHandle Handle { get; } + public Context(ContextOptions opts, Status status) { - _handle = c_api.TFE_NewContext(opts, status.Handle); + Handle = c_api.TFE_NewContext(opts, status.Handle); status.Check(true); } @@ -29,16 +31,10 @@ namespace Tensorflow.Eager } public void start_step() - => c_api.TFE_ContextStartStep(_handle); + => c_api.TFE_ContextStartStep(Handle); public void end_step() - => c_api.TFE_ContextEndStep(_handle); - - /// - /// Dispose any unmanaged resources related to given . - /// - protected sealed override void DisposeUnmanagedResources(IntPtr handle) - => c_api.TFE_DeleteContext(_handle); + => c_api.TFE_ContextEndStep(Handle); public bool executing_eagerly() => default_execution_mode == EAGER_MODE; @@ -48,10 +44,7 @@ namespace Tensorflow.Eager name : "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; - public static implicit operator IntPtr(Context ctx) - => ctx._handle; - - public static implicit operator TFE_Context(Context ctx) - => new TFE_Context(ctx._handle); + public void Dispose() + => Handle.Dispose(); } } diff --git a/src/TensorFlowNET.Core/Eager/EagerOperation.cs b/src/TensorFlowNET.Core/Eager/EagerOperation.cs index fdcf1cd5..982198f8 100644 --- a/src/TensorFlowNET.Core/Eager/EagerOperation.cs +++ b/src/TensorFlowNET.Core/Eager/EagerOperation.cs @@ -53,7 +53,7 @@ namespace Tensorflow.Eager { object value = null; byte isList = 0; - var attrType = c_api.TFE_OpNameGetAttrType(tf.context, Name, attr_name, ref isList, tf.status.Handle); + var attrType = c_api.TFE_OpNameGetAttrType(tf.context.Handle, Name, attr_name, ref isList, tf.status.Handle); switch (attrType) { case TF_AttrType.TF_ATTR_BOOL: diff --git a/src/TensorFlowNET.Core/Eager/SafeContextHandle.cs b/src/TensorFlowNET.Core/Eager/SafeContextHandle.cs new file mode 100644 index 00000000..9fa8db79 --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/SafeContextHandle.cs @@ -0,0 +1,40 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using Tensorflow.Util; + +namespace Tensorflow.Eager +{ + public sealed class SafeContextHandle : SafeTensorflowHandle + { + public SafeContextHandle() + { + } + + public SafeContextHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { + c_api.TFE_DeleteContext(handle); + SetHandle(IntPtr.Zero); + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/TFE_Context.cs b/src/TensorFlowNET.Core/Eager/TFE_Context.cs deleted file mode 100644 index dc16909d..00000000 --- a/src/TensorFlowNET.Core/Eager/TFE_Context.cs +++ /dev/null @@ -1,23 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; - -namespace Tensorflow.Eager -{ - public struct TFE_Context - { - IntPtr _handle; - - public TFE_Context(IntPtr handle) - => _handle = handle; - - public static implicit operator TFE_Context(IntPtr handle) - => new TFE_Context(handle); - - public static implicit operator IntPtr(TFE_Context tensor) - => tensor._handle; - - public override string ToString() - => $"TFE_Context {_handle}"; - } -} diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs index 4e9f59c1..9b4a132f 100644 --- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs +++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs @@ -73,7 +73,7 @@ namespace Tensorflow public static extern TF_AttrType TFE_OpGetAttrType(IntPtr op, string attr_name, ref byte is_list, SafeStatusHandle status); [DllImport(TensorFlowLibName)] - public static extern TF_AttrType TFE_OpNameGetAttrType(IntPtr ct, string op_or_function_name, string attr_name, ref byte is_list, SafeStatusHandle status); + public static extern TF_AttrType TFE_OpNameGetAttrType(SafeContextHandle ctx, string op_or_function_name, string attr_name, ref byte is_list, SafeStatusHandle status); /// /// Returns the length (number of tensors) of the input argument `input_name` @@ -114,13 +114,13 @@ namespace Tensorflow /// TF_Status* /// TFE_Context* [DllImport(TensorFlowLibName)] - public static extern TFE_Context TFE_NewContext(IntPtr opts, SafeStatusHandle status); + public static extern SafeContextHandle TFE_NewContext(IntPtr opts, SafeStatusHandle status); [DllImport(TensorFlowLibName)] - public static extern TFE_Context TFE_ContextStartStep(IntPtr ctx); + public static extern void TFE_ContextStartStep(SafeContextHandle ctx); [DllImport(TensorFlowLibName)] - public static extern TFE_Context TFE_ContextEndStep(IntPtr ctx); + public static extern void TFE_ContextEndStep(SafeContextHandle ctx); /// /// @@ -148,7 +148,7 @@ namespace Tensorflow /// TF_Status* /// [DllImport(TensorFlowLibName)] - public static extern TFE_Op TFE_NewOp(IntPtr ctx, string op_or_function_name, SafeStatusHandle status); + public static extern TFE_Op TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status); /// /// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This @@ -317,7 +317,7 @@ namespace Tensorflow /// TF_Status* /// [DllImport(TensorFlowLibName)] - public static extern IntPtr TFE_ContextListDevices(IntPtr ctx, SafeStatusHandle status); + public static extern IntPtr TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status); /// /// @@ -379,7 +379,7 @@ namespace Tensorflow /// /// [DllImport(TensorFlowLibName)] - public static extern void TFE_ContextSetExecutorForThread(IntPtr ctx, TFE_Executor executor); + public static extern void TFE_ContextSetExecutorForThread(SafeContextHandle ctx, TFE_Executor executor); /// /// Returns the Executor for current thread. @@ -387,7 +387,7 @@ namespace Tensorflow /// /// TFE_Executor* [DllImport(TensorFlowLibName)] - public static extern TFE_Executor TFE_ContextGetExecutorForThread(IntPtr ctx); + public static extern TFE_Executor TFE_ContextGetExecutorForThread(SafeContextHandle ctx); /// /// @@ -402,7 +402,7 @@ namespace Tensorflow /// /// EagerTensorHandle [DllImport(TensorFlowLibName)] - public static extern SafeStatusHandle TFE_FastPathExecute(IntPtr ctx, + public static extern SafeStatusHandle TFE_FastPathExecute(SafeContextHandle ctx, string device_name, string op_name, string name, @@ -416,7 +416,7 @@ namespace Tensorflow public delegate void TFE_FastPathExecute_SetOpAttrs(IntPtr op); [DllImport(TensorFlowLibName)] - public static extern SafeStatusHandle TFE_QuickExecute(IntPtr ctx, + public static extern SafeStatusHandle TFE_QuickExecute(SafeContextHandle ctx, string device_name, string op_name, IntPtr[] inputs, diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs b/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs index 17d95945..2ee6cd29 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs @@ -1,6 +1,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; using Tensorflow; +using Tensorflow.Eager; using Buffer = System.Buffer; namespace TensorFlowNET.UnitTest @@ -92,7 +93,7 @@ namespace TensorFlowNET.UnitTest protected void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length) => c_api.TFE_OpSetAttrString(op, attr_name, value, length); - protected IntPtr TFE_NewOp(IntPtr ctx, string op_or_function_name, SafeStatusHandle status) + protected IntPtr TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status) => c_api.TFE_NewOp(ctx, op_or_function_name, status); protected IntPtr TFE_NewTensorHandle(IntPtr t, SafeStatusHandle status) @@ -104,10 +105,7 @@ namespace TensorFlowNET.UnitTest protected IntPtr TFE_NewContextOptions() => c_api.TFE_NewContextOptions(); - protected void TFE_DeleteContext(IntPtr t) - => c_api.TFE_DeleteContext(t); - - protected IntPtr TFE_NewContext(IntPtr opts, SafeStatusHandle status) + protected SafeContextHandle TFE_NewContext(IntPtr opts, SafeStatusHandle status) => c_api.TFE_NewContext(opts, status); protected void TFE_DeleteContextOptions(IntPtr opts) @@ -131,7 +129,7 @@ namespace TensorFlowNET.UnitTest protected void TFE_DeleteExecutor(IntPtr executor) => c_api.TFE_DeleteExecutor(executor); - protected IntPtr TFE_ContextGetExecutorForThread(IntPtr ctx) + protected IntPtr TFE_ContextGetExecutorForThread(SafeContextHandle ctx) => c_api.TFE_ContextGetExecutorForThread(ctx); protected void TFE_ExecutorWaitForAllPendingNodes(IntPtr executor, SafeStatusHandle status) @@ -146,7 +144,7 @@ namespace TensorFlowNET.UnitTest protected string TFE_TensorHandleBackingDeviceName(IntPtr h, SafeStatusHandle status) => c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status)); - protected IntPtr TFE_ContextListDevices(IntPtr ctx, SafeStatusHandle status) + protected IntPtr TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status) => c_api.TFE_ContextListDevices(ctx, status); protected int TF_DeviceListCount(IntPtr list) @@ -161,7 +159,7 @@ namespace TensorFlowNET.UnitTest protected void TF_DeleteDeviceList(IntPtr list) => c_api.TF_DeleteDeviceList(list); - protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, IntPtr ctx, string device_name, SafeStatusHandle status) + protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status) => c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status); protected void TFE_OpSetDevice(IntPtr op, string device_name, SafeStatusHandle status) diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Context.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Context.cs index 617edb8b..220461be 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Context.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Context.cs @@ -1,7 +1,6 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; using Tensorflow; -using Tensorflow.Eager; namespace TensorFlowNET.UnitTest.NativeAPI { @@ -15,14 +14,16 @@ namespace TensorFlowNET.UnitTest.NativeAPI { using var status = c_api.TF_NewStatus(); var opts = c_api.TFE_NewContextOptions(); - var ctx = c_api.TFE_NewContext(opts, status); - c_api.TFE_DeleteContextOptions(opts); + IntPtr devices; + using (var ctx = c_api.TFE_NewContext(opts, status)) + { + c_api.TFE_DeleteContextOptions(opts); - var devices = c_api.TFE_ContextListDevices(ctx, status); - EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + devices = c_api.TFE_ContextListDevices(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + } - c_api.TFE_DeleteContext(ctx); EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); int num_devices = c_api.TF_DeviceListCount(devices); diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs index 554714bc..d6b70fcd 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs @@ -21,24 +21,28 @@ namespace TensorFlowNET.UnitTest.NativeAPI using var status = TF_NewStatus(); var opts = TFE_NewContextOptions(); c_api.TFE_ContextOptionsSetAsync(opts, Convert.ToByte(async)); - var ctx = TFE_NewContext(opts, status); - CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); - TFE_DeleteContextOptions(opts); - var m = TestMatrixTensorHandle(); - var matmul = MatMulOp(ctx, m, m); - var retvals = new IntPtr[] { IntPtr.Zero, IntPtr.Zero }; - int num_retvals = 2; - c_api.TFE_Execute(matmul, retvals, ref num_retvals, status); - EXPECT_EQ(1, num_retvals); - EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); - TFE_DeleteOp(matmul); - TFE_DeleteTensorHandle(m); + IntPtr t; + using (var ctx = TFE_NewContext(opts, status)) + { + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + TFE_DeleteContextOptions(opts); + + var m = TestMatrixTensorHandle(); + var matmul = MatMulOp(ctx, m, m); + var retvals = new IntPtr[] { IntPtr.Zero, IntPtr.Zero }; + int num_retvals = 2; + c_api.TFE_Execute(matmul, retvals, ref num_retvals, status); + EXPECT_EQ(1, num_retvals); + EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + TFE_DeleteOp(matmul); + TFE_DeleteTensorHandle(m); + + t = TFE_TensorHandleResolve(retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + TFE_DeleteTensorHandle(retvals[0]); + } - var t = TFE_TensorHandleResolve(retvals[0], status); - ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); - TFE_DeleteTensorHandle(retvals[0]); - TFE_DeleteContext(ctx); ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); var product = new float[4]; EXPECT_EQ(product.Length * sizeof(float), (int)TF_TensorByteSize(t)); diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs index 5128f114..527ef9e6 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs @@ -16,7 +16,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI { using var status = TF_NewStatus(); var opts = TFE_NewContextOptions(); - var ctx = TFE_NewContext(opts, status); + using var ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); TFE_DeleteContextOptions(opts); @@ -57,7 +57,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI TFE_DeleteTensorHandle(input2); TFE_DeleteTensorHandle(retvals[0]); TFE_DeleteTensorHandle(retvals[1]); - TFE_DeleteContext(ctx); } } } diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs index 7f6f74fc..9c6a8ecc 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs @@ -18,7 +18,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI { using var status = TF_NewStatus(); var opts = TFE_NewContextOptions(); - var ctx = TFE_NewContext(opts, status); + using var ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); TFE_DeleteContextOptions(opts); @@ -50,7 +50,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI TFE_DeleteTensorHandle(t1); TFE_DeleteTensorHandle(t2); TFE_DeleteTensorHandle(retvals[0]); - TFE_DeleteContext(ctx); } } } diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs index aa8d9ffb..297530dc 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs @@ -16,7 +16,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI { var status = c_api.TF_NewStatus(); var opts = TFE_NewContextOptions(); - var ctx = TFE_NewContext(opts, status); + using var ctx = TFE_NewContext(opts, status); TFE_DeleteContextOptions(opts); ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); @@ -65,7 +65,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI TFE_ExecutorWaitForAllPendingNodes(executor, status); ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); TFE_DeleteExecutor(executor); - TFE_DeleteContext(ctx); } } } diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs index 0713ec15..40eb880b 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs @@ -15,7 +15,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI { using var status = c_api.TF_NewStatus(); var opts = TFE_NewContextOptions(); - var ctx = TFE_NewContext(opts, status); + using var ctx = TFE_NewContext(opts, status); ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); TFE_DeleteContextOptions(opts); @@ -47,7 +47,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI TFE_DeleteTensorHandle(var_handle); TFE_DeleteTensorHandle(value_handle[0]); - TFE_DeleteContext(ctx); CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); } } diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs index 34184672..ebf17f7c 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs @@ -1,6 +1,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; using Tensorflow; +using Tensorflow.Eager; using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest.NativeAPI @@ -25,7 +26,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI return th; } - IntPtr MatMulOp(IntPtr ctx, IntPtr a, IntPtr b) + IntPtr MatMulOp(SafeContextHandle ctx, IntPtr a, IntPtr b) { using var status = TF_NewStatus(); @@ -40,7 +41,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI return op; } - bool GetDeviceName(IntPtr ctx, ref string device_name, string device_type) + bool GetDeviceName(SafeContextHandle ctx, ref string device_name, string device_type) { var status = TF_NewStatus(); var devices = TFE_ContextListDevices(ctx, status); @@ -65,7 +66,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI return false; } - IntPtr ShapeOp(IntPtr ctx, IntPtr a) + IntPtr ShapeOp(SafeContextHandle ctx, IntPtr a) { using var status = TF_NewStatus(); @@ -78,7 +79,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI return op; } - unsafe IntPtr CreateVariable(IntPtr ctx, float value, SafeStatusHandle status) + unsafe IntPtr CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status) { var op = TFE_NewOp(ctx, "VarHandleOp", status); if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;