diff --git a/src/TensorFlowNET.Core/Device/c_api.device.cs b/src/TensorFlowNET.Core/Device/c_api.device.cs index 3b7cca5a..bd2d1295 100644 --- a/src/TensorFlowNET.Core/Device/c_api.device.cs +++ b/src/TensorFlowNET.Core/Device/c_api.device.cs @@ -67,7 +67,7 @@ namespace Tensorflow /// TF_Status* /// TFE_TensorHandle* [DllImport(TensorFlowLibName)] - public static extern SafeTensorHandleHandle TFE_TensorHandleCopyToDevice(SafeTensorHandleHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status); + public static extern SafeEagerTensorHandle TFE_TensorHandleCopyToDevice(SafeEagerTensorHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status); /// /// Retrieves the full name of the device (e.g. /job:worker/replica:0/...) diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs index 2fc3f40c..4aad851f 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs @@ -48,7 +48,7 @@ namespace Tensorflow.Eager { for (int i = 0; i < inputs.Length; ++i) { - SafeTensorHandleHandle tensor_handle = inputs[i] switch + SafeEagerTensorHandle tensor_handle = inputs[i] switch { EagerTensor et => et.EagerTensorHandle, Tensor nd => nd.EagerTensorHandle, @@ -61,7 +61,7 @@ namespace Tensorflow.Eager if (status.ok() && attrs != null) SetOpAttrs(op, attrs); - var outputs = new SafeTensorHandleHandle[num_outputs]; + var outputs = new SafeEagerTensorHandle[num_outputs]; if (status.ok()) { c_api.TFE_Execute(op, outputs, out num_outputs, status.Handle); diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs index 20049952..3bab7c07 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs @@ -141,7 +141,7 @@ namespace Tensorflow.Eager num_retvals += (int)delta; } - var retVals = new SafeTensorHandleHandle[num_retvals]; + var retVals = new SafeEagerTensorHandle[num_retvals]; c_api.TFE_Execute(op, retVals, out num_retvals, status.Handle); status.Check(true); diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs index fa7309e3..1390daf2 100644 --- a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs @@ -12,7 +12,7 @@ namespace Tensorflow.Eager NewEagerTensorHandle(handle); } - public EagerTensor(SafeTensorHandleHandle handle) + public EagerTensor(SafeEagerTensorHandle handle) { _id = ops.uid(); _eagerTensorHandle = handle; diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.cs index addb93de..30a13312 100644 --- a/src/TensorFlowNET.Core/Eager/EagerTensor.cs +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.cs @@ -6,17 +6,19 @@ namespace Tensorflow.Eager { public partial class EagerTensor : Tensor { - public override string Device - { - get - { - using var _ = EagerTensorHandle.Lease(); - return c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(EagerTensorHandle, tf.Status.Handle)); - } - } + public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(_eagerTensorHandle, tf.Status.Handle)); + public override TF_DataType dtype => c_api.TFE_TensorHandleDataType(_eagerTensorHandle); public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.Status.Handle); + protected override Shape GetShapeInternal() + { + var dims = new int[c_api.TFE_TensorHandleNumDims(_eagerTensorHandle, tf.Status.Handle)]; + for (int i = 0; i < dims.Length; i++) + dims[i] = c_api.TFE_TensorHandleDim(_eagerTensorHandle, i, tf.Status.Handle); + return dims; + } + public static int GetRank(IntPtr handle) { var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); diff --git a/src/TensorFlowNET.Core/Eager/SafeTensorHandleHandle.cs b/src/TensorFlowNET.Core/Eager/SafeEagerTensorHandle.cs similarity index 88% rename from src/TensorFlowNET.Core/Eager/SafeTensorHandleHandle.cs rename to src/TensorFlowNET.Core/Eager/SafeEagerTensorHandle.cs index 9f91e881..025e6511 100644 --- a/src/TensorFlowNET.Core/Eager/SafeTensorHandleHandle.cs +++ b/src/TensorFlowNET.Core/Eager/SafeEagerTensorHandle.cs @@ -20,13 +20,13 @@ using static Tensorflow.Binding; namespace Tensorflow.Eager { - public sealed class SafeTensorHandleHandle : SafeTensorflowHandle + public sealed class SafeEagerTensorHandle : SafeTensorflowHandle { - private SafeTensorHandleHandle() + private SafeEagerTensorHandle() { } - public SafeTensorHandleHandle(IntPtr handle) + public SafeEagerTensorHandle(IntPtr handle) : base(handle) { } diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs index f8911bd4..d874ac93 100644 --- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs +++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs @@ -94,7 +94,7 @@ namespace Tensorflow /// TF_Status* /// [DllImport(TensorFlowLibName)] - public static extern int TFE_OpAddInputList(SafeEagerOpHandle op, [In, MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(SafeHandleArrayMarshaler))] SafeTensorHandleHandle[] inputs, int num_inputs, SafeStatusHandle status); + public static extern int TFE_OpAddInputList(SafeEagerOpHandle op, [In, MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(SafeHandleArrayMarshaler))] SafeEagerTensorHandle[] inputs, int num_inputs, SafeStatusHandle status); /// /// @@ -161,7 +161,7 @@ namespace Tensorflow /// /// /// - public static void TFE_Execute(SafeEagerOpHandle op, SafeTensorHandleHandle[] retvals, out int num_retvals, SafeStatusHandle status) + public static void TFE_Execute(SafeEagerOpHandle op, SafeEagerTensorHandle[] retvals, out int num_retvals, SafeStatusHandle status) { unsafe { @@ -173,7 +173,7 @@ namespace Tensorflow // A handle is created for every return, even if rawReturns[i] is null. The resulting handle will be // non-null but invalid, which is the same behavior P/Invoke gives for non-array SafeHandle return // values. - retvals[i] = new SafeTensorHandleHandle(rawReturns[i]); + retvals[i] = new SafeEagerTensorHandle(rawReturns[i]); } } } @@ -295,7 +295,7 @@ namespace Tensorflow /// TFE_TensorHandle* /// TF_Status* [DllImport(TensorFlowLibName)] - public static extern void TFE_OpAddInput(SafeEagerOpHandle op, SafeTensorHandleHandle h, SafeStatusHandle status); + public static extern void TFE_OpAddInput(SafeEagerOpHandle op, SafeEagerTensorHandle h, SafeStatusHandle status); /// /// @@ -303,10 +303,10 @@ namespace Tensorflow /// const tensorflow::Tensor& /// TFE_TensorHandle* [DllImport(TensorFlowLibName)] - public static extern SafeTensorHandleHandle TFE_NewTensorHandle(SafeTensorHandle t, SafeStatusHandle status); + public static extern SafeEagerTensorHandle TFE_NewTensorHandle(SafeTensorHandle t, SafeStatusHandle status); [DllImport(TensorFlowLibName)] - public static extern SafeTensorHandleHandle TFE_EagerTensorHandle(IntPtr t); + public static extern SafeEagerTensorHandle TFE_EagerTensorHandle(IntPtr t); /// /// Sets the default execution mode (sync/async). Note that this can be @@ -323,7 +323,7 @@ namespace Tensorflow /// TFE_TensorHandle* /// [DllImport(TensorFlowLibName)] - public static extern TF_DataType TFE_TensorHandleDataType(SafeTensorHandleHandle h); + public static extern TF_DataType TFE_TensorHandleDataType(SafeEagerTensorHandle h); /// /// This function will block till the operation that produces `h` has @@ -334,7 +334,7 @@ namespace Tensorflow /// TF_Status* /// [DllImport(TensorFlowLibName)] - public static extern SafeTensorHandle TFE_TensorHandleResolve(SafeTensorHandleHandle h, SafeStatusHandle status); + public static extern SafeTensorHandle TFE_TensorHandleResolve(SafeEagerTensorHandle h, SafeStatusHandle status); /// @@ -344,10 +344,10 @@ namespace Tensorflow /// TF_Status* /// [DllImport(TensorFlowLibName)] - public static extern int TFE_TensorHandleNumDims(SafeTensorHandleHandle h, SafeStatusHandle status); + public static extern int TFE_TensorHandleNumDims(SafeEagerTensorHandle h, SafeStatusHandle status); [DllImport(TensorFlowLibName)] - public static extern int TFE_TensorHandleDim(SafeTensorHandleHandle h, int dim, SafeStatusHandle status); + public static extern int TFE_TensorHandleDim(SafeEagerTensorHandle h, int dim, SafeStatusHandle status); /// /// Returns the device of the operation that produced `h`. If `h` was produced by @@ -360,7 +360,7 @@ namespace Tensorflow /// TF_Status* /// [DllImport(TensorFlowLibName)] - public static extern IntPtr TFE_TensorHandleDeviceName(SafeTensorHandleHandle h, SafeStatusHandle status); + public static extern IntPtr TFE_TensorHandleDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status); /// /// Returns the name of the device in whose memory `h` resides. @@ -369,7 +369,7 @@ namespace Tensorflow /// TF_Status* /// [DllImport(TensorFlowLibName)] - public static extern IntPtr TFE_TensorHandleBackingDeviceName(SafeTensorHandleHandle h, SafeStatusHandle status); + public static extern IntPtr TFE_TensorHandleBackingDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status); /// /// diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 628d1ce0..3e76d3fa 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -67,9 +67,9 @@ namespace Tensorflow /// /// The DType of elements in this tensor. /// - public TF_DataType dtype => _handle == null ? _override_dtype : c_api.TF_TensorType(_handle); + public virtual TF_DataType dtype => _handle == null ? _override_dtype : c_api.TF_TensorType(_handle); public ulong bytesize => _handle == null ? 0 : c_api.TF_TensorByteSize(_handle); - public ulong dtypesize => _handle == null ? 0 : c_api.TF_DataTypeSize(dtype); + public ulong dtypesize => (ulong)dtype.get_datatype_size(); public ulong size => _handle == null ? 0 : bytesize / dtypesize; public IntPtr buffer => _handle == null ? IntPtr.Zero : c_api.TF_TensorData(_handle); public int num_consumers(TF_Output oper_out) => _handle == null ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); @@ -88,11 +88,11 @@ namespace Tensorflow protected new SafeTensorHandle _handle; public SafeTensorHandle Handle => _handle; - protected SafeTensorHandleHandle _eagerTensorHandle; + protected SafeEagerTensorHandle _eagerTensorHandle; /// /// TFE_TensorHandle /// - public SafeTensorHandleHandle EagerTensorHandle => _eagerTensorHandle; + public SafeEagerTensorHandle EagerTensorHandle => _eagerTensorHandle; protected bool _isCreatedInGraphMode; @@ -109,19 +109,7 @@ namespace Tensorflow if (rank < 0) return Shape.Null; - var dims = new Shape(new long[rank]); - - if (_handle == null) - { - c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, tf.Status.Handle); - } - else - { - for (int i = 0; i < rank; i++) - dims[i] = c_api.TF_Dim(_handle, i); - } - - return dims; + return GetShapeInternal(); } set @@ -142,6 +130,23 @@ namespace Tensorflow } } + protected virtual Shape GetShapeInternal() + { + var dims = new Shape(new long[rank]); + + if (_handle == null) + { + c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, tf.Status.Handle); + } + else + { + for (int i = 0; i < rank; i++) + dims[i] = c_api.TF_Dim(_handle, i); + } + + return dims; + } + public int[] _shape_tuple() { return rank < 0 ? null : shape.dims.Select(x => (int)x).ToArray(); diff --git a/test/TensorFlowNET.Native.UnitTest/CApiTest.cs b/test/TensorFlowNET.Native.UnitTest/CApiTest.cs index 781d29ee..2432ec1f 100644 --- a/test/TensorFlowNET.Native.UnitTest/CApiTest.cs +++ b/test/TensorFlowNET.Native.UnitTest/CApiTest.cs @@ -56,10 +56,10 @@ namespace Tensorflow.Native.UnitTest protected void TF_SetAttrBool(OperationDescription desc, string attrName, bool value) => c_api.TF_SetAttrBool(desc, attrName, value); - protected TF_DataType TFE_TensorHandleDataType(SafeTensorHandleHandle h) + protected TF_DataType TFE_TensorHandleDataType(SafeEagerTensorHandle h) => c_api.TFE_TensorHandleDataType(h); - protected int TFE_TensorHandleNumDims(SafeTensorHandleHandle h, SafeStatusHandle status) + protected int TFE_TensorHandleNumDims(SafeEagerTensorHandle h, SafeStatusHandle status) => c_api.TFE_TensorHandleNumDims(h, status); protected TF_Code TF_GetCode(Status s) @@ -80,7 +80,7 @@ namespace Tensorflow.Native.UnitTest protected ulong TF_TensorByteSize(SafeTensorHandle t) => c_api.TF_TensorByteSize(t); - protected void TFE_OpAddInput(SafeEagerOpHandle op, SafeTensorHandleHandle h, SafeStatusHandle status) + protected void TFE_OpAddInput(SafeEagerOpHandle op, SafeEagerTensorHandle h, SafeStatusHandle status) => c_api.TFE_OpAddInput(op, h, status); protected void TFE_OpSetAttrType(SafeEagerOpHandle op, string attr_name, TF_DataType value) @@ -95,10 +95,10 @@ namespace Tensorflow.Native.UnitTest protected SafeEagerOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status) => c_api.TFE_NewOp(ctx, op_or_function_name, status); - protected SafeTensorHandleHandle TFE_NewTensorHandle(SafeTensorHandle t, SafeStatusHandle status) + protected SafeEagerTensorHandle TFE_NewTensorHandle(SafeTensorHandle t, SafeStatusHandle status) => c_api.TFE_NewTensorHandle(t, status); - protected void TFE_Execute(SafeEagerOpHandle op, SafeTensorHandleHandle[] retvals, out int num_retvals, SafeStatusHandle status) + protected void TFE_Execute(SafeEagerOpHandle op, SafeEagerTensorHandle[] retvals, out int num_retvals, SafeStatusHandle status) => c_api.TFE_Execute(op, retvals, out num_retvals, status); protected SafeContextOptionsHandle TFE_NewContextOptions() @@ -110,7 +110,7 @@ namespace Tensorflow.Native.UnitTest protected int TFE_OpGetInputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status) => c_api.TFE_OpGetInputLength(op, input_name, status); - protected int TFE_OpAddInputList(SafeEagerOpHandle op, SafeTensorHandleHandle[] inputs, int num_inputs, SafeStatusHandle status) + protected int TFE_OpAddInputList(SafeEagerOpHandle op, SafeEagerTensorHandle[] inputs, int num_inputs, SafeStatusHandle status) => c_api.TFE_OpAddInputList(op, inputs, num_inputs, status); protected int TFE_OpGetOutputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status) @@ -125,13 +125,13 @@ namespace Tensorflow.Native.UnitTest protected void TFE_ExecutorWaitForAllPendingNodes(SafeExecutorHandle executor, SafeStatusHandle status) => c_api.TFE_ExecutorWaitForAllPendingNodes(executor, status); - protected SafeTensorHandle TFE_TensorHandleResolve(SafeTensorHandleHandle h, SafeStatusHandle status) + protected SafeTensorHandle TFE_TensorHandleResolve(SafeEagerTensorHandle h, SafeStatusHandle status) => c_api.TFE_TensorHandleResolve(h, status); - protected string TFE_TensorHandleDeviceName(SafeTensorHandleHandle h, SafeStatusHandle status) + protected string TFE_TensorHandleDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status) => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(h, status)); - protected string TFE_TensorHandleBackingDeviceName(SafeTensorHandleHandle h, SafeStatusHandle status) + protected string TFE_TensorHandleBackingDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status) => c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status)); protected SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status) @@ -146,7 +146,7 @@ namespace Tensorflow.Native.UnitTest protected string TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status) => c_api.TF_DeviceListName(list, index, status); - protected SafeTensorHandleHandle TFE_TensorHandleCopyToDevice(SafeTensorHandleHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status) + protected SafeEagerTensorHandle TFE_TensorHandleCopyToDevice(SafeEagerTensorHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status) => c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status); protected void TFE_OpSetDevice(SafeEagerOpHandle op, string device_name, SafeStatusHandle status) diff --git a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.Execute_MatMul_CPU.cs b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.Execute_MatMul_CPU.cs index 5873b2c9..c8502735 100644 --- a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.Execute_MatMul_CPU.cs +++ b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.Execute_MatMul_CPU.cs @@ -32,7 +32,7 @@ namespace Tensorflow.Native.UnitTest.Eager { CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); - var retvals = new SafeTensorHandleHandle[2]; + var retvals = new SafeEagerTensorHandle[2]; using (var m = TestMatrixTensorHandle()) using (var matmul = MatMulOp(ctx, m, m)) { diff --git a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpGetInputAndOutputLengths.cs b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpGetInputAndOutputLengths.cs index ce5a287f..ff31b195 100644 --- a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpGetInputAndOutputLengths.cs +++ b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpGetInputAndOutputLengths.cs @@ -25,7 +25,7 @@ namespace Tensorflow.Native.UnitTest.Eager using var input1 = TestMatrixTensorHandle(); using var input2 = TestMatrixTensorHandle(); - var retvals = new SafeTensorHandleHandle[2]; + var retvals = new SafeEagerTensorHandle[2]; using (var identityOp = TFE_NewOp(ctx, "IdentityN", status)) { CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); @@ -36,7 +36,7 @@ namespace Tensorflow.Native.UnitTest.Eager EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status)); CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); - var inputs = new SafeTensorHandleHandle[] { input1, input2 }; + var inputs = new SafeEagerTensorHandle[] { input1, input2 }; TFE_OpAddInputList(identityOp, inputs, 2, status); CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); diff --git a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpInferMixedTypeInputListAttrs.cs b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpInferMixedTypeInputListAttrs.cs index ad878115..ab0d5181 100644 --- a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpInferMixedTypeInputListAttrs.cs +++ b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpInferMixedTypeInputListAttrs.cs @@ -41,7 +41,7 @@ namespace Tensorflow.Native.UnitTest.Eager //EXPECT_EQ(attr_found->second.list().type(1), tensorflow::DataType::DT_FLOAT); //EXPECT_EQ(attr_found->second.list().type(2), tensorflow::DataType::DT_INT32); - var retvals = new SafeTensorHandleHandle[0]; + var retvals = new SafeEagerTensorHandle[0]; int num_retvals; TFE_Execute(assertOp, retvals, out num_retvals, status); EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); diff --git a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.TensorHandleDevices.cs b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.TensorHandleDevices.cs index 9fc8f95e..bc430f87 100644 --- a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.TensorHandleDevices.cs +++ b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.TensorHandleDevices.cs @@ -39,7 +39,7 @@ namespace Tensorflow.Native.UnitTest.Eager using var hgpu = TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_device_name, status); ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); - var retvals = new SafeTensorHandleHandle[1]; + var retvals = new SafeEagerTensorHandle[1]; using (var shape_op = ShapeOp(ctx, hgpu)) { TFE_OpSetDevice(shape_op, gpu_device_name, status); diff --git a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.Variables.cs b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.Variables.cs index 310a933a..7c43e111 100644 --- a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.Variables.cs +++ b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.Variables.cs @@ -28,7 +28,7 @@ namespace Tensorflow.Native.UnitTest.Eager ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); int num_retvals = 1; - var value_handle = new SafeTensorHandleHandle[1]; + var value_handle = new SafeEagerTensorHandle[1]; using (var op = TFE_NewOp(ctx, "ReadVariableOp", status)) { ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); diff --git a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.cs b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.cs index 864c09f0..c38ba5a5 100644 --- a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.cs +++ b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.cs @@ -11,7 +11,7 @@ namespace Tensorflow.Native.UnitTest.Eager [TestClass] public partial class CApiEagerTest : CApiTest { - SafeTensorHandleHandle TestMatrixTensorHandle() + SafeEagerTensorHandle TestMatrixTensorHandle() { var dims = new long[] { 2, 2 }; var data = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; @@ -25,7 +25,7 @@ namespace Tensorflow.Native.UnitTest.Eager return th; } - SafeEagerOpHandle MatMulOp(SafeContextHandle ctx, SafeTensorHandleHandle a, SafeTensorHandleHandle b) + SafeEagerOpHandle MatMulOp(SafeContextHandle ctx, SafeEagerTensorHandle a, SafeEagerTensorHandle b) { using var status = TF_NewStatus(); @@ -63,7 +63,7 @@ namespace Tensorflow.Native.UnitTest.Eager return false; } - SafeEagerOpHandle ShapeOp(SafeContextHandle ctx, SafeTensorHandleHandle a) + SafeEagerOpHandle ShapeOp(SafeContextHandle ctx, SafeEagerTensorHandle a) { using var status = TF_NewStatus(); @@ -76,27 +76,27 @@ namespace Tensorflow.Native.UnitTest.Eager return op; } - unsafe SafeTensorHandleHandle CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status) + unsafe SafeEagerTensorHandle CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status) { - var var_handle = new SafeTensorHandleHandle[1]; + var var_handle = new SafeEagerTensorHandle[1]; int num_retvals; using (var op = TFE_NewOp(ctx, "VarHandleOp", status)) { - if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); + if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero); TFE_OpSetAttrType(op, "dtype", TF_FLOAT); TFE_OpSetAttrShape(op, "shape", new long[0], 0, status); TFE_OpSetAttrString(op, "container", "", 0); TFE_OpSetAttrString(op, "shared_name", "", 0); - if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); + if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero); TFE_Execute(op, var_handle, out num_retvals, status); - if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); + if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero); CHECK_EQ(1, num_retvals); } // Assign 'value' to it. using (var op = TFE_NewOp(ctx, "AssignVariableOp", status)) { - if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); + if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero); TFE_OpSetAttrType(op, "dtype", TF_FLOAT); TFE_OpAddInput(op, var_handle[0], status); @@ -105,20 +105,20 @@ namespace Tensorflow.Native.UnitTest.Eager tf.memcpy(TF_TensorData(t).ToPointer(), &value, TF_TensorByteSize(t)); var value_handle = c_api.TFE_NewTensorHandle(t, status); - if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); + if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero); TFE_OpAddInput(op, value_handle, status); - if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); + if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero); c_api.TFE_Execute(op, null, out num_retvals, status); - if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); + if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero); CHECK_EQ(0, num_retvals); } return var_handle[0]; } - SafeTensorHandleHandle TestAxisTensorHandle() + SafeEagerTensorHandle TestAxisTensorHandle() { var dims = new long[] { 1 }; var data = new int[] { 1 }; @@ -131,7 +131,7 @@ namespace Tensorflow.Native.UnitTest.Eager return th; } - SafeTensorHandleHandle TestScalarTensorHandle(bool value) + SafeEagerTensorHandle TestScalarTensorHandle(bool value) { var data = new[] { value }; var t = c_api.TF_AllocateTensor(TF_BOOL, null, 0, sizeof(bool)); @@ -143,7 +143,7 @@ namespace Tensorflow.Native.UnitTest.Eager return th; } - SafeTensorHandleHandle TestScalarTensorHandle(float value) + SafeEagerTensorHandle TestScalarTensorHandle(float value) { var data = new[] { value }; var t = c_api.TF_AllocateTensor(TF_FLOAT, null, 0, sizeof(float));