@@ -16,6 +16,7 @@ | |||||
using System; | using System; | ||||
using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
using Tensorflow.Eager; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -64,7 +65,7 @@ namespace Tensorflow | |||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
/// <returns>TFE_TensorHandle*</returns> | /// <returns>TFE_TensorHandle*</returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, IntPtr ctx, string device_name, SafeStatusHandle status); | |||||
public static extern IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// Retrieves the full name of the device (e.g. /job:worker/replica:0/...) | /// Retrieves the full name of the device (e.g. /job:worker/replica:0/...) | ||||
@@ -2,7 +2,7 @@ | |||||
namespace Tensorflow.Eager | namespace Tensorflow.Eager | ||||
{ | { | ||||
public class Context : DisposableObject | |||||
public sealed class Context : IDisposable | |||||
{ | { | ||||
public const int GRAPH_MODE = 0; | public const int GRAPH_MODE = 0; | ||||
public const int EAGER_MODE = 1; | public const int EAGER_MODE = 1; | ||||
@@ -12,9 +12,11 @@ namespace Tensorflow.Eager | |||||
public string scope_name = ""; | public string scope_name = ""; | ||||
bool _initialized = false; | bool _initialized = false; | ||||
public SafeContextHandle Handle { get; } | |||||
public Context(ContextOptions opts, Status status) | public Context(ContextOptions opts, Status status) | ||||
{ | { | ||||
_handle = c_api.TFE_NewContext(opts, status.Handle); | |||||
Handle = c_api.TFE_NewContext(opts, status.Handle); | |||||
status.Check(true); | status.Check(true); | ||||
} | } | ||||
@@ -29,16 +31,10 @@ namespace Tensorflow.Eager | |||||
} | } | ||||
public void start_step() | public void start_step() | ||||
=> c_api.TFE_ContextStartStep(_handle); | |||||
=> c_api.TFE_ContextStartStep(Handle); | |||||
public void end_step() | public void end_step() | ||||
=> c_api.TFE_ContextEndStep(_handle); | |||||
/// <summary> | |||||
/// Dispose any unmanaged resources related to given <paramref name="handle"/>. | |||||
/// </summary> | |||||
protected sealed override void DisposeUnmanagedResources(IntPtr handle) | |||||
=> c_api.TFE_DeleteContext(_handle); | |||||
=> c_api.TFE_ContextEndStep(Handle); | |||||
public bool executing_eagerly() | public bool executing_eagerly() | ||||
=> default_execution_mode == EAGER_MODE; | => default_execution_mode == EAGER_MODE; | ||||
@@ -48,10 +44,7 @@ namespace Tensorflow.Eager | |||||
name : | name : | ||||
"cd2c89b7-88b7-44c8-ad83-06c2a9158347"; | "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; | ||||
public static implicit operator IntPtr(Context ctx) | |||||
=> ctx._handle; | |||||
public static implicit operator TFE_Context(Context ctx) | |||||
=> new TFE_Context(ctx._handle); | |||||
public void Dispose() | |||||
=> Handle.Dispose(); | |||||
} | } | ||||
} | } |
@@ -53,7 +53,7 @@ namespace Tensorflow.Eager | |||||
{ | { | ||||
object value = null; | object value = null; | ||||
byte isList = 0; | byte isList = 0; | ||||
var attrType = c_api.TFE_OpNameGetAttrType(tf.context, Name, attr_name, ref isList, tf.status.Handle); | |||||
var attrType = c_api.TFE_OpNameGetAttrType(tf.context.Handle, Name, attr_name, ref isList, tf.status.Handle); | |||||
switch (attrType) | switch (attrType) | ||||
{ | { | ||||
case TF_AttrType.TF_ATTR_BOOL: | case TF_AttrType.TF_ATTR_BOOL: | ||||
@@ -0,0 +1,40 @@ | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using System; | |||||
using Tensorflow.Util; | |||||
namespace Tensorflow.Eager | |||||
{ | |||||
public sealed class SafeContextHandle : SafeTensorflowHandle | |||||
{ | |||||
public SafeContextHandle() | |||||
{ | |||||
} | |||||
public SafeContextHandle(IntPtr handle) | |||||
: base(handle) | |||||
{ | |||||
} | |||||
protected override bool ReleaseHandle() | |||||
{ | |||||
c_api.TFE_DeleteContext(handle); | |||||
SetHandle(IntPtr.Zero); | |||||
return true; | |||||
} | |||||
} | |||||
} |
@@ -1,23 +0,0 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Eager | |||||
{ | |||||
public struct TFE_Context | |||||
{ | |||||
IntPtr _handle; | |||||
public TFE_Context(IntPtr handle) | |||||
=> _handle = handle; | |||||
public static implicit operator TFE_Context(IntPtr handle) | |||||
=> new TFE_Context(handle); | |||||
public static implicit operator IntPtr(TFE_Context tensor) | |||||
=> tensor._handle; | |||||
public override string ToString() | |||||
=> $"TFE_Context {_handle}"; | |||||
} | |||||
} |
@@ -73,7 +73,7 @@ namespace Tensorflow | |||||
public static extern TF_AttrType TFE_OpGetAttrType(IntPtr op, string attr_name, ref byte is_list, SafeStatusHandle status); | public static extern TF_AttrType TFE_OpGetAttrType(IntPtr op, string attr_name, ref byte is_list, SafeStatusHandle status); | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern TF_AttrType TFE_OpNameGetAttrType(IntPtr ct, string op_or_function_name, string attr_name, ref byte is_list, SafeStatusHandle status); | |||||
public static extern TF_AttrType TFE_OpNameGetAttrType(SafeContextHandle ctx, string op_or_function_name, string attr_name, ref byte is_list, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// Returns the length (number of tensors) of the input argument `input_name` | /// Returns the length (number of tensors) of the input argument `input_name` | ||||
@@ -114,13 +114,13 @@ namespace Tensorflow | |||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
/// <returns>TFE_Context*</returns> | /// <returns>TFE_Context*</returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern TFE_Context TFE_NewContext(IntPtr opts, SafeStatusHandle status); | |||||
public static extern SafeContextHandle TFE_NewContext(IntPtr opts, SafeStatusHandle status); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern TFE_Context TFE_ContextStartStep(IntPtr ctx); | |||||
public static extern void TFE_ContextStartStep(SafeContextHandle ctx); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern TFE_Context TFE_ContextEndStep(IntPtr ctx); | |||||
public static extern void TFE_ContextEndStep(SafeContextHandle ctx); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -148,7 +148,7 @@ namespace Tensorflow | |||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern TFE_Op TFE_NewOp(IntPtr ctx, string op_or_function_name, SafeStatusHandle status); | |||||
public static extern TFE_Op TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This | /// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This | ||||
@@ -317,7 +317,7 @@ namespace Tensorflow | |||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern IntPtr TFE_ContextListDevices(IntPtr ctx, SafeStatusHandle status); | |||||
public static extern IntPtr TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -379,7 +379,7 @@ namespace Tensorflow | |||||
/// <param name="ctx"></param> | /// <param name="ctx"></param> | ||||
/// <param name="executor"></param> | /// <param name="executor"></param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_ContextSetExecutorForThread(IntPtr ctx, TFE_Executor executor); | |||||
public static extern void TFE_ContextSetExecutorForThread(SafeContextHandle ctx, TFE_Executor executor); | |||||
/// <summary> | /// <summary> | ||||
/// Returns the Executor for current thread. | /// Returns the Executor for current thread. | ||||
@@ -387,7 +387,7 @@ namespace Tensorflow | |||||
/// <param name="ctx"></param> | /// <param name="ctx"></param> | ||||
/// <returns>TFE_Executor*</returns> | /// <returns>TFE_Executor*</returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern TFE_Executor TFE_ContextGetExecutorForThread(IntPtr ctx); | |||||
public static extern TFE_Executor TFE_ContextGetExecutorForThread(SafeContextHandle ctx); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -402,7 +402,7 @@ namespace Tensorflow | |||||
/// <param name="status"></param> | /// <param name="status"></param> | ||||
/// <returns>EagerTensorHandle</returns> | /// <returns>EagerTensorHandle</returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern SafeStatusHandle TFE_FastPathExecute(IntPtr ctx, | |||||
public static extern SafeStatusHandle TFE_FastPathExecute(SafeContextHandle ctx, | |||||
string device_name, | string device_name, | ||||
string op_name, | string op_name, | ||||
string name, | string name, | ||||
@@ -416,7 +416,7 @@ namespace Tensorflow | |||||
public delegate void TFE_FastPathExecute_SetOpAttrs(IntPtr op); | public delegate void TFE_FastPathExecute_SetOpAttrs(IntPtr op); | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern SafeStatusHandle TFE_QuickExecute(IntPtr ctx, | |||||
public static extern SafeStatusHandle TFE_QuickExecute(SafeContextHandle ctx, | |||||
string device_name, | string device_name, | ||||
string op_name, | string op_name, | ||||
IntPtr[] inputs, | IntPtr[] inputs, | ||||
@@ -1,6 +1,7 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using System; | using System; | ||||
using Tensorflow; | using Tensorflow; | ||||
using Tensorflow.Eager; | |||||
using Buffer = System.Buffer; | using Buffer = System.Buffer; | ||||
namespace TensorFlowNET.UnitTest | namespace TensorFlowNET.UnitTest | ||||
@@ -92,7 +93,7 @@ namespace TensorFlowNET.UnitTest | |||||
protected void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length) | protected void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length) | ||||
=> c_api.TFE_OpSetAttrString(op, attr_name, value, length); | => c_api.TFE_OpSetAttrString(op, attr_name, value, length); | ||||
protected IntPtr TFE_NewOp(IntPtr ctx, string op_or_function_name, SafeStatusHandle status) | |||||
protected IntPtr TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status) | |||||
=> c_api.TFE_NewOp(ctx, op_or_function_name, status); | => c_api.TFE_NewOp(ctx, op_or_function_name, status); | ||||
protected IntPtr TFE_NewTensorHandle(IntPtr t, SafeStatusHandle status) | protected IntPtr TFE_NewTensorHandle(IntPtr t, SafeStatusHandle status) | ||||
@@ -104,10 +105,7 @@ namespace TensorFlowNET.UnitTest | |||||
protected IntPtr TFE_NewContextOptions() | protected IntPtr TFE_NewContextOptions() | ||||
=> c_api.TFE_NewContextOptions(); | => c_api.TFE_NewContextOptions(); | ||||
protected void TFE_DeleteContext(IntPtr t) | |||||
=> c_api.TFE_DeleteContext(t); | |||||
protected IntPtr TFE_NewContext(IntPtr opts, SafeStatusHandle status) | |||||
protected SafeContextHandle TFE_NewContext(IntPtr opts, SafeStatusHandle status) | |||||
=> c_api.TFE_NewContext(opts, status); | => c_api.TFE_NewContext(opts, status); | ||||
protected void TFE_DeleteContextOptions(IntPtr opts) | protected void TFE_DeleteContextOptions(IntPtr opts) | ||||
@@ -131,7 +129,7 @@ namespace TensorFlowNET.UnitTest | |||||
protected void TFE_DeleteExecutor(IntPtr executor) | protected void TFE_DeleteExecutor(IntPtr executor) | ||||
=> c_api.TFE_DeleteExecutor(executor); | => c_api.TFE_DeleteExecutor(executor); | ||||
protected IntPtr TFE_ContextGetExecutorForThread(IntPtr ctx) | |||||
protected IntPtr TFE_ContextGetExecutorForThread(SafeContextHandle ctx) | |||||
=> c_api.TFE_ContextGetExecutorForThread(ctx); | => c_api.TFE_ContextGetExecutorForThread(ctx); | ||||
protected void TFE_ExecutorWaitForAllPendingNodes(IntPtr executor, SafeStatusHandle status) | protected void TFE_ExecutorWaitForAllPendingNodes(IntPtr executor, SafeStatusHandle status) | ||||
@@ -146,7 +144,7 @@ namespace TensorFlowNET.UnitTest | |||||
protected string TFE_TensorHandleBackingDeviceName(IntPtr h, SafeStatusHandle status) | protected string TFE_TensorHandleBackingDeviceName(IntPtr h, SafeStatusHandle status) | ||||
=> c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status)); | => c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status)); | ||||
protected IntPtr TFE_ContextListDevices(IntPtr ctx, SafeStatusHandle status) | |||||
protected IntPtr TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status) | |||||
=> c_api.TFE_ContextListDevices(ctx, status); | => c_api.TFE_ContextListDevices(ctx, status); | ||||
protected int TF_DeviceListCount(IntPtr list) | protected int TF_DeviceListCount(IntPtr list) | ||||
@@ -161,7 +159,7 @@ namespace TensorFlowNET.UnitTest | |||||
protected void TF_DeleteDeviceList(IntPtr list) | protected void TF_DeleteDeviceList(IntPtr list) | ||||
=> c_api.TF_DeleteDeviceList(list); | => c_api.TF_DeleteDeviceList(list); | ||||
protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, IntPtr ctx, string device_name, SafeStatusHandle status) | |||||
protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status) | |||||
=> c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status); | => c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status); | ||||
protected void TFE_OpSetDevice(IntPtr op, string device_name, SafeStatusHandle status) | protected void TFE_OpSetDevice(IntPtr op, string device_name, SafeStatusHandle status) | ||||
@@ -1,7 +1,6 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using System; | using System; | ||||
using Tensorflow; | using Tensorflow; | ||||
using Tensorflow.Eager; | |||||
namespace TensorFlowNET.UnitTest.NativeAPI | namespace TensorFlowNET.UnitTest.NativeAPI | ||||
{ | { | ||||
@@ -15,14 +14,16 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
{ | { | ||||
using var status = c_api.TF_NewStatus(); | using var status = c_api.TF_NewStatus(); | ||||
var opts = c_api.TFE_NewContextOptions(); | var opts = c_api.TFE_NewContextOptions(); | ||||
var ctx = c_api.TFE_NewContext(opts, status); | |||||
c_api.TFE_DeleteContextOptions(opts); | |||||
IntPtr devices; | |||||
using (var ctx = c_api.TFE_NewContext(opts, status)) | |||||
{ | |||||
c_api.TFE_DeleteContextOptions(opts); | |||||
var devices = c_api.TFE_ContextListDevices(ctx, status); | |||||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
devices = c_api.TFE_ContextListDevices(ctx, status); | |||||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
} | |||||
c_api.TFE_DeleteContext(ctx); | |||||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
int num_devices = c_api.TF_DeviceListCount(devices); | int num_devices = c_api.TF_DeviceListCount(devices); | ||||
@@ -21,24 +21,28 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
using var status = TF_NewStatus(); | using var status = TF_NewStatus(); | ||||
var opts = TFE_NewContextOptions(); | var opts = TFE_NewContextOptions(); | ||||
c_api.TFE_ContextOptionsSetAsync(opts, Convert.ToByte(async)); | c_api.TFE_ContextOptionsSetAsync(opts, Convert.ToByte(async)); | ||||
var ctx = TFE_NewContext(opts, status); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
TFE_DeleteContextOptions(opts); | |||||
var m = TestMatrixTensorHandle(); | |||||
var matmul = MatMulOp(ctx, m, m); | |||||
var retvals = new IntPtr[] { IntPtr.Zero, IntPtr.Zero }; | |||||
int num_retvals = 2; | |||||
c_api.TFE_Execute(matmul, retvals, ref num_retvals, status); | |||||
EXPECT_EQ(1, num_retvals); | |||||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
TFE_DeleteOp(matmul); | |||||
TFE_DeleteTensorHandle(m); | |||||
IntPtr t; | |||||
using (var ctx = TFE_NewContext(opts, status)) | |||||
{ | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
TFE_DeleteContextOptions(opts); | |||||
var m = TestMatrixTensorHandle(); | |||||
var matmul = MatMulOp(ctx, m, m); | |||||
var retvals = new IntPtr[] { IntPtr.Zero, IntPtr.Zero }; | |||||
int num_retvals = 2; | |||||
c_api.TFE_Execute(matmul, retvals, ref num_retvals, status); | |||||
EXPECT_EQ(1, num_retvals); | |||||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
TFE_DeleteOp(matmul); | |||||
TFE_DeleteTensorHandle(m); | |||||
t = TFE_TensorHandleResolve(retvals[0], status); | |||||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
TFE_DeleteTensorHandle(retvals[0]); | |||||
} | |||||
var t = TFE_TensorHandleResolve(retvals[0], status); | |||||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
TFE_DeleteTensorHandle(retvals[0]); | |||||
TFE_DeleteContext(ctx); | |||||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
var product = new float[4]; | var product = new float[4]; | ||||
EXPECT_EQ(product.Length * sizeof(float), (int)TF_TensorByteSize(t)); | EXPECT_EQ(product.Length * sizeof(float), (int)TF_TensorByteSize(t)); | ||||
@@ -16,7 +16,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
{ | { | ||||
using var status = TF_NewStatus(); | using var status = TF_NewStatus(); | ||||
var opts = TFE_NewContextOptions(); | var opts = TFE_NewContextOptions(); | ||||
var ctx = TFE_NewContext(opts, status); | |||||
using var ctx = TFE_NewContext(opts, status); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
TFE_DeleteContextOptions(opts); | TFE_DeleteContextOptions(opts); | ||||
@@ -57,7 +57,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
TFE_DeleteTensorHandle(input2); | TFE_DeleteTensorHandle(input2); | ||||
TFE_DeleteTensorHandle(retvals[0]); | TFE_DeleteTensorHandle(retvals[0]); | ||||
TFE_DeleteTensorHandle(retvals[1]); | TFE_DeleteTensorHandle(retvals[1]); | ||||
TFE_DeleteContext(ctx); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -18,7 +18,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
{ | { | ||||
using var status = TF_NewStatus(); | using var status = TF_NewStatus(); | ||||
var opts = TFE_NewContextOptions(); | var opts = TFE_NewContextOptions(); | ||||
var ctx = TFE_NewContext(opts, status); | |||||
using var ctx = TFE_NewContext(opts, status); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
TFE_DeleteContextOptions(opts); | TFE_DeleteContextOptions(opts); | ||||
@@ -50,7 +50,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
TFE_DeleteTensorHandle(t1); | TFE_DeleteTensorHandle(t1); | ||||
TFE_DeleteTensorHandle(t2); | TFE_DeleteTensorHandle(t2); | ||||
TFE_DeleteTensorHandle(retvals[0]); | TFE_DeleteTensorHandle(retvals[0]); | ||||
TFE_DeleteContext(ctx); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -16,7 +16,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
{ | { | ||||
var status = c_api.TF_NewStatus(); | var status = c_api.TF_NewStatus(); | ||||
var opts = TFE_NewContextOptions(); | var opts = TFE_NewContextOptions(); | ||||
var ctx = TFE_NewContext(opts, status); | |||||
using var ctx = TFE_NewContext(opts, status); | |||||
TFE_DeleteContextOptions(opts); | TFE_DeleteContextOptions(opts); | ||||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
@@ -65,7 +65,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
TFE_ExecutorWaitForAllPendingNodes(executor, status); | TFE_ExecutorWaitForAllPendingNodes(executor, status); | ||||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
TFE_DeleteExecutor(executor); | TFE_DeleteExecutor(executor); | ||||
TFE_DeleteContext(ctx); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -15,7 +15,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
{ | { | ||||
using var status = c_api.TF_NewStatus(); | using var status = c_api.TF_NewStatus(); | ||||
var opts = TFE_NewContextOptions(); | var opts = TFE_NewContextOptions(); | ||||
var ctx = TFE_NewContext(opts, status); | |||||
using var ctx = TFE_NewContext(opts, status); | |||||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
TFE_DeleteContextOptions(opts); | TFE_DeleteContextOptions(opts); | ||||
@@ -47,7 +47,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
TFE_DeleteTensorHandle(var_handle); | TFE_DeleteTensorHandle(var_handle); | ||||
TFE_DeleteTensorHandle(value_handle[0]); | TFE_DeleteTensorHandle(value_handle[0]); | ||||
TFE_DeleteContext(ctx); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
} | } | ||||
} | } | ||||
@@ -1,6 +1,7 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using System; | using System; | ||||
using Tensorflow; | using Tensorflow; | ||||
using Tensorflow.Eager; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace TensorFlowNET.UnitTest.NativeAPI | namespace TensorFlowNET.UnitTest.NativeAPI | ||||
@@ -25,7 +26,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
return th; | return th; | ||||
} | } | ||||
IntPtr MatMulOp(IntPtr ctx, IntPtr a, IntPtr b) | |||||
IntPtr MatMulOp(SafeContextHandle ctx, IntPtr a, IntPtr b) | |||||
{ | { | ||||
using var status = TF_NewStatus(); | using var status = TF_NewStatus(); | ||||
@@ -40,7 +41,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
return op; | return op; | ||||
} | } | ||||
bool GetDeviceName(IntPtr ctx, ref string device_name, string device_type) | |||||
bool GetDeviceName(SafeContextHandle ctx, ref string device_name, string device_type) | |||||
{ | { | ||||
var status = TF_NewStatus(); | var status = TF_NewStatus(); | ||||
var devices = TFE_ContextListDevices(ctx, status); | var devices = TFE_ContextListDevices(ctx, status); | ||||
@@ -65,7 +66,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
return false; | return false; | ||||
} | } | ||||
IntPtr ShapeOp(IntPtr ctx, IntPtr a) | |||||
IntPtr ShapeOp(SafeContextHandle ctx, IntPtr a) | |||||
{ | { | ||||
using var status = TF_NewStatus(); | using var status = TF_NewStatus(); | ||||
@@ -78,7 +79,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
return op; | return op; | ||||
} | } | ||||
unsafe IntPtr CreateVariable(IntPtr ctx, float value, SafeStatusHandle status) | |||||
unsafe IntPtr CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status) | |||||
{ | { | ||||
var op = TFE_NewOp(ctx, "VarHandleOp", status); | var op = TFE_NewOp(ctx, "VarHandleOp", status); | ||||
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | ||||