diff --git a/src/TensorFlowNET.Core/Eager/Context.cs b/src/TensorFlowNET.Core/Eager/Context.cs index 36985438..95c2a832 100644 --- a/src/TensorFlowNET.Core/Eager/Context.cs +++ b/src/TensorFlowNET.Core/Eager/Context.cs @@ -16,7 +16,7 @@ namespace Tensorflow.Eager public Context(ContextOptions opts, Status status) { - Handle = c_api.TFE_NewContext(opts, status.Handle); + Handle = c_api.TFE_NewContext(opts.Handle, status.Handle); status.Check(true); } diff --git a/src/TensorFlowNET.Core/Eager/ContextOptions.cs b/src/TensorFlowNET.Core/Eager/ContextOptions.cs index 8659b6ce..399c7a0e 100644 --- a/src/TensorFlowNET.Core/Eager/ContextOptions.cs +++ b/src/TensorFlowNET.Core/Eager/ContextOptions.cs @@ -1,26 +1,33 @@ -using System; -using System.IO; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; namespace Tensorflow.Eager { - public class ContextOptions : DisposableObject + public sealed class ContextOptions : IDisposable { - public ContextOptions() : base(c_api.TFE_NewContextOptions()) - { } + public SafeContextOptionsHandle Handle { get; } - /// - /// Dispose any unmanaged resources related to given . - /// - protected sealed override void DisposeUnmanagedResources(IntPtr handle) - => c_api.TFE_DeleteContextOptions(_handle); + public ContextOptions() + { + Handle = c_api.TFE_NewContextOptions(); + } - - public static implicit operator IntPtr(ContextOptions opts) - => opts._handle; - - public static implicit operator TFE_ContextOptions(ContextOptions opts) - => new TFE_ContextOptions(opts._handle); - + public void Dispose() + => Handle.Dispose(); } - } diff --git a/src/TensorFlowNET.Core/Eager/SafeContextOptionsHandle.cs b/src/TensorFlowNET.Core/Eager/SafeContextOptionsHandle.cs new file mode 100644 index 00000000..aac6ff00 --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/SafeContextOptionsHandle.cs @@ -0,0 +1,40 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using Tensorflow.Util; + +namespace Tensorflow.Eager +{ + public sealed class SafeContextOptionsHandle : SafeTensorflowHandle + { + public SafeContextOptionsHandle() + { + } + + public SafeContextOptionsHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { + c_api.TFE_DeleteContextOptions(handle); + SetHandle(IntPtr.Zero); + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/TFE_ContextOptions.cs b/src/TensorFlowNET.Core/Eager/TFE_ContextOptions.cs deleted file mode 100644 index f43d97f8..00000000 --- a/src/TensorFlowNET.Core/Eager/TFE_ContextOptions.cs +++ /dev/null @@ -1,23 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; - -namespace Tensorflow.Eager -{ - public struct TFE_ContextOptions - { - IntPtr _handle; - - public TFE_ContextOptions(IntPtr handle) - => _handle = handle; - - public static implicit operator TFE_ContextOptions(IntPtr handle) - => new TFE_ContextOptions(handle); - - public static implicit operator IntPtr(TFE_ContextOptions tensor) - => tensor._handle; - - public override string ToString() - => $"TFE_ContextOptions {_handle}"; - } -} diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs index 9b4a132f..af63dd67 100644 --- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs +++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs @@ -52,7 +52,7 @@ namespace Tensorflow /// /// TFE_ContextOptions* [DllImport(TensorFlowLibName)] - public static extern TFE_ContextOptions TFE_NewContextOptions(); + public static extern SafeContextOptionsHandle TFE_NewContextOptions(); /// /// Destroy an options object. @@ -114,7 +114,7 @@ namespace Tensorflow /// TF_Status* /// TFE_Context* [DllImport(TensorFlowLibName)] - public static extern SafeContextHandle TFE_NewContext(IntPtr opts, SafeStatusHandle status); + public static extern SafeContextHandle TFE_NewContext(SafeContextOptionsHandle opts, SafeStatusHandle status); [DllImport(TensorFlowLibName)] public static extern void TFE_ContextStartStep(SafeContextHandle ctx); @@ -254,7 +254,7 @@ namespace Tensorflow /// TFE_ContextOptions* /// unsigned char [DllImport(TensorFlowLibName)] - public static extern void TFE_ContextOptionsSetAsync(IntPtr opts, byte enable); + public static extern void TFE_ContextOptionsSetAsync(SafeContextOptionsHandle opts, byte enable); /// /// diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs b/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs index 2ee6cd29..07abf18e 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs @@ -102,15 +102,12 @@ namespace TensorFlowNET.UnitTest protected void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status) => c_api.TFE_Execute(op, retvals, ref num_retvals, status); - protected IntPtr TFE_NewContextOptions() + protected SafeContextOptionsHandle TFE_NewContextOptions() => c_api.TFE_NewContextOptions(); - protected SafeContextHandle TFE_NewContext(IntPtr opts, SafeStatusHandle status) + protected SafeContextHandle TFE_NewContext(SafeContextOptionsHandle opts, SafeStatusHandle status) => c_api.TFE_NewContext(opts, status); - protected void TFE_DeleteContextOptions(IntPtr opts) - => c_api.TFE_DeleteContextOptions(opts); - protected int TFE_OpGetInputLength(IntPtr op, string input_name, SafeStatusHandle status) => c_api.TFE_OpGetInputLength(op, input_name, status); diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Context.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Context.cs index 220461be..84b2f54c 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Context.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Context.cs @@ -1,6 +1,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; using Tensorflow; +using Tensorflow.Eager; namespace TensorFlowNET.UnitTest.NativeAPI { @@ -13,13 +14,16 @@ namespace TensorFlowNET.UnitTest.NativeAPI public void Context() { using var status = c_api.TF_NewStatus(); - var opts = c_api.TFE_NewContextOptions(); - IntPtr devices; - using (var ctx = c_api.TFE_NewContext(opts, status)) + static SafeContextHandle NewContext(SafeStatusHandle status) { - c_api.TFE_DeleteContextOptions(opts); + using var opts = c_api.TFE_NewContextOptions(); + return c_api.TFE_NewContext(opts, status); + } + IntPtr devices; + using (var ctx = NewContext(status)) + { devices = c_api.TFE_ContextListDevices(ctx, status); EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); } diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs index d6b70fcd..3e6c930b 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs @@ -1,6 +1,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; using Tensorflow; +using Tensorflow.Eager; using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest.NativeAPI @@ -19,14 +20,18 @@ namespace TensorFlowNET.UnitTest.NativeAPI unsafe void Execute_MatMul_CPU(bool async) { using var status = TF_NewStatus(); - var opts = TFE_NewContextOptions(); - c_api.TFE_ContextOptionsSetAsync(opts, Convert.ToByte(async)); + + static SafeContextHandle NewContext(bool async, SafeStatusHandle status) + { + using var opts = c_api.TFE_NewContextOptions(); + c_api.TFE_ContextOptionsSetAsync(opts, Convert.ToByte(async)); + return c_api.TFE_NewContext(opts, status); + } IntPtr t; - using (var ctx = TFE_NewContext(opts, status)) + using (var ctx = NewContext(async, status)) { CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); - TFE_DeleteContextOptions(opts); var m = TestMatrixTensorHandle(); var matmul = MatMulOp(ctx, m, m); diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs index 527ef9e6..7b6c8f38 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs @@ -2,7 +2,6 @@ using System; using Tensorflow; using Tensorflow.Eager; -using Buffer = System.Buffer; namespace TensorFlowNET.UnitTest.NativeAPI { @@ -15,10 +14,15 @@ namespace TensorFlowNET.UnitTest.NativeAPI public unsafe void OpGetInputAndOutputLengths() { using var status = TF_NewStatus(); - var opts = TFE_NewContextOptions(); - using var ctx = TFE_NewContext(opts, status); + + static SafeContextHandle NewContext(SafeStatusHandle status) + { + using var opts = c_api.TFE_NewContextOptions(); + return c_api.TFE_NewContext(opts, status); + } + + using var ctx = NewContext(status); CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); - TFE_DeleteContextOptions(opts); var input1 = TestMatrixTensorHandle(); var input2 = TestMatrixTensorHandle(); diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs index 9c6a8ecc..9c903e4b 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs @@ -1,10 +1,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; -using System.Collections.Generic; using Tensorflow; using Tensorflow.Eager; -using Buffer = System.Buffer; -using System.Linq; namespace TensorFlowNET.UnitTest.NativeAPI { @@ -17,10 +14,15 @@ namespace TensorFlowNET.UnitTest.NativeAPI public unsafe void OpInferMixedTypeInputListAttrs() { using var status = TF_NewStatus(); - var opts = TFE_NewContextOptions(); - using var ctx = TFE_NewContext(opts, status); + + static SafeContextHandle NewContext(SafeStatusHandle status) + { + using var opts = c_api.TFE_NewContextOptions(); + return c_api.TFE_NewContext(opts, status); + } + + using var ctx = NewContext(status); CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); - TFE_DeleteContextOptions(opts); var condition = TestScalarTensorHandle(true); var t1 = TestMatrixTensorHandle(); diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs index 297530dc..812883c3 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs @@ -2,7 +2,6 @@ using System; using Tensorflow; using Tensorflow.Eager; -using Buffer = System.Buffer; namespace TensorFlowNET.UnitTest.NativeAPI { @@ -15,9 +14,14 @@ namespace TensorFlowNET.UnitTest.NativeAPI public unsafe void TensorHandleDevices() { var status = c_api.TF_NewStatus(); - var opts = TFE_NewContextOptions(); - using var ctx = TFE_NewContext(opts, status); - TFE_DeleteContextOptions(opts); + + static SafeContextHandle NewContext(SafeStatusHandle status) + { + using var opts = c_api.TFE_NewContextOptions(); + return c_api.TFE_NewContext(opts, status); + } + + using var ctx = NewContext(status); ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); var hcpu = TestMatrixTensorHandle(); diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs index 40eb880b..df555bf0 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs @@ -1,6 +1,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; using Tensorflow; +using Tensorflow.Eager; using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest.NativeAPI @@ -14,10 +15,15 @@ namespace TensorFlowNET.UnitTest.NativeAPI public unsafe void Variables() { using var status = c_api.TF_NewStatus(); - var opts = TFE_NewContextOptions(); - using var ctx = TFE_NewContext(opts, status); + + static SafeContextHandle NewContext(SafeStatusHandle status) + { + using var opts = c_api.TFE_NewContextOptions(); + return c_api.TFE_NewContext(opts, status); + } + + using var ctx = NewContext(status); ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); - TFE_DeleteContextOptions(opts); var var_handle = CreateVariable(ctx, 12.0f, status); ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));