diff --git a/src/TensorFlowNET.Core/Device/c_api.device.cs b/src/TensorFlowNET.Core/Device/c_api.device.cs
index 3b7cca5a..bd2d1295 100644
--- a/src/TensorFlowNET.Core/Device/c_api.device.cs
+++ b/src/TensorFlowNET.Core/Device/c_api.device.cs
@@ -67,7 +67,7 @@ namespace Tensorflow
/// TF_Status*
/// TFE_TensorHandle*
[DllImport(TensorFlowLibName)]
- public static extern SafeTensorHandleHandle TFE_TensorHandleCopyToDevice(SafeTensorHandleHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status);
+ public static extern SafeEagerTensorHandle TFE_TensorHandleCopyToDevice(SafeEagerTensorHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status);
///
/// Retrieves the full name of the device (e.g. /job:worker/replica:0/...)
diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs
index 2fc3f40c..4aad851f 100644
--- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs
@@ -48,7 +48,7 @@ namespace Tensorflow.Eager
{
for (int i = 0; i < inputs.Length; ++i)
{
- SafeTensorHandleHandle tensor_handle = inputs[i] switch
+ SafeEagerTensorHandle tensor_handle = inputs[i] switch
{
EagerTensor et => et.EagerTensorHandle,
Tensor nd => nd.EagerTensorHandle,
@@ -61,7 +61,7 @@ namespace Tensorflow.Eager
if (status.ok() && attrs != null)
SetOpAttrs(op, attrs);
- var outputs = new SafeTensorHandleHandle[num_outputs];
+ var outputs = new SafeEagerTensorHandle[num_outputs];
if (status.ok())
{
c_api.TFE_Execute(op, outputs, out num_outputs, status.Handle);
diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
index 20049952..3bab7c07 100644
--- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
@@ -141,7 +141,7 @@ namespace Tensorflow.Eager
num_retvals += (int)delta;
}
- var retVals = new SafeTensorHandleHandle[num_retvals];
+ var retVals = new SafeEagerTensorHandle[num_retvals];
c_api.TFE_Execute(op, retVals, out num_retvals, status.Handle);
status.Check(true);
diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
index fa7309e3..1390daf2 100644
--- a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
@@ -12,7 +12,7 @@ namespace Tensorflow.Eager
NewEagerTensorHandle(handle);
}
- public EagerTensor(SafeTensorHandleHandle handle)
+ public EagerTensor(SafeEagerTensorHandle handle)
{
_id = ops.uid();
_eagerTensorHandle = handle;
diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.cs
index addb93de..30a13312 100644
--- a/src/TensorFlowNET.Core/Eager/EagerTensor.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerTensor.cs
@@ -6,17 +6,19 @@ namespace Tensorflow.Eager
{
public partial class EagerTensor : Tensor
{
- public override string Device
- {
- get
- {
- using var _ = EagerTensorHandle.Lease();
- return c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(EagerTensorHandle, tf.Status.Handle));
- }
- }
+ public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(_eagerTensorHandle, tf.Status.Handle));
+ public override TF_DataType dtype => c_api.TFE_TensorHandleDataType(_eagerTensorHandle);
public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.Status.Handle);
+ protected override Shape GetShapeInternal()
+ {
+ var dims = new int[c_api.TFE_TensorHandleNumDims(_eagerTensorHandle, tf.Status.Handle)];
+ for (int i = 0; i < dims.Length; i++)
+ dims[i] = c_api.TFE_TensorHandleDim(_eagerTensorHandle, i, tf.Status.Handle);
+ return dims;
+ }
+
public static int GetRank(IntPtr handle)
{
var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle);
diff --git a/src/TensorFlowNET.Core/Eager/SafeTensorHandleHandle.cs b/src/TensorFlowNET.Core/Eager/SafeEagerTensorHandle.cs
similarity index 88%
rename from src/TensorFlowNET.Core/Eager/SafeTensorHandleHandle.cs
rename to src/TensorFlowNET.Core/Eager/SafeEagerTensorHandle.cs
index 9f91e881..025e6511 100644
--- a/src/TensorFlowNET.Core/Eager/SafeTensorHandleHandle.cs
+++ b/src/TensorFlowNET.Core/Eager/SafeEagerTensorHandle.cs
@@ -20,13 +20,13 @@ using static Tensorflow.Binding;
namespace Tensorflow.Eager
{
- public sealed class SafeTensorHandleHandle : SafeTensorflowHandle
+ public sealed class SafeEagerTensorHandle : SafeTensorflowHandle
{
- private SafeTensorHandleHandle()
+ private SafeEagerTensorHandle()
{
}
- public SafeTensorHandleHandle(IntPtr handle)
+ public SafeEagerTensorHandle(IntPtr handle)
: base(handle)
{
}
diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs
index f8911bd4..d874ac93 100644
--- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs
+++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs
@@ -94,7 +94,7 @@ namespace Tensorflow
/// TF_Status*
///
[DllImport(TensorFlowLibName)]
- public static extern int TFE_OpAddInputList(SafeEagerOpHandle op, [In, MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(SafeHandleArrayMarshaler))] SafeTensorHandleHandle[] inputs, int num_inputs, SafeStatusHandle status);
+ public static extern int TFE_OpAddInputList(SafeEagerOpHandle op, [In, MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(SafeHandleArrayMarshaler))] SafeEagerTensorHandle[] inputs, int num_inputs, SafeStatusHandle status);
///
///
@@ -161,7 +161,7 @@ namespace Tensorflow
///
///
///
- public static void TFE_Execute(SafeEagerOpHandle op, SafeTensorHandleHandle[] retvals, out int num_retvals, SafeStatusHandle status)
+ public static void TFE_Execute(SafeEagerOpHandle op, SafeEagerTensorHandle[] retvals, out int num_retvals, SafeStatusHandle status)
{
unsafe
{
@@ -173,7 +173,7 @@ namespace Tensorflow
// A handle is created for every return, even if rawReturns[i] is null. The resulting handle will be
// non-null but invalid, which is the same behavior P/Invoke gives for non-array SafeHandle return
// values.
- retvals[i] = new SafeTensorHandleHandle(rawReturns[i]);
+ retvals[i] = new SafeEagerTensorHandle(rawReturns[i]);
}
}
}
@@ -295,7 +295,7 @@ namespace Tensorflow
/// TFE_TensorHandle*
/// TF_Status*
[DllImport(TensorFlowLibName)]
- public static extern void TFE_OpAddInput(SafeEagerOpHandle op, SafeTensorHandleHandle h, SafeStatusHandle status);
+ public static extern void TFE_OpAddInput(SafeEagerOpHandle op, SafeEagerTensorHandle h, SafeStatusHandle status);
///
///
@@ -303,10 +303,10 @@ namespace Tensorflow
/// const tensorflow::Tensor&
/// TFE_TensorHandle*
[DllImport(TensorFlowLibName)]
- public static extern SafeTensorHandleHandle TFE_NewTensorHandle(SafeTensorHandle t, SafeStatusHandle status);
+ public static extern SafeEagerTensorHandle TFE_NewTensorHandle(SafeTensorHandle t, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
- public static extern SafeTensorHandleHandle TFE_EagerTensorHandle(IntPtr t);
+ public static extern SafeEagerTensorHandle TFE_EagerTensorHandle(IntPtr t);
///
/// Sets the default execution mode (sync/async). Note that this can be
@@ -323,7 +323,7 @@ namespace Tensorflow
/// TFE_TensorHandle*
///
[DllImport(TensorFlowLibName)]
- public static extern TF_DataType TFE_TensorHandleDataType(SafeTensorHandleHandle h);
+ public static extern TF_DataType TFE_TensorHandleDataType(SafeEagerTensorHandle h);
///
/// This function will block till the operation that produces `h` has
@@ -334,7 +334,7 @@ namespace Tensorflow
/// TF_Status*
///
[DllImport(TensorFlowLibName)]
- public static extern SafeTensorHandle TFE_TensorHandleResolve(SafeTensorHandleHandle h, SafeStatusHandle status);
+ public static extern SafeTensorHandle TFE_TensorHandleResolve(SafeEagerTensorHandle h, SafeStatusHandle status);
///
@@ -344,10 +344,10 @@ namespace Tensorflow
/// TF_Status*
///
[DllImport(TensorFlowLibName)]
- public static extern int TFE_TensorHandleNumDims(SafeTensorHandleHandle h, SafeStatusHandle status);
+ public static extern int TFE_TensorHandleNumDims(SafeEagerTensorHandle h, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
- public static extern int TFE_TensorHandleDim(SafeTensorHandleHandle h, int dim, SafeStatusHandle status);
+ public static extern int TFE_TensorHandleDim(SafeEagerTensorHandle h, int dim, SafeStatusHandle status);
///
/// Returns the device of the operation that produced `h`. If `h` was produced by
@@ -360,7 +360,7 @@ namespace Tensorflow
/// TF_Status*
///
[DllImport(TensorFlowLibName)]
- public static extern IntPtr TFE_TensorHandleDeviceName(SafeTensorHandleHandle h, SafeStatusHandle status);
+ public static extern IntPtr TFE_TensorHandleDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status);
///
/// Returns the name of the device in whose memory `h` resides.
@@ -369,7 +369,7 @@ namespace Tensorflow
/// TF_Status*
///
[DllImport(TensorFlowLibName)]
- public static extern IntPtr TFE_TensorHandleBackingDeviceName(SafeTensorHandleHandle h, SafeStatusHandle status);
+ public static extern IntPtr TFE_TensorHandleBackingDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status);
///
///
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs
index 628d1ce0..3e76d3fa 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs
@@ -67,9 +67,9 @@ namespace Tensorflow
///
/// The DType of elements in this tensor.
///
- public TF_DataType dtype => _handle == null ? _override_dtype : c_api.TF_TensorType(_handle);
+ public virtual TF_DataType dtype => _handle == null ? _override_dtype : c_api.TF_TensorType(_handle);
public ulong bytesize => _handle == null ? 0 : c_api.TF_TensorByteSize(_handle);
- public ulong dtypesize => _handle == null ? 0 : c_api.TF_DataTypeSize(dtype);
+ public ulong dtypesize => (ulong)dtype.get_datatype_size();
public ulong size => _handle == null ? 0 : bytesize / dtypesize;
public IntPtr buffer => _handle == null ? IntPtr.Zero : c_api.TF_TensorData(_handle);
public int num_consumers(TF_Output oper_out) => _handle == null ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out);
@@ -88,11 +88,11 @@ namespace Tensorflow
protected new SafeTensorHandle _handle;
public SafeTensorHandle Handle => _handle;
- protected SafeTensorHandleHandle _eagerTensorHandle;
+ protected SafeEagerTensorHandle _eagerTensorHandle;
///
/// TFE_TensorHandle
///
- public SafeTensorHandleHandle EagerTensorHandle => _eagerTensorHandle;
+ public SafeEagerTensorHandle EagerTensorHandle => _eagerTensorHandle;
protected bool _isCreatedInGraphMode;
@@ -109,19 +109,7 @@ namespace Tensorflow
if (rank < 0)
return Shape.Null;
- var dims = new Shape(new long[rank]);
-
- if (_handle == null)
- {
- c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, tf.Status.Handle);
- }
- else
- {
- for (int i = 0; i < rank; i++)
- dims[i] = c_api.TF_Dim(_handle, i);
- }
-
- return dims;
+ return GetShapeInternal();
}
set
@@ -142,6 +130,23 @@ namespace Tensorflow
}
}
+ protected virtual Shape GetShapeInternal()
+ {
+ var dims = new Shape(new long[rank]);
+
+ if (_handle == null)
+ {
+ c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, tf.Status.Handle);
+ }
+ else
+ {
+ for (int i = 0; i < rank; i++)
+ dims[i] = c_api.TF_Dim(_handle, i);
+ }
+
+ return dims;
+ }
+
public int[] _shape_tuple()
{
return rank < 0 ? null : shape.dims.Select(x => (int)x).ToArray();
diff --git a/test/TensorFlowNET.Native.UnitTest/CApiTest.cs b/test/TensorFlowNET.Native.UnitTest/CApiTest.cs
index 781d29ee..2432ec1f 100644
--- a/test/TensorFlowNET.Native.UnitTest/CApiTest.cs
+++ b/test/TensorFlowNET.Native.UnitTest/CApiTest.cs
@@ -56,10 +56,10 @@ namespace Tensorflow.Native.UnitTest
protected void TF_SetAttrBool(OperationDescription desc, string attrName, bool value)
=> c_api.TF_SetAttrBool(desc, attrName, value);
- protected TF_DataType TFE_TensorHandleDataType(SafeTensorHandleHandle h)
+ protected TF_DataType TFE_TensorHandleDataType(SafeEagerTensorHandle h)
=> c_api.TFE_TensorHandleDataType(h);
- protected int TFE_TensorHandleNumDims(SafeTensorHandleHandle h, SafeStatusHandle status)
+ protected int TFE_TensorHandleNumDims(SafeEagerTensorHandle h, SafeStatusHandle status)
=> c_api.TFE_TensorHandleNumDims(h, status);
protected TF_Code TF_GetCode(Status s)
@@ -80,7 +80,7 @@ namespace Tensorflow.Native.UnitTest
protected ulong TF_TensorByteSize(SafeTensorHandle t)
=> c_api.TF_TensorByteSize(t);
- protected void TFE_OpAddInput(SafeEagerOpHandle op, SafeTensorHandleHandle h, SafeStatusHandle status)
+ protected void TFE_OpAddInput(SafeEagerOpHandle op, SafeEagerTensorHandle h, SafeStatusHandle status)
=> c_api.TFE_OpAddInput(op, h, status);
protected void TFE_OpSetAttrType(SafeEagerOpHandle op, string attr_name, TF_DataType value)
@@ -95,10 +95,10 @@ namespace Tensorflow.Native.UnitTest
protected SafeEagerOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status)
=> c_api.TFE_NewOp(ctx, op_or_function_name, status);
- protected SafeTensorHandleHandle TFE_NewTensorHandle(SafeTensorHandle t, SafeStatusHandle status)
+ protected SafeEagerTensorHandle TFE_NewTensorHandle(SafeTensorHandle t, SafeStatusHandle status)
=> c_api.TFE_NewTensorHandle(t, status);
- protected void TFE_Execute(SafeEagerOpHandle op, SafeTensorHandleHandle[] retvals, out int num_retvals, SafeStatusHandle status)
+ protected void TFE_Execute(SafeEagerOpHandle op, SafeEagerTensorHandle[] retvals, out int num_retvals, SafeStatusHandle status)
=> c_api.TFE_Execute(op, retvals, out num_retvals, status);
protected SafeContextOptionsHandle TFE_NewContextOptions()
@@ -110,7 +110,7 @@ namespace Tensorflow.Native.UnitTest
protected int TFE_OpGetInputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status)
=> c_api.TFE_OpGetInputLength(op, input_name, status);
- protected int TFE_OpAddInputList(SafeEagerOpHandle op, SafeTensorHandleHandle[] inputs, int num_inputs, SafeStatusHandle status)
+ protected int TFE_OpAddInputList(SafeEagerOpHandle op, SafeEagerTensorHandle[] inputs, int num_inputs, SafeStatusHandle status)
=> c_api.TFE_OpAddInputList(op, inputs, num_inputs, status);
protected int TFE_OpGetOutputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status)
@@ -125,13 +125,13 @@ namespace Tensorflow.Native.UnitTest
protected void TFE_ExecutorWaitForAllPendingNodes(SafeExecutorHandle executor, SafeStatusHandle status)
=> c_api.TFE_ExecutorWaitForAllPendingNodes(executor, status);
- protected SafeTensorHandle TFE_TensorHandleResolve(SafeTensorHandleHandle h, SafeStatusHandle status)
+ protected SafeTensorHandle TFE_TensorHandleResolve(SafeEagerTensorHandle h, SafeStatusHandle status)
=> c_api.TFE_TensorHandleResolve(h, status);
- protected string TFE_TensorHandleDeviceName(SafeTensorHandleHandle h, SafeStatusHandle status)
+ protected string TFE_TensorHandleDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status)
=> c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(h, status));
- protected string TFE_TensorHandleBackingDeviceName(SafeTensorHandleHandle h, SafeStatusHandle status)
+ protected string TFE_TensorHandleBackingDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status)
=> c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status));
protected SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status)
@@ -146,7 +146,7 @@ namespace Tensorflow.Native.UnitTest
protected string TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status)
=> c_api.TF_DeviceListName(list, index, status);
- protected SafeTensorHandleHandle TFE_TensorHandleCopyToDevice(SafeTensorHandleHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status)
+ protected SafeEagerTensorHandle TFE_TensorHandleCopyToDevice(SafeEagerTensorHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status)
=> c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status);
protected void TFE_OpSetDevice(SafeEagerOpHandle op, string device_name, SafeStatusHandle status)
diff --git a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.Execute_MatMul_CPU.cs b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.Execute_MatMul_CPU.cs
index 5873b2c9..c8502735 100644
--- a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.Execute_MatMul_CPU.cs
+++ b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.Execute_MatMul_CPU.cs
@@ -32,7 +32,7 @@ namespace Tensorflow.Native.UnitTest.Eager
{
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
- var retvals = new SafeTensorHandleHandle[2];
+ var retvals = new SafeEagerTensorHandle[2];
using (var m = TestMatrixTensorHandle())
using (var matmul = MatMulOp(ctx, m, m))
{
diff --git a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpGetInputAndOutputLengths.cs b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpGetInputAndOutputLengths.cs
index ce5a287f..ff31b195 100644
--- a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpGetInputAndOutputLengths.cs
+++ b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpGetInputAndOutputLengths.cs
@@ -25,7 +25,7 @@ namespace Tensorflow.Native.UnitTest.Eager
using var input1 = TestMatrixTensorHandle();
using var input2 = TestMatrixTensorHandle();
- var retvals = new SafeTensorHandleHandle[2];
+ var retvals = new SafeEagerTensorHandle[2];
using (var identityOp = TFE_NewOp(ctx, "IdentityN", status))
{
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
@@ -36,7 +36,7 @@ namespace Tensorflow.Native.UnitTest.Eager
EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status));
CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status));
- var inputs = new SafeTensorHandleHandle[] { input1, input2 };
+ var inputs = new SafeEagerTensorHandle[] { input1, input2 };
TFE_OpAddInputList(identityOp, inputs, 2, status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
diff --git a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpInferMixedTypeInputListAttrs.cs b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpInferMixedTypeInputListAttrs.cs
index ad878115..ab0d5181 100644
--- a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpInferMixedTypeInputListAttrs.cs
+++ b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpInferMixedTypeInputListAttrs.cs
@@ -41,7 +41,7 @@ namespace Tensorflow.Native.UnitTest.Eager
//EXPECT_EQ(attr_found->second.list().type(1), tensorflow::DataType::DT_FLOAT);
//EXPECT_EQ(attr_found->second.list().type(2), tensorflow::DataType::DT_INT32);
- var retvals = new SafeTensorHandleHandle[0];
+ var retvals = new SafeEagerTensorHandle[0];
int num_retvals;
TFE_Execute(assertOp, retvals, out num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
diff --git a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.TensorHandleDevices.cs b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.TensorHandleDevices.cs
index 9fc8f95e..bc430f87 100644
--- a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.TensorHandleDevices.cs
+++ b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.TensorHandleDevices.cs
@@ -39,7 +39,7 @@ namespace Tensorflow.Native.UnitTest.Eager
using var hgpu = TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_device_name, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status));
- var retvals = new SafeTensorHandleHandle[1];
+ var retvals = new SafeEagerTensorHandle[1];
using (var shape_op = ShapeOp(ctx, hgpu))
{
TFE_OpSetDevice(shape_op, gpu_device_name, status);
diff --git a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.Variables.cs b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.Variables.cs
index 310a933a..7c43e111 100644
--- a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.Variables.cs
+++ b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.Variables.cs
@@ -28,7 +28,7 @@ namespace Tensorflow.Native.UnitTest.Eager
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
int num_retvals = 1;
- var value_handle = new SafeTensorHandleHandle[1];
+ var value_handle = new SafeEagerTensorHandle[1];
using (var op = TFE_NewOp(ctx, "ReadVariableOp", status))
{
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
diff --git a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.cs b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.cs
index 864c09f0..c38ba5a5 100644
--- a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.cs
+++ b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.cs
@@ -11,7 +11,7 @@ namespace Tensorflow.Native.UnitTest.Eager
[TestClass]
public partial class CApiEagerTest : CApiTest
{
- SafeTensorHandleHandle TestMatrixTensorHandle()
+ SafeEagerTensorHandle TestMatrixTensorHandle()
{
var dims = new long[] { 2, 2 };
var data = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
@@ -25,7 +25,7 @@ namespace Tensorflow.Native.UnitTest.Eager
return th;
}
- SafeEagerOpHandle MatMulOp(SafeContextHandle ctx, SafeTensorHandleHandle a, SafeTensorHandleHandle b)
+ SafeEagerOpHandle MatMulOp(SafeContextHandle ctx, SafeEagerTensorHandle a, SafeEagerTensorHandle b)
{
using var status = TF_NewStatus();
@@ -63,7 +63,7 @@ namespace Tensorflow.Native.UnitTest.Eager
return false;
}
- SafeEagerOpHandle ShapeOp(SafeContextHandle ctx, SafeTensorHandleHandle a)
+ SafeEagerOpHandle ShapeOp(SafeContextHandle ctx, SafeEagerTensorHandle a)
{
using var status = TF_NewStatus();
@@ -76,27 +76,27 @@ namespace Tensorflow.Native.UnitTest.Eager
return op;
}
- unsafe SafeTensorHandleHandle CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status)
+ unsafe SafeEagerTensorHandle CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status)
{
- var var_handle = new SafeTensorHandleHandle[1];
+ var var_handle = new SafeEagerTensorHandle[1];
int num_retvals;
using (var op = TFE_NewOp(ctx, "VarHandleOp", status))
{
- if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
+ if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero);
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpSetAttrShape(op, "shape", new long[0], 0, status);
TFE_OpSetAttrString(op, "container", "", 0);
TFE_OpSetAttrString(op, "shared_name", "", 0);
- if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
+ if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero);
TFE_Execute(op, var_handle, out num_retvals, status);
- if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
+ if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero);
CHECK_EQ(1, num_retvals);
}
// Assign 'value' to it.
using (var op = TFE_NewOp(ctx, "AssignVariableOp", status))
{
- if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
+ if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero);
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpAddInput(op, var_handle[0], status);
@@ -105,20 +105,20 @@ namespace Tensorflow.Native.UnitTest.Eager
tf.memcpy(TF_TensorData(t).ToPointer(), &value, TF_TensorByteSize(t));
var value_handle = c_api.TFE_NewTensorHandle(t, status);
- if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
+ if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero);
TFE_OpAddInput(op, value_handle, status);
- if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
+ if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero);
c_api.TFE_Execute(op, null, out num_retvals, status);
- if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
+ if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero);
CHECK_EQ(0, num_retvals);
}
return var_handle[0];
}
- SafeTensorHandleHandle TestAxisTensorHandle()
+ SafeEagerTensorHandle TestAxisTensorHandle()
{
var dims = new long[] { 1 };
var data = new int[] { 1 };
@@ -131,7 +131,7 @@ namespace Tensorflow.Native.UnitTest.Eager
return th;
}
- SafeTensorHandleHandle TestScalarTensorHandle(bool value)
+ SafeEagerTensorHandle TestScalarTensorHandle(bool value)
{
var data = new[] { value };
var t = c_api.TF_AllocateTensor(TF_BOOL, null, 0, sizeof(bool));
@@ -143,7 +143,7 @@ namespace Tensorflow.Native.UnitTest.Eager
return th;
}
- SafeTensorHandleHandle TestScalarTensorHandle(float value)
+ SafeEagerTensorHandle TestScalarTensorHandle(float value)
{
var data = new[] { value };
var t = c_api.TF_AllocateTensor(TF_FLOAT, null, 0, sizeof(float));