@@ -16,6 +16,7 @@ | |||
using System; | |||
using System.Runtime.InteropServices; | |||
using Tensorflow.Eager; | |||
namespace Tensorflow | |||
{ | |||
@@ -64,7 +65,7 @@ namespace Tensorflow | |||
/// <param name="status">TF_Status*</param> | |||
/// <returns>TFE_TensorHandle*</returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, IntPtr ctx, string device_name, SafeStatusHandle status); | |||
public static extern IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status); | |||
/// <summary> | |||
/// Retrieves the full name of the device (e.g. /job:worker/replica:0/...) | |||
@@ -2,7 +2,7 @@ | |||
namespace Tensorflow.Eager | |||
{ | |||
public class Context : DisposableObject | |||
public sealed class Context : IDisposable | |||
{ | |||
public const int GRAPH_MODE = 0; | |||
public const int EAGER_MODE = 1; | |||
@@ -12,9 +12,11 @@ namespace Tensorflow.Eager | |||
public string scope_name = ""; | |||
bool _initialized = false; | |||
public SafeContextHandle Handle { get; } | |||
public Context(ContextOptions opts, Status status) | |||
{ | |||
_handle = c_api.TFE_NewContext(opts, status.Handle); | |||
Handle = c_api.TFE_NewContext(opts, status.Handle); | |||
status.Check(true); | |||
} | |||
@@ -29,16 +31,10 @@ namespace Tensorflow.Eager | |||
} | |||
public void start_step() | |||
=> c_api.TFE_ContextStartStep(_handle); | |||
=> c_api.TFE_ContextStartStep(Handle); | |||
public void end_step() | |||
=> c_api.TFE_ContextEndStep(_handle); | |||
/// <summary> | |||
/// Dispose any unmanaged resources related to given <paramref name="handle"/>. | |||
/// </summary> | |||
protected sealed override void DisposeUnmanagedResources(IntPtr handle) | |||
=> c_api.TFE_DeleteContext(_handle); | |||
=> c_api.TFE_ContextEndStep(Handle); | |||
public bool executing_eagerly() | |||
=> default_execution_mode == EAGER_MODE; | |||
@@ -48,10 +44,7 @@ namespace Tensorflow.Eager | |||
name : | |||
"cd2c89b7-88b7-44c8-ad83-06c2a9158347"; | |||
public static implicit operator IntPtr(Context ctx) | |||
=> ctx._handle; | |||
public static implicit operator TFE_Context(Context ctx) | |||
=> new TFE_Context(ctx._handle); | |||
public void Dispose() | |||
=> Handle.Dispose(); | |||
} | |||
} |
@@ -53,7 +53,7 @@ namespace Tensorflow.Eager | |||
{ | |||
object value = null; | |||
byte isList = 0; | |||
var attrType = c_api.TFE_OpNameGetAttrType(tf.context, Name, attr_name, ref isList, tf.status.Handle); | |||
var attrType = c_api.TFE_OpNameGetAttrType(tf.context.Handle, Name, attr_name, ref isList, tf.status.Handle); | |||
switch (attrType) | |||
{ | |||
case TF_AttrType.TF_ATTR_BOOL: | |||
@@ -0,0 +1,40 @@ | |||
/***************************************************************************** | |||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using Tensorflow.Util; | |||
namespace Tensorflow.Eager | |||
{ | |||
public sealed class SafeContextHandle : SafeTensorflowHandle | |||
{ | |||
public SafeContextHandle() | |||
{ | |||
} | |||
public SafeContextHandle(IntPtr handle) | |||
: base(handle) | |||
{ | |||
} | |||
protected override bool ReleaseHandle() | |||
{ | |||
c_api.TFE_DeleteContext(handle); | |||
SetHandle(IntPtr.Zero); | |||
return true; | |||
} | |||
} | |||
} |
@@ -1,23 +0,0 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Eager | |||
{ | |||
public struct TFE_Context | |||
{ | |||
IntPtr _handle; | |||
public TFE_Context(IntPtr handle) | |||
=> _handle = handle; | |||
public static implicit operator TFE_Context(IntPtr handle) | |||
=> new TFE_Context(handle); | |||
public static implicit operator IntPtr(TFE_Context tensor) | |||
=> tensor._handle; | |||
public override string ToString() | |||
=> $"TFE_Context {_handle}"; | |||
} | |||
} |
@@ -73,7 +73,7 @@ namespace Tensorflow | |||
public static extern TF_AttrType TFE_OpGetAttrType(IntPtr op, string attr_name, ref byte is_list, SafeStatusHandle status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern TF_AttrType TFE_OpNameGetAttrType(IntPtr ct, string op_or_function_name, string attr_name, ref byte is_list, SafeStatusHandle status); | |||
public static extern TF_AttrType TFE_OpNameGetAttrType(SafeContextHandle ctx, string op_or_function_name, string attr_name, ref byte is_list, SafeStatusHandle status); | |||
/// <summary> | |||
/// Returns the length (number of tensors) of the input argument `input_name` | |||
@@ -114,13 +114,13 @@ namespace Tensorflow | |||
/// <param name="status">TF_Status*</param> | |||
/// <returns>TFE_Context*</returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern TFE_Context TFE_NewContext(IntPtr opts, SafeStatusHandle status); | |||
public static extern SafeContextHandle TFE_NewContext(IntPtr opts, SafeStatusHandle status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern TFE_Context TFE_ContextStartStep(IntPtr ctx); | |||
public static extern void TFE_ContextStartStep(SafeContextHandle ctx); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern TFE_Context TFE_ContextEndStep(IntPtr ctx); | |||
public static extern void TFE_ContextEndStep(SafeContextHandle ctx); | |||
/// <summary> | |||
/// | |||
@@ -148,7 +148,7 @@ namespace Tensorflow | |||
/// <param name="status">TF_Status*</param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern TFE_Op TFE_NewOp(IntPtr ctx, string op_or_function_name, SafeStatusHandle status); | |||
public static extern TFE_Op TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status); | |||
/// <summary> | |||
/// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This | |||
@@ -317,7 +317,7 @@ namespace Tensorflow | |||
/// <param name="status">TF_Status*</param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TFE_ContextListDevices(IntPtr ctx, SafeStatusHandle status); | |||
public static extern IntPtr TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status); | |||
/// <summary> | |||
/// | |||
@@ -379,7 +379,7 @@ namespace Tensorflow | |||
/// <param name="ctx"></param> | |||
/// <param name="executor"></param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_ContextSetExecutorForThread(IntPtr ctx, TFE_Executor executor); | |||
public static extern void TFE_ContextSetExecutorForThread(SafeContextHandle ctx, TFE_Executor executor); | |||
/// <summary> | |||
/// Returns the Executor for current thread. | |||
@@ -387,7 +387,7 @@ namespace Tensorflow | |||
/// <param name="ctx"></param> | |||
/// <returns>TFE_Executor*</returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern TFE_Executor TFE_ContextGetExecutorForThread(IntPtr ctx); | |||
public static extern TFE_Executor TFE_ContextGetExecutorForThread(SafeContextHandle ctx); | |||
/// <summary> | |||
/// | |||
@@ -402,7 +402,7 @@ namespace Tensorflow | |||
/// <param name="status"></param> | |||
/// <returns>EagerTensorHandle</returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern SafeStatusHandle TFE_FastPathExecute(IntPtr ctx, | |||
public static extern SafeStatusHandle TFE_FastPathExecute(SafeContextHandle ctx, | |||
string device_name, | |||
string op_name, | |||
string name, | |||
@@ -416,7 +416,7 @@ namespace Tensorflow | |||
public delegate void TFE_FastPathExecute_SetOpAttrs(IntPtr op); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern SafeStatusHandle TFE_QuickExecute(IntPtr ctx, | |||
public static extern SafeStatusHandle TFE_QuickExecute(SafeContextHandle ctx, | |||
string device_name, | |||
string op_name, | |||
IntPtr[] inputs, | |||
@@ -1,6 +1,7 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System; | |||
using Tensorflow; | |||
using Tensorflow.Eager; | |||
using Buffer = System.Buffer; | |||
namespace TensorFlowNET.UnitTest | |||
@@ -92,7 +93,7 @@ namespace TensorFlowNET.UnitTest | |||
protected void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length) | |||
=> c_api.TFE_OpSetAttrString(op, attr_name, value, length); | |||
protected IntPtr TFE_NewOp(IntPtr ctx, string op_or_function_name, SafeStatusHandle status) | |||
protected IntPtr TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status) | |||
=> c_api.TFE_NewOp(ctx, op_or_function_name, status); | |||
protected IntPtr TFE_NewTensorHandle(IntPtr t, SafeStatusHandle status) | |||
@@ -104,10 +105,7 @@ namespace TensorFlowNET.UnitTest | |||
protected IntPtr TFE_NewContextOptions() | |||
=> c_api.TFE_NewContextOptions(); | |||
protected void TFE_DeleteContext(IntPtr t) | |||
=> c_api.TFE_DeleteContext(t); | |||
protected IntPtr TFE_NewContext(IntPtr opts, SafeStatusHandle status) | |||
protected SafeContextHandle TFE_NewContext(IntPtr opts, SafeStatusHandle status) | |||
=> c_api.TFE_NewContext(opts, status); | |||
protected void TFE_DeleteContextOptions(IntPtr opts) | |||
@@ -131,7 +129,7 @@ namespace TensorFlowNET.UnitTest | |||
protected void TFE_DeleteExecutor(IntPtr executor) | |||
=> c_api.TFE_DeleteExecutor(executor); | |||
protected IntPtr TFE_ContextGetExecutorForThread(IntPtr ctx) | |||
protected IntPtr TFE_ContextGetExecutorForThread(SafeContextHandle ctx) | |||
=> c_api.TFE_ContextGetExecutorForThread(ctx); | |||
protected void TFE_ExecutorWaitForAllPendingNodes(IntPtr executor, SafeStatusHandle status) | |||
@@ -146,7 +144,7 @@ namespace TensorFlowNET.UnitTest | |||
protected string TFE_TensorHandleBackingDeviceName(IntPtr h, SafeStatusHandle status) | |||
=> c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status)); | |||
protected IntPtr TFE_ContextListDevices(IntPtr ctx, SafeStatusHandle status) | |||
protected IntPtr TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status) | |||
=> c_api.TFE_ContextListDevices(ctx, status); | |||
protected int TF_DeviceListCount(IntPtr list) | |||
@@ -161,7 +159,7 @@ namespace TensorFlowNET.UnitTest | |||
protected void TF_DeleteDeviceList(IntPtr list) | |||
=> c_api.TF_DeleteDeviceList(list); | |||
protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, IntPtr ctx, string device_name, SafeStatusHandle status) | |||
protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status) | |||
=> c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status); | |||
protected void TFE_OpSetDevice(IntPtr op, string device_name, SafeStatusHandle status) | |||
@@ -1,7 +1,6 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System; | |||
using Tensorflow; | |||
using Tensorflow.Eager; | |||
namespace TensorFlowNET.UnitTest.NativeAPI | |||
{ | |||
@@ -15,14 +14,16 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
{ | |||
using var status = c_api.TF_NewStatus(); | |||
var opts = c_api.TFE_NewContextOptions(); | |||
var ctx = c_api.TFE_NewContext(opts, status); | |||
c_api.TFE_DeleteContextOptions(opts); | |||
IntPtr devices; | |||
using (var ctx = c_api.TFE_NewContext(opts, status)) | |||
{ | |||
c_api.TFE_DeleteContextOptions(opts); | |||
var devices = c_api.TFE_ContextListDevices(ctx, status); | |||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
devices = c_api.TFE_ContextListDevices(ctx, status); | |||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
} | |||
c_api.TFE_DeleteContext(ctx); | |||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
int num_devices = c_api.TF_DeviceListCount(devices); | |||
@@ -21,24 +21,28 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
using var status = TF_NewStatus(); | |||
var opts = TFE_NewContextOptions(); | |||
c_api.TFE_ContextOptionsSetAsync(opts, Convert.ToByte(async)); | |||
var ctx = TFE_NewContext(opts, status); | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
TFE_DeleteContextOptions(opts); | |||
var m = TestMatrixTensorHandle(); | |||
var matmul = MatMulOp(ctx, m, m); | |||
var retvals = new IntPtr[] { IntPtr.Zero, IntPtr.Zero }; | |||
int num_retvals = 2; | |||
c_api.TFE_Execute(matmul, retvals, ref num_retvals, status); | |||
EXPECT_EQ(1, num_retvals); | |||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
TFE_DeleteOp(matmul); | |||
TFE_DeleteTensorHandle(m); | |||
IntPtr t; | |||
using (var ctx = TFE_NewContext(opts, status)) | |||
{ | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
TFE_DeleteContextOptions(opts); | |||
var m = TestMatrixTensorHandle(); | |||
var matmul = MatMulOp(ctx, m, m); | |||
var retvals = new IntPtr[] { IntPtr.Zero, IntPtr.Zero }; | |||
int num_retvals = 2; | |||
c_api.TFE_Execute(matmul, retvals, ref num_retvals, status); | |||
EXPECT_EQ(1, num_retvals); | |||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
TFE_DeleteOp(matmul); | |||
TFE_DeleteTensorHandle(m); | |||
t = TFE_TensorHandleResolve(retvals[0], status); | |||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
TFE_DeleteTensorHandle(retvals[0]); | |||
} | |||
var t = TFE_TensorHandleResolve(retvals[0], status); | |||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
TFE_DeleteTensorHandle(retvals[0]); | |||
TFE_DeleteContext(ctx); | |||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
var product = new float[4]; | |||
EXPECT_EQ(product.Length * sizeof(float), (int)TF_TensorByteSize(t)); | |||
@@ -16,7 +16,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
{ | |||
using var status = TF_NewStatus(); | |||
var opts = TFE_NewContextOptions(); | |||
var ctx = TFE_NewContext(opts, status); | |||
using var ctx = TFE_NewContext(opts, status); | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
TFE_DeleteContextOptions(opts); | |||
@@ -57,7 +57,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
TFE_DeleteTensorHandle(input2); | |||
TFE_DeleteTensorHandle(retvals[0]); | |||
TFE_DeleteTensorHandle(retvals[1]); | |||
TFE_DeleteContext(ctx); | |||
} | |||
} | |||
} |
@@ -18,7 +18,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
{ | |||
using var status = TF_NewStatus(); | |||
var opts = TFE_NewContextOptions(); | |||
var ctx = TFE_NewContext(opts, status); | |||
using var ctx = TFE_NewContext(opts, status); | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
TFE_DeleteContextOptions(opts); | |||
@@ -50,7 +50,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
TFE_DeleteTensorHandle(t1); | |||
TFE_DeleteTensorHandle(t2); | |||
TFE_DeleteTensorHandle(retvals[0]); | |||
TFE_DeleteContext(ctx); | |||
} | |||
} | |||
} |
@@ -16,7 +16,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
{ | |||
var status = c_api.TF_NewStatus(); | |||
var opts = TFE_NewContextOptions(); | |||
var ctx = TFE_NewContext(opts, status); | |||
using var ctx = TFE_NewContext(opts, status); | |||
TFE_DeleteContextOptions(opts); | |||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
@@ -65,7 +65,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
TFE_ExecutorWaitForAllPendingNodes(executor, status); | |||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
TFE_DeleteExecutor(executor); | |||
TFE_DeleteContext(ctx); | |||
} | |||
} | |||
} |
@@ -15,7 +15,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
{ | |||
using var status = c_api.TF_NewStatus(); | |||
var opts = TFE_NewContextOptions(); | |||
var ctx = TFE_NewContext(opts, status); | |||
using var ctx = TFE_NewContext(opts, status); | |||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
TFE_DeleteContextOptions(opts); | |||
@@ -47,7 +47,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
TFE_DeleteTensorHandle(var_handle); | |||
TFE_DeleteTensorHandle(value_handle[0]); | |||
TFE_DeleteContext(ctx); | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
} | |||
} | |||
@@ -1,6 +1,7 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System; | |||
using Tensorflow; | |||
using Tensorflow.Eager; | |||
using static Tensorflow.Binding; | |||
namespace TensorFlowNET.UnitTest.NativeAPI | |||
@@ -25,7 +26,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
return th; | |||
} | |||
IntPtr MatMulOp(IntPtr ctx, IntPtr a, IntPtr b) | |||
IntPtr MatMulOp(SafeContextHandle ctx, IntPtr a, IntPtr b) | |||
{ | |||
using var status = TF_NewStatus(); | |||
@@ -40,7 +41,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
return op; | |||
} | |||
bool GetDeviceName(IntPtr ctx, ref string device_name, string device_type) | |||
bool GetDeviceName(SafeContextHandle ctx, ref string device_name, string device_type) | |||
{ | |||
var status = TF_NewStatus(); | |||
var devices = TFE_ContextListDevices(ctx, status); | |||
@@ -65,7 +66,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
return false; | |||
} | |||
IntPtr ShapeOp(IntPtr ctx, IntPtr a) | |||
IntPtr ShapeOp(SafeContextHandle ctx, IntPtr a) | |||
{ | |||
using var status = TF_NewStatus(); | |||
@@ -78,7 +79,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
return op; | |||
} | |||
unsafe IntPtr CreateVariable(IntPtr ctx, float value, SafeStatusHandle status) | |||
unsafe IntPtr CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status) | |||
{ | |||
var op = TFE_NewOp(ctx, "VarHandleOp", status); | |||
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||