@@ -67,7 +67,7 @@ namespace Tensorflow | |||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
/// <returns>TFE_TensorHandle*</returns> | /// <returns>TFE_TensorHandle*</returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern SafeTensorHandleHandle TFE_TensorHandleCopyToDevice(SafeTensorHandleHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status); | |||||
public static extern SafeEagerTensorHandle TFE_TensorHandleCopyToDevice(SafeEagerTensorHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// Retrieves the full name of the device (e.g. /job:worker/replica:0/...) | /// Retrieves the full name of the device (e.g. /job:worker/replica:0/...) | ||||
@@ -48,7 +48,7 @@ namespace Tensorflow.Eager | |||||
{ | { | ||||
for (int i = 0; i < inputs.Length; ++i) | for (int i = 0; i < inputs.Length; ++i) | ||||
{ | { | ||||
SafeTensorHandleHandle tensor_handle = inputs[i] switch | |||||
SafeEagerTensorHandle tensor_handle = inputs[i] switch | |||||
{ | { | ||||
EagerTensor et => et.EagerTensorHandle, | EagerTensor et => et.EagerTensorHandle, | ||||
Tensor nd => nd.EagerTensorHandle, | Tensor nd => nd.EagerTensorHandle, | ||||
@@ -61,7 +61,7 @@ namespace Tensorflow.Eager | |||||
if (status.ok() && attrs != null) | if (status.ok() && attrs != null) | ||||
SetOpAttrs(op, attrs); | SetOpAttrs(op, attrs); | ||||
var outputs = new SafeTensorHandleHandle[num_outputs]; | |||||
var outputs = new SafeEagerTensorHandle[num_outputs]; | |||||
if (status.ok()) | if (status.ok()) | ||||
{ | { | ||||
c_api.TFE_Execute(op, outputs, out num_outputs, status.Handle); | c_api.TFE_Execute(op, outputs, out num_outputs, status.Handle); | ||||
@@ -141,7 +141,7 @@ namespace Tensorflow.Eager | |||||
num_retvals += (int)delta; | num_retvals += (int)delta; | ||||
} | } | ||||
var retVals = new SafeTensorHandleHandle[num_retvals]; | |||||
var retVals = new SafeEagerTensorHandle[num_retvals]; | |||||
c_api.TFE_Execute(op, retVals, out num_retvals, status.Handle); | c_api.TFE_Execute(op, retVals, out num_retvals, status.Handle); | ||||
status.Check(true); | status.Check(true); | ||||
@@ -12,7 +12,7 @@ namespace Tensorflow.Eager | |||||
NewEagerTensorHandle(handle); | NewEagerTensorHandle(handle); | ||||
} | } | ||||
public EagerTensor(SafeTensorHandleHandle handle) | |||||
public EagerTensor(SafeEagerTensorHandle handle) | |||||
{ | { | ||||
_id = ops.uid(); | _id = ops.uid(); | ||||
_eagerTensorHandle = handle; | _eagerTensorHandle = handle; | ||||
@@ -6,17 +6,19 @@ namespace Tensorflow.Eager | |||||
{ | { | ||||
public partial class EagerTensor : Tensor | public partial class EagerTensor : Tensor | ||||
{ | { | ||||
public override string Device | |||||
{ | |||||
get | |||||
{ | |||||
using var _ = EagerTensorHandle.Lease(); | |||||
return c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(EagerTensorHandle, tf.Status.Handle)); | |||||
} | |||||
} | |||||
public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(_eagerTensorHandle, tf.Status.Handle)); | |||||
public override TF_DataType dtype => c_api.TFE_TensorHandleDataType(_eagerTensorHandle); | |||||
public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.Status.Handle); | public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.Status.Handle); | ||||
protected override Shape GetShapeInternal() | |||||
{ | |||||
var dims = new int[c_api.TFE_TensorHandleNumDims(_eagerTensorHandle, tf.Status.Handle)]; | |||||
for (int i = 0; i < dims.Length; i++) | |||||
dims[i] = c_api.TFE_TensorHandleDim(_eagerTensorHandle, i, tf.Status.Handle); | |||||
return dims; | |||||
} | |||||
public static int GetRank(IntPtr handle) | public static int GetRank(IntPtr handle) | ||||
{ | { | ||||
var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); | var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); | ||||
@@ -20,13 +20,13 @@ using static Tensorflow.Binding; | |||||
namespace Tensorflow.Eager | namespace Tensorflow.Eager | ||||
{ | { | ||||
public sealed class SafeTensorHandleHandle : SafeTensorflowHandle | |||||
public sealed class SafeEagerTensorHandle : SafeTensorflowHandle | |||||
{ | { | ||||
private SafeTensorHandleHandle() | |||||
private SafeEagerTensorHandle() | |||||
{ | { | ||||
} | } | ||||
public SafeTensorHandleHandle(IntPtr handle) | |||||
public SafeEagerTensorHandle(IntPtr handle) | |||||
: base(handle) | : base(handle) | ||||
{ | { | ||||
} | } |
@@ -94,7 +94,7 @@ namespace Tensorflow | |||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern int TFE_OpAddInputList(SafeEagerOpHandle op, [In, MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(SafeHandleArrayMarshaler))] SafeTensorHandleHandle[] inputs, int num_inputs, SafeStatusHandle status); | |||||
public static extern int TFE_OpAddInputList(SafeEagerOpHandle op, [In, MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(SafeHandleArrayMarshaler))] SafeEagerTensorHandle[] inputs, int num_inputs, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -161,7 +161,7 @@ namespace Tensorflow | |||||
/// <param name="retvals"></param> | /// <param name="retvals"></param> | ||||
/// <param name="num_retvals"></param> | /// <param name="num_retvals"></param> | ||||
/// <param name="status"></param> | /// <param name="status"></param> | ||||
public static void TFE_Execute(SafeEagerOpHandle op, SafeTensorHandleHandle[] retvals, out int num_retvals, SafeStatusHandle status) | |||||
public static void TFE_Execute(SafeEagerOpHandle op, SafeEagerTensorHandle[] retvals, out int num_retvals, SafeStatusHandle status) | |||||
{ | { | ||||
unsafe | unsafe | ||||
{ | { | ||||
@@ -173,7 +173,7 @@ namespace Tensorflow | |||||
// A handle is created for every return, even if rawReturns[i] is null. The resulting handle will be | // A handle is created for every return, even if rawReturns[i] is null. The resulting handle will be | ||||
// non-null but invalid, which is the same behavior P/Invoke gives for non-array SafeHandle return | // non-null but invalid, which is the same behavior P/Invoke gives for non-array SafeHandle return | ||||
// values. | // values. | ||||
retvals[i] = new SafeTensorHandleHandle(rawReturns[i]); | |||||
retvals[i] = new SafeEagerTensorHandle(rawReturns[i]); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -295,7 +295,7 @@ namespace Tensorflow | |||||
/// <param name="h">TFE_TensorHandle*</param> | /// <param name="h">TFE_TensorHandle*</param> | ||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_OpAddInput(SafeEagerOpHandle op, SafeTensorHandleHandle h, SafeStatusHandle status); | |||||
public static extern void TFE_OpAddInput(SafeEagerOpHandle op, SafeEagerTensorHandle h, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -303,10 +303,10 @@ namespace Tensorflow | |||||
/// <param name="t">const tensorflow::Tensor&</param> | /// <param name="t">const tensorflow::Tensor&</param> | ||||
/// <returns>TFE_TensorHandle*</returns> | /// <returns>TFE_TensorHandle*</returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern SafeTensorHandleHandle TFE_NewTensorHandle(SafeTensorHandle t, SafeStatusHandle status); | |||||
public static extern SafeEagerTensorHandle TFE_NewTensorHandle(SafeTensorHandle t, SafeStatusHandle status); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern SafeTensorHandleHandle TFE_EagerTensorHandle(IntPtr t); | |||||
public static extern SafeEagerTensorHandle TFE_EagerTensorHandle(IntPtr t); | |||||
/// <summary> | /// <summary> | ||||
/// Sets the default execution mode (sync/async). Note that this can be | /// Sets the default execution mode (sync/async). Note that this can be | ||||
@@ -323,7 +323,7 @@ namespace Tensorflow | |||||
/// <param name="h">TFE_TensorHandle*</param> | /// <param name="h">TFE_TensorHandle*</param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern TF_DataType TFE_TensorHandleDataType(SafeTensorHandleHandle h); | |||||
public static extern TF_DataType TFE_TensorHandleDataType(SafeEagerTensorHandle h); | |||||
/// <summary> | /// <summary> | ||||
/// This function will block till the operation that produces `h` has | /// This function will block till the operation that produces `h` has | ||||
@@ -334,7 +334,7 @@ namespace Tensorflow | |||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern SafeTensorHandle TFE_TensorHandleResolve(SafeTensorHandleHandle h, SafeStatusHandle status); | |||||
public static extern SafeTensorHandle TFE_TensorHandleResolve(SafeEagerTensorHandle h, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
@@ -344,10 +344,10 @@ namespace Tensorflow | |||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern int TFE_TensorHandleNumDims(SafeTensorHandleHandle h, SafeStatusHandle status); | |||||
public static extern int TFE_TensorHandleNumDims(SafeEagerTensorHandle h, SafeStatusHandle status); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern int TFE_TensorHandleDim(SafeTensorHandleHandle h, int dim, SafeStatusHandle status); | |||||
public static extern int TFE_TensorHandleDim(SafeEagerTensorHandle h, int dim, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// Returns the device of the operation that produced `h`. If `h` was produced by | /// Returns the device of the operation that produced `h`. If `h` was produced by | ||||
@@ -360,7 +360,7 @@ namespace Tensorflow | |||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern IntPtr TFE_TensorHandleDeviceName(SafeTensorHandleHandle h, SafeStatusHandle status); | |||||
public static extern IntPtr TFE_TensorHandleDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// Returns the name of the device in whose memory `h` resides. | /// Returns the name of the device in whose memory `h` resides. | ||||
@@ -369,7 +369,7 @@ namespace Tensorflow | |||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern IntPtr TFE_TensorHandleBackingDeviceName(SafeTensorHandleHandle h, SafeStatusHandle status); | |||||
public static extern IntPtr TFE_TensorHandleBackingDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -67,9 +67,9 @@ namespace Tensorflow | |||||
/// <summary> | /// <summary> | ||||
/// The DType of elements in this tensor. | /// The DType of elements in this tensor. | ||||
/// </summary> | /// </summary> | ||||
public TF_DataType dtype => _handle == null ? _override_dtype : c_api.TF_TensorType(_handle); | |||||
public virtual TF_DataType dtype => _handle == null ? _override_dtype : c_api.TF_TensorType(_handle); | |||||
public ulong bytesize => _handle == null ? 0 : c_api.TF_TensorByteSize(_handle); | public ulong bytesize => _handle == null ? 0 : c_api.TF_TensorByteSize(_handle); | ||||
public ulong dtypesize => _handle == null ? 0 : c_api.TF_DataTypeSize(dtype); | |||||
public ulong dtypesize => (ulong)dtype.get_datatype_size(); | |||||
public ulong size => _handle == null ? 0 : bytesize / dtypesize; | public ulong size => _handle == null ? 0 : bytesize / dtypesize; | ||||
public IntPtr buffer => _handle == null ? IntPtr.Zero : c_api.TF_TensorData(_handle); | public IntPtr buffer => _handle == null ? IntPtr.Zero : c_api.TF_TensorData(_handle); | ||||
public int num_consumers(TF_Output oper_out) => _handle == null ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); | public int num_consumers(TF_Output oper_out) => _handle == null ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); | ||||
@@ -88,11 +88,11 @@ namespace Tensorflow | |||||
protected new SafeTensorHandle _handle; | protected new SafeTensorHandle _handle; | ||||
public SafeTensorHandle Handle => _handle; | public SafeTensorHandle Handle => _handle; | ||||
protected SafeTensorHandleHandle _eagerTensorHandle; | |||||
protected SafeEagerTensorHandle _eagerTensorHandle; | |||||
/// <summary> | /// <summary> | ||||
/// TFE_TensorHandle | /// TFE_TensorHandle | ||||
/// </summary> | /// </summary> | ||||
public SafeTensorHandleHandle EagerTensorHandle => _eagerTensorHandle; | |||||
public SafeEagerTensorHandle EagerTensorHandle => _eagerTensorHandle; | |||||
protected bool _isCreatedInGraphMode; | protected bool _isCreatedInGraphMode; | ||||
@@ -109,19 +109,7 @@ namespace Tensorflow | |||||
if (rank < 0) | if (rank < 0) | ||||
return Shape.Null; | return Shape.Null; | ||||
var dims = new Shape(new long[rank]); | |||||
if (_handle == null) | |||||
{ | |||||
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, tf.Status.Handle); | |||||
} | |||||
else | |||||
{ | |||||
for (int i = 0; i < rank; i++) | |||||
dims[i] = c_api.TF_Dim(_handle, i); | |||||
} | |||||
return dims; | |||||
return GetShapeInternal(); | |||||
} | } | ||||
set | set | ||||
@@ -142,6 +130,23 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
protected virtual Shape GetShapeInternal() | |||||
{ | |||||
var dims = new Shape(new long[rank]); | |||||
if (_handle == null) | |||||
{ | |||||
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, tf.Status.Handle); | |||||
} | |||||
else | |||||
{ | |||||
for (int i = 0; i < rank; i++) | |||||
dims[i] = c_api.TF_Dim(_handle, i); | |||||
} | |||||
return dims; | |||||
} | |||||
public int[] _shape_tuple() | public int[] _shape_tuple() | ||||
{ | { | ||||
return rank < 0 ? null : shape.dims.Select(x => (int)x).ToArray(); | return rank < 0 ? null : shape.dims.Select(x => (int)x).ToArray(); | ||||
@@ -56,10 +56,10 @@ namespace Tensorflow.Native.UnitTest | |||||
protected void TF_SetAttrBool(OperationDescription desc, string attrName, bool value) | protected void TF_SetAttrBool(OperationDescription desc, string attrName, bool value) | ||||
=> c_api.TF_SetAttrBool(desc, attrName, value); | => c_api.TF_SetAttrBool(desc, attrName, value); | ||||
protected TF_DataType TFE_TensorHandleDataType(SafeTensorHandleHandle h) | |||||
protected TF_DataType TFE_TensorHandleDataType(SafeEagerTensorHandle h) | |||||
=> c_api.TFE_TensorHandleDataType(h); | => c_api.TFE_TensorHandleDataType(h); | ||||
protected int TFE_TensorHandleNumDims(SafeTensorHandleHandle h, SafeStatusHandle status) | |||||
protected int TFE_TensorHandleNumDims(SafeEagerTensorHandle h, SafeStatusHandle status) | |||||
=> c_api.TFE_TensorHandleNumDims(h, status); | => c_api.TFE_TensorHandleNumDims(h, status); | ||||
protected TF_Code TF_GetCode(Status s) | protected TF_Code TF_GetCode(Status s) | ||||
@@ -80,7 +80,7 @@ namespace Tensorflow.Native.UnitTest | |||||
protected ulong TF_TensorByteSize(SafeTensorHandle t) | protected ulong TF_TensorByteSize(SafeTensorHandle t) | ||||
=> c_api.TF_TensorByteSize(t); | => c_api.TF_TensorByteSize(t); | ||||
protected void TFE_OpAddInput(SafeEagerOpHandle op, SafeTensorHandleHandle h, SafeStatusHandle status) | |||||
protected void TFE_OpAddInput(SafeEagerOpHandle op, SafeEagerTensorHandle h, SafeStatusHandle status) | |||||
=> c_api.TFE_OpAddInput(op, h, status); | => c_api.TFE_OpAddInput(op, h, status); | ||||
protected void TFE_OpSetAttrType(SafeEagerOpHandle op, string attr_name, TF_DataType value) | protected void TFE_OpSetAttrType(SafeEagerOpHandle op, string attr_name, TF_DataType value) | ||||
@@ -95,10 +95,10 @@ namespace Tensorflow.Native.UnitTest | |||||
protected SafeEagerOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status) | protected SafeEagerOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status) | ||||
=> c_api.TFE_NewOp(ctx, op_or_function_name, status); | => c_api.TFE_NewOp(ctx, op_or_function_name, status); | ||||
protected SafeTensorHandleHandle TFE_NewTensorHandle(SafeTensorHandle t, SafeStatusHandle status) | |||||
protected SafeEagerTensorHandle TFE_NewTensorHandle(SafeTensorHandle t, SafeStatusHandle status) | |||||
=> c_api.TFE_NewTensorHandle(t, status); | => c_api.TFE_NewTensorHandle(t, status); | ||||
protected void TFE_Execute(SafeEagerOpHandle op, SafeTensorHandleHandle[] retvals, out int num_retvals, SafeStatusHandle status) | |||||
protected void TFE_Execute(SafeEagerOpHandle op, SafeEagerTensorHandle[] retvals, out int num_retvals, SafeStatusHandle status) | |||||
=> c_api.TFE_Execute(op, retvals, out num_retvals, status); | => c_api.TFE_Execute(op, retvals, out num_retvals, status); | ||||
protected SafeContextOptionsHandle TFE_NewContextOptions() | protected SafeContextOptionsHandle TFE_NewContextOptions() | ||||
@@ -110,7 +110,7 @@ namespace Tensorflow.Native.UnitTest | |||||
protected int TFE_OpGetInputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status) | protected int TFE_OpGetInputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status) | ||||
=> c_api.TFE_OpGetInputLength(op, input_name, status); | => c_api.TFE_OpGetInputLength(op, input_name, status); | ||||
protected int TFE_OpAddInputList(SafeEagerOpHandle op, SafeTensorHandleHandle[] inputs, int num_inputs, SafeStatusHandle status) | |||||
protected int TFE_OpAddInputList(SafeEagerOpHandle op, SafeEagerTensorHandle[] inputs, int num_inputs, SafeStatusHandle status) | |||||
=> c_api.TFE_OpAddInputList(op, inputs, num_inputs, status); | => c_api.TFE_OpAddInputList(op, inputs, num_inputs, status); | ||||
protected int TFE_OpGetOutputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status) | protected int TFE_OpGetOutputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status) | ||||
@@ -125,13 +125,13 @@ namespace Tensorflow.Native.UnitTest | |||||
protected void TFE_ExecutorWaitForAllPendingNodes(SafeExecutorHandle executor, SafeStatusHandle status) | protected void TFE_ExecutorWaitForAllPendingNodes(SafeExecutorHandle executor, SafeStatusHandle status) | ||||
=> c_api.TFE_ExecutorWaitForAllPendingNodes(executor, status); | => c_api.TFE_ExecutorWaitForAllPendingNodes(executor, status); | ||||
protected SafeTensorHandle TFE_TensorHandleResolve(SafeTensorHandleHandle h, SafeStatusHandle status) | |||||
protected SafeTensorHandle TFE_TensorHandleResolve(SafeEagerTensorHandle h, SafeStatusHandle status) | |||||
=> c_api.TFE_TensorHandleResolve(h, status); | => c_api.TFE_TensorHandleResolve(h, status); | ||||
protected string TFE_TensorHandleDeviceName(SafeTensorHandleHandle h, SafeStatusHandle status) | |||||
protected string TFE_TensorHandleDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status) | |||||
=> c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(h, status)); | => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(h, status)); | ||||
protected string TFE_TensorHandleBackingDeviceName(SafeTensorHandleHandle h, SafeStatusHandle status) | |||||
protected string TFE_TensorHandleBackingDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status) | |||||
=> c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status)); | => c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status)); | ||||
protected SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status) | protected SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status) | ||||
@@ -146,7 +146,7 @@ namespace Tensorflow.Native.UnitTest | |||||
protected string TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status) | protected string TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status) | ||||
=> c_api.TF_DeviceListName(list, index, status); | => c_api.TF_DeviceListName(list, index, status); | ||||
protected SafeTensorHandleHandle TFE_TensorHandleCopyToDevice(SafeTensorHandleHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status) | |||||
protected SafeEagerTensorHandle TFE_TensorHandleCopyToDevice(SafeEagerTensorHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status) | |||||
=> c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status); | => c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status); | ||||
protected void TFE_OpSetDevice(SafeEagerOpHandle op, string device_name, SafeStatusHandle status) | protected void TFE_OpSetDevice(SafeEagerOpHandle op, string device_name, SafeStatusHandle status) | ||||
@@ -32,7 +32,7 @@ namespace Tensorflow.Native.UnitTest.Eager | |||||
{ | { | ||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
var retvals = new SafeTensorHandleHandle[2]; | |||||
var retvals = new SafeEagerTensorHandle[2]; | |||||
using (var m = TestMatrixTensorHandle()) | using (var m = TestMatrixTensorHandle()) | ||||
using (var matmul = MatMulOp(ctx, m, m)) | using (var matmul = MatMulOp(ctx, m, m)) | ||||
{ | { | ||||
@@ -25,7 +25,7 @@ namespace Tensorflow.Native.UnitTest.Eager | |||||
using var input1 = TestMatrixTensorHandle(); | using var input1 = TestMatrixTensorHandle(); | ||||
using var input2 = TestMatrixTensorHandle(); | using var input2 = TestMatrixTensorHandle(); | ||||
var retvals = new SafeTensorHandleHandle[2]; | |||||
var retvals = new SafeEagerTensorHandle[2]; | |||||
using (var identityOp = TFE_NewOp(ctx, "IdentityN", status)) | using (var identityOp = TFE_NewOp(ctx, "IdentityN", status)) | ||||
{ | { | ||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
@@ -36,7 +36,7 @@ namespace Tensorflow.Native.UnitTest.Eager | |||||
EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status)); | EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status)); | ||||
CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); | CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
var inputs = new SafeTensorHandleHandle[] { input1, input2 }; | |||||
var inputs = new SafeEagerTensorHandle[] { input1, input2 }; | |||||
TFE_OpAddInputList(identityOp, inputs, 2, status); | TFE_OpAddInputList(identityOp, inputs, 2, status); | ||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
@@ -41,7 +41,7 @@ namespace Tensorflow.Native.UnitTest.Eager | |||||
//EXPECT_EQ(attr_found->second.list().type(1), tensorflow::DataType::DT_FLOAT); | //EXPECT_EQ(attr_found->second.list().type(1), tensorflow::DataType::DT_FLOAT); | ||||
//EXPECT_EQ(attr_found->second.list().type(2), tensorflow::DataType::DT_INT32); | //EXPECT_EQ(attr_found->second.list().type(2), tensorflow::DataType::DT_INT32); | ||||
var retvals = new SafeTensorHandleHandle[0]; | |||||
var retvals = new SafeEagerTensorHandle[0]; | |||||
int num_retvals; | int num_retvals; | ||||
TFE_Execute(assertOp, retvals, out num_retvals, status); | TFE_Execute(assertOp, retvals, out num_retvals, status); | ||||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
@@ -39,7 +39,7 @@ namespace Tensorflow.Native.UnitTest.Eager | |||||
using var hgpu = TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_device_name, status); | using var hgpu = TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_device_name, status); | ||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); | ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); | ||||
var retvals = new SafeTensorHandleHandle[1]; | |||||
var retvals = new SafeEagerTensorHandle[1]; | |||||
using (var shape_op = ShapeOp(ctx, hgpu)) | using (var shape_op = ShapeOp(ctx, hgpu)) | ||||
{ | { | ||||
TFE_OpSetDevice(shape_op, gpu_device_name, status); | TFE_OpSetDevice(shape_op, gpu_device_name, status); | ||||
@@ -28,7 +28,7 @@ namespace Tensorflow.Native.UnitTest.Eager | |||||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
int num_retvals = 1; | int num_retvals = 1; | ||||
var value_handle = new SafeTensorHandleHandle[1]; | |||||
var value_handle = new SafeEagerTensorHandle[1]; | |||||
using (var op = TFE_NewOp(ctx, "ReadVariableOp", status)) | using (var op = TFE_NewOp(ctx, "ReadVariableOp", status)) | ||||
{ | { | ||||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
@@ -11,7 +11,7 @@ namespace Tensorflow.Native.UnitTest.Eager | |||||
[TestClass] | [TestClass] | ||||
public partial class CApiEagerTest : CApiTest | public partial class CApiEagerTest : CApiTest | ||||
{ | { | ||||
SafeTensorHandleHandle TestMatrixTensorHandle() | |||||
SafeEagerTensorHandle TestMatrixTensorHandle() | |||||
{ | { | ||||
var dims = new long[] { 2, 2 }; | var dims = new long[] { 2, 2 }; | ||||
var data = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; | var data = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; | ||||
@@ -25,7 +25,7 @@ namespace Tensorflow.Native.UnitTest.Eager | |||||
return th; | return th; | ||||
} | } | ||||
SafeEagerOpHandle MatMulOp(SafeContextHandle ctx, SafeTensorHandleHandle a, SafeTensorHandleHandle b) | |||||
SafeEagerOpHandle MatMulOp(SafeContextHandle ctx, SafeEagerTensorHandle a, SafeEagerTensorHandle b) | |||||
{ | { | ||||
using var status = TF_NewStatus(); | using var status = TF_NewStatus(); | ||||
@@ -63,7 +63,7 @@ namespace Tensorflow.Native.UnitTest.Eager | |||||
return false; | return false; | ||||
} | } | ||||
SafeEagerOpHandle ShapeOp(SafeContextHandle ctx, SafeTensorHandleHandle a) | |||||
SafeEagerOpHandle ShapeOp(SafeContextHandle ctx, SafeEagerTensorHandle a) | |||||
{ | { | ||||
using var status = TF_NewStatus(); | using var status = TF_NewStatus(); | ||||
@@ -76,27 +76,27 @@ namespace Tensorflow.Native.UnitTest.Eager | |||||
return op; | return op; | ||||
} | } | ||||
unsafe SafeTensorHandleHandle CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status) | |||||
unsafe SafeEagerTensorHandle CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status) | |||||
{ | { | ||||
var var_handle = new SafeTensorHandleHandle[1]; | |||||
var var_handle = new SafeEagerTensorHandle[1]; | |||||
int num_retvals; | int num_retvals; | ||||
using (var op = TFE_NewOp(ctx, "VarHandleOp", status)) | using (var op = TFE_NewOp(ctx, "VarHandleOp", status)) | ||||
{ | { | ||||
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | |||||
if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero); | |||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT); | TFE_OpSetAttrType(op, "dtype", TF_FLOAT); | ||||
TFE_OpSetAttrShape(op, "shape", new long[0], 0, status); | TFE_OpSetAttrShape(op, "shape", new long[0], 0, status); | ||||
TFE_OpSetAttrString(op, "container", "", 0); | TFE_OpSetAttrString(op, "container", "", 0); | ||||
TFE_OpSetAttrString(op, "shared_name", "", 0); | TFE_OpSetAttrString(op, "shared_name", "", 0); | ||||
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | |||||
if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero); | |||||
TFE_Execute(op, var_handle, out num_retvals, status); | TFE_Execute(op, var_handle, out num_retvals, status); | ||||
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | |||||
if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero); | |||||
CHECK_EQ(1, num_retvals); | CHECK_EQ(1, num_retvals); | ||||
} | } | ||||
// Assign 'value' to it. | // Assign 'value' to it. | ||||
using (var op = TFE_NewOp(ctx, "AssignVariableOp", status)) | using (var op = TFE_NewOp(ctx, "AssignVariableOp", status)) | ||||
{ | { | ||||
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | |||||
if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero); | |||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT); | TFE_OpSetAttrType(op, "dtype", TF_FLOAT); | ||||
TFE_OpAddInput(op, var_handle[0], status); | TFE_OpAddInput(op, var_handle[0], status); | ||||
@@ -105,20 +105,20 @@ namespace Tensorflow.Native.UnitTest.Eager | |||||
tf.memcpy(TF_TensorData(t).ToPointer(), &value, TF_TensorByteSize(t)); | tf.memcpy(TF_TensorData(t).ToPointer(), &value, TF_TensorByteSize(t)); | ||||
var value_handle = c_api.TFE_NewTensorHandle(t, status); | var value_handle = c_api.TFE_NewTensorHandle(t, status); | ||||
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | |||||
if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero); | |||||
TFE_OpAddInput(op, value_handle, status); | TFE_OpAddInput(op, value_handle, status); | ||||
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | |||||
if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero); | |||||
c_api.TFE_Execute(op, null, out num_retvals, status); | c_api.TFE_Execute(op, null, out num_retvals, status); | ||||
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | |||||
if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero); | |||||
CHECK_EQ(0, num_retvals); | CHECK_EQ(0, num_retvals); | ||||
} | } | ||||
return var_handle[0]; | return var_handle[0]; | ||||
} | } | ||||
SafeTensorHandleHandle TestAxisTensorHandle() | |||||
SafeEagerTensorHandle TestAxisTensorHandle() | |||||
{ | { | ||||
var dims = new long[] { 1 }; | var dims = new long[] { 1 }; | ||||
var data = new int[] { 1 }; | var data = new int[] { 1 }; | ||||
@@ -131,7 +131,7 @@ namespace Tensorflow.Native.UnitTest.Eager | |||||
return th; | return th; | ||||
} | } | ||||
SafeTensorHandleHandle TestScalarTensorHandle(bool value) | |||||
SafeEagerTensorHandle TestScalarTensorHandle(bool value) | |||||
{ | { | ||||
var data = new[] { value }; | var data = new[] { value }; | ||||
var t = c_api.TF_AllocateTensor(TF_BOOL, null, 0, sizeof(bool)); | var t = c_api.TF_AllocateTensor(TF_BOOL, null, 0, sizeof(bool)); | ||||
@@ -143,7 +143,7 @@ namespace Tensorflow.Native.UnitTest.Eager | |||||
return th; | return th; | ||||
} | } | ||||
SafeTensorHandleHandle TestScalarTensorHandle(float value) | |||||
SafeEagerTensorHandle TestScalarTensorHandle(float value) | |||||
{ | { | ||||
var data = new[] { value }; | var data = new[] { value }; | ||||
var t = c_api.TF_AllocateTensor(TF_FLOAT, null, 0, sizeof(float)); | var t = c_api.TF_AllocateTensor(TF_FLOAT, null, 0, sizeof(float)); | ||||