Browse Source

override dtype and shape

tags/TimeSeries
Oceania2018 3 years ago
parent
commit
fa37f0725a
15 changed files with 83 additions and 76 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Device/c_api.device.cs
  2. +2
    -2
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
  5. +10
    -8
      src/TensorFlowNET.Core/Eager/EagerTensor.cs
  6. +3
    -3
      src/TensorFlowNET.Core/Eager/SafeEagerTensorHandle.cs
  7. +12
    -12
      src/TensorFlowNET.Core/Eager/c_api.eager.cs
  8. +22
    -17
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  9. +10
    -10
      test/TensorFlowNET.Native.UnitTest/CApiTest.cs
  10. +1
    -1
      test/TensorFlowNET.Native.UnitTest/Eager/Eager.Execute_MatMul_CPU.cs
  11. +2
    -2
      test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpGetInputAndOutputLengths.cs
  12. +1
    -1
      test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpInferMixedTypeInputListAttrs.cs
  13. +1
    -1
      test/TensorFlowNET.Native.UnitTest/Eager/Eager.TensorHandleDevices.cs
  14. +1
    -1
      test/TensorFlowNET.Native.UnitTest/Eager/Eager.Variables.cs
  15. +15
    -15
      test/TensorFlowNET.Native.UnitTest/Eager/Eager.cs

+ 1
- 1
src/TensorFlowNET.Core/Device/c_api.device.cs View File

@@ -67,7 +67,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
/// <returns>TFE_TensorHandle*</returns> /// <returns>TFE_TensorHandle*</returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern SafeTensorHandleHandle TFE_TensorHandleCopyToDevice(SafeTensorHandleHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status);
public static extern SafeEagerTensorHandle TFE_TensorHandleCopyToDevice(SafeEagerTensorHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status);


/// <summary> /// <summary>
/// Retrieves the full name of the device (e.g. /job:worker/replica:0/...) /// Retrieves the full name of the device (e.g. /job:worker/replica:0/...)


+ 2
- 2
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs View File

@@ -48,7 +48,7 @@ namespace Tensorflow.Eager
{ {
for (int i = 0; i < inputs.Length; ++i) for (int i = 0; i < inputs.Length; ++i)
{ {
SafeTensorHandleHandle tensor_handle = inputs[i] switch
SafeEagerTensorHandle tensor_handle = inputs[i] switch
{ {
EagerTensor et => et.EagerTensorHandle, EagerTensor et => et.EagerTensorHandle,
Tensor nd => nd.EagerTensorHandle, Tensor nd => nd.EagerTensorHandle,
@@ -61,7 +61,7 @@ namespace Tensorflow.Eager
if (status.ok() && attrs != null) if (status.ok() && attrs != null)
SetOpAttrs(op, attrs); SetOpAttrs(op, attrs);


var outputs = new SafeTensorHandleHandle[num_outputs];
var outputs = new SafeEagerTensorHandle[num_outputs];
if (status.ok()) if (status.ok())
{ {
c_api.TFE_Execute(op, outputs, out num_outputs, status.Handle); c_api.TFE_Execute(op, outputs, out num_outputs, status.Handle);


+ 1
- 1
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs View File

@@ -141,7 +141,7 @@ namespace Tensorflow.Eager
num_retvals += (int)delta; num_retvals += (int)delta;
} }


var retVals = new SafeTensorHandleHandle[num_retvals];
var retVals = new SafeEagerTensorHandle[num_retvals];
c_api.TFE_Execute(op, retVals, out num_retvals, status.Handle); c_api.TFE_Execute(op, retVals, out num_retvals, status.Handle);
status.Check(true); status.Check(true);




+ 1
- 1
src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs View File

@@ -12,7 +12,7 @@ namespace Tensorflow.Eager
NewEagerTensorHandle(handle); NewEagerTensorHandle(handle);
} }


public EagerTensor(SafeTensorHandleHandle handle)
public EagerTensor(SafeEagerTensorHandle handle)
{ {
_id = ops.uid(); _id = ops.uid();
_eagerTensorHandle = handle; _eagerTensorHandle = handle;


+ 10
- 8
src/TensorFlowNET.Core/Eager/EagerTensor.cs View File

@@ -6,17 +6,19 @@ namespace Tensorflow.Eager
{ {
public partial class EagerTensor : Tensor public partial class EagerTensor : Tensor
{ {
public override string Device
{
get
{
using var _ = EagerTensorHandle.Lease();
return c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(EagerTensorHandle, tf.Status.Handle));
}
}
public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(_eagerTensorHandle, tf.Status.Handle));
public override TF_DataType dtype => c_api.TFE_TensorHandleDataType(_eagerTensorHandle);


public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.Status.Handle); public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.Status.Handle);


protected override Shape GetShapeInternal()
{
var dims = new int[c_api.TFE_TensorHandleNumDims(_eagerTensorHandle, tf.Status.Handle)];
for (int i = 0; i < dims.Length; i++)
dims[i] = c_api.TFE_TensorHandleDim(_eagerTensorHandle, i, tf.Status.Handle);
return dims;
}

public static int GetRank(IntPtr handle) public static int GetRank(IntPtr handle)
{ {
var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle);


src/TensorFlowNET.Core/Eager/SafeTensorHandleHandle.cs → src/TensorFlowNET.Core/Eager/SafeEagerTensorHandle.cs View File

@@ -20,13 +20,13 @@ using static Tensorflow.Binding;


namespace Tensorflow.Eager namespace Tensorflow.Eager
{ {
public sealed class SafeTensorHandleHandle : SafeTensorflowHandle
public sealed class SafeEagerTensorHandle : SafeTensorflowHandle
{ {
private SafeTensorHandleHandle()
private SafeEagerTensorHandle()
{ {
} }


public SafeTensorHandleHandle(IntPtr handle)
public SafeEagerTensorHandle(IntPtr handle)
: base(handle) : base(handle)
{ {
} }

+ 12
- 12
src/TensorFlowNET.Core/Eager/c_api.eager.cs View File

@@ -94,7 +94,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
/// <returns></returns> /// <returns></returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern int TFE_OpAddInputList(SafeEagerOpHandle op, [In, MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(SafeHandleArrayMarshaler))] SafeTensorHandleHandle[] inputs, int num_inputs, SafeStatusHandle status);
public static extern int TFE_OpAddInputList(SafeEagerOpHandle op, [In, MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(SafeHandleArrayMarshaler))] SafeEagerTensorHandle[] inputs, int num_inputs, SafeStatusHandle status);


/// <summary> /// <summary>
/// ///
@@ -161,7 +161,7 @@ namespace Tensorflow
/// <param name="retvals"></param> /// <param name="retvals"></param>
/// <param name="num_retvals"></param> /// <param name="num_retvals"></param>
/// <param name="status"></param> /// <param name="status"></param>
public static void TFE_Execute(SafeEagerOpHandle op, SafeTensorHandleHandle[] retvals, out int num_retvals, SafeStatusHandle status)
public static void TFE_Execute(SafeEagerOpHandle op, SafeEagerTensorHandle[] retvals, out int num_retvals, SafeStatusHandle status)
{ {
unsafe unsafe
{ {
@@ -173,7 +173,7 @@ namespace Tensorflow
// A handle is created for every return, even if rawReturns[i] is null. The resulting handle will be // A handle is created for every return, even if rawReturns[i] is null. The resulting handle will be
// non-null but invalid, which is the same behavior P/Invoke gives for non-array SafeHandle return // non-null but invalid, which is the same behavior P/Invoke gives for non-array SafeHandle return
// values. // values.
retvals[i] = new SafeTensorHandleHandle(rawReturns[i]);
retvals[i] = new SafeEagerTensorHandle(rawReturns[i]);
} }
} }
} }
@@ -295,7 +295,7 @@ namespace Tensorflow
/// <param name="h">TFE_TensorHandle*</param> /// <param name="h">TFE_TensorHandle*</param>
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TFE_OpAddInput(SafeEagerOpHandle op, SafeTensorHandleHandle h, SafeStatusHandle status);
public static extern void TFE_OpAddInput(SafeEagerOpHandle op, SafeEagerTensorHandle h, SafeStatusHandle status);


/// <summary> /// <summary>
/// ///
@@ -303,10 +303,10 @@ namespace Tensorflow
/// <param name="t">const tensorflow::Tensor&amp;</param> /// <param name="t">const tensorflow::Tensor&amp;</param>
/// <returns>TFE_TensorHandle*</returns> /// <returns>TFE_TensorHandle*</returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern SafeTensorHandleHandle TFE_NewTensorHandle(SafeTensorHandle t, SafeStatusHandle status);
public static extern SafeEagerTensorHandle TFE_NewTensorHandle(SafeTensorHandle t, SafeStatusHandle status);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern SafeTensorHandleHandle TFE_EagerTensorHandle(IntPtr t);
public static extern SafeEagerTensorHandle TFE_EagerTensorHandle(IntPtr t);


/// <summary> /// <summary>
/// Sets the default execution mode (sync/async). Note that this can be /// Sets the default execution mode (sync/async). Note that this can be
@@ -323,7 +323,7 @@ namespace Tensorflow
/// <param name="h">TFE_TensorHandle*</param> /// <param name="h">TFE_TensorHandle*</param>
/// <returns></returns> /// <returns></returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern TF_DataType TFE_TensorHandleDataType(SafeTensorHandleHandle h);
public static extern TF_DataType TFE_TensorHandleDataType(SafeEagerTensorHandle h);


/// <summary> /// <summary>
/// This function will block till the operation that produces `h` has /// This function will block till the operation that produces `h` has
@@ -334,7 +334,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
/// <returns></returns> /// <returns></returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern SafeTensorHandle TFE_TensorHandleResolve(SafeTensorHandleHandle h, SafeStatusHandle status);
public static extern SafeTensorHandle TFE_TensorHandleResolve(SafeEagerTensorHandle h, SafeStatusHandle status);




/// <summary> /// <summary>
@@ -344,10 +344,10 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
/// <returns></returns> /// <returns></returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern int TFE_TensorHandleNumDims(SafeTensorHandleHandle h, SafeStatusHandle status);
public static extern int TFE_TensorHandleNumDims(SafeEagerTensorHandle h, SafeStatusHandle status);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern int TFE_TensorHandleDim(SafeTensorHandleHandle h, int dim, SafeStatusHandle status);
public static extern int TFE_TensorHandleDim(SafeEagerTensorHandle h, int dim, SafeStatusHandle status);


/// <summary> /// <summary>
/// Returns the device of the operation that produced `h`. If `h` was produced by /// Returns the device of the operation that produced `h`. If `h` was produced by
@@ -360,7 +360,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
/// <returns></returns> /// <returns></returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_TensorHandleDeviceName(SafeTensorHandleHandle h, SafeStatusHandle status);
public static extern IntPtr TFE_TensorHandleDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status);


/// <summary> /// <summary>
/// Returns the name of the device in whose memory `h` resides. /// Returns the name of the device in whose memory `h` resides.
@@ -369,7 +369,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
/// <returns></returns> /// <returns></returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_TensorHandleBackingDeviceName(SafeTensorHandleHandle h, SafeStatusHandle status);
public static extern IntPtr TFE_TensorHandleBackingDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status);


/// <summary> /// <summary>
/// ///


+ 22
- 17
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -67,9 +67,9 @@ namespace Tensorflow
/// <summary> /// <summary>
/// The DType of elements in this tensor. /// The DType of elements in this tensor.
/// </summary> /// </summary>
public TF_DataType dtype => _handle == null ? _override_dtype : c_api.TF_TensorType(_handle);
public virtual TF_DataType dtype => _handle == null ? _override_dtype : c_api.TF_TensorType(_handle);
public ulong bytesize => _handle == null ? 0 : c_api.TF_TensorByteSize(_handle); public ulong bytesize => _handle == null ? 0 : c_api.TF_TensorByteSize(_handle);
public ulong dtypesize => _handle == null ? 0 : c_api.TF_DataTypeSize(dtype);
public ulong dtypesize => (ulong)dtype.get_datatype_size();
public ulong size => _handle == null ? 0 : bytesize / dtypesize; public ulong size => _handle == null ? 0 : bytesize / dtypesize;
public IntPtr buffer => _handle == null ? IntPtr.Zero : c_api.TF_TensorData(_handle); public IntPtr buffer => _handle == null ? IntPtr.Zero : c_api.TF_TensorData(_handle);
public int num_consumers(TF_Output oper_out) => _handle == null ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); public int num_consumers(TF_Output oper_out) => _handle == null ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out);
@@ -88,11 +88,11 @@ namespace Tensorflow
protected new SafeTensorHandle _handle; protected new SafeTensorHandle _handle;
public SafeTensorHandle Handle => _handle; public SafeTensorHandle Handle => _handle;


protected SafeTensorHandleHandle _eagerTensorHandle;
protected SafeEagerTensorHandle _eagerTensorHandle;
/// <summary> /// <summary>
/// TFE_TensorHandle /// TFE_TensorHandle
/// </summary> /// </summary>
public SafeTensorHandleHandle EagerTensorHandle => _eagerTensorHandle;
public SafeEagerTensorHandle EagerTensorHandle => _eagerTensorHandle;


protected bool _isCreatedInGraphMode; protected bool _isCreatedInGraphMode;
@@ -109,19 +109,7 @@ namespace Tensorflow
if (rank < 0) if (rank < 0)
return Shape.Null; return Shape.Null;


var dims = new Shape(new long[rank]);

if (_handle == null)
{
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, tf.Status.Handle);
}
else
{
for (int i = 0; i < rank; i++)
dims[i] = c_api.TF_Dim(_handle, i);
}

return dims;
return GetShapeInternal();
} }


set set
@@ -142,6 +130,23 @@ namespace Tensorflow
} }
} }


protected virtual Shape GetShapeInternal()
{
var dims = new Shape(new long[rank]);

if (_handle == null)
{
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, tf.Status.Handle);
}
else
{
for (int i = 0; i < rank; i++)
dims[i] = c_api.TF_Dim(_handle, i);
}

return dims;
}

public int[] _shape_tuple() public int[] _shape_tuple()
{ {
return rank < 0 ? null : shape.dims.Select(x => (int)x).ToArray(); return rank < 0 ? null : shape.dims.Select(x => (int)x).ToArray();


+ 10
- 10
test/TensorFlowNET.Native.UnitTest/CApiTest.cs View File

@@ -56,10 +56,10 @@ namespace Tensorflow.Native.UnitTest
protected void TF_SetAttrBool(OperationDescription desc, string attrName, bool value) protected void TF_SetAttrBool(OperationDescription desc, string attrName, bool value)
=> c_api.TF_SetAttrBool(desc, attrName, value); => c_api.TF_SetAttrBool(desc, attrName, value);


protected TF_DataType TFE_TensorHandleDataType(SafeTensorHandleHandle h)
protected TF_DataType TFE_TensorHandleDataType(SafeEagerTensorHandle h)
=> c_api.TFE_TensorHandleDataType(h); => c_api.TFE_TensorHandleDataType(h);


protected int TFE_TensorHandleNumDims(SafeTensorHandleHandle h, SafeStatusHandle status)
protected int TFE_TensorHandleNumDims(SafeEagerTensorHandle h, SafeStatusHandle status)
=> c_api.TFE_TensorHandleNumDims(h, status); => c_api.TFE_TensorHandleNumDims(h, status);


protected TF_Code TF_GetCode(Status s) protected TF_Code TF_GetCode(Status s)
@@ -80,7 +80,7 @@ namespace Tensorflow.Native.UnitTest
protected ulong TF_TensorByteSize(SafeTensorHandle t) protected ulong TF_TensorByteSize(SafeTensorHandle t)
=> c_api.TF_TensorByteSize(t); => c_api.TF_TensorByteSize(t);


protected void TFE_OpAddInput(SafeEagerOpHandle op, SafeTensorHandleHandle h, SafeStatusHandle status)
protected void TFE_OpAddInput(SafeEagerOpHandle op, SafeEagerTensorHandle h, SafeStatusHandle status)
=> c_api.TFE_OpAddInput(op, h, status); => c_api.TFE_OpAddInput(op, h, status);


protected void TFE_OpSetAttrType(SafeEagerOpHandle op, string attr_name, TF_DataType value) protected void TFE_OpSetAttrType(SafeEagerOpHandle op, string attr_name, TF_DataType value)
@@ -95,10 +95,10 @@ namespace Tensorflow.Native.UnitTest
protected SafeEagerOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status) protected SafeEagerOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status)
=> c_api.TFE_NewOp(ctx, op_or_function_name, status); => c_api.TFE_NewOp(ctx, op_or_function_name, status);


protected SafeTensorHandleHandle TFE_NewTensorHandle(SafeTensorHandle t, SafeStatusHandle status)
protected SafeEagerTensorHandle TFE_NewTensorHandle(SafeTensorHandle t, SafeStatusHandle status)
=> c_api.TFE_NewTensorHandle(t, status); => c_api.TFE_NewTensorHandle(t, status);


protected void TFE_Execute(SafeEagerOpHandle op, SafeTensorHandleHandle[] retvals, out int num_retvals, SafeStatusHandle status)
protected void TFE_Execute(SafeEagerOpHandle op, SafeEagerTensorHandle[] retvals, out int num_retvals, SafeStatusHandle status)
=> c_api.TFE_Execute(op, retvals, out num_retvals, status); => c_api.TFE_Execute(op, retvals, out num_retvals, status);


protected SafeContextOptionsHandle TFE_NewContextOptions() protected SafeContextOptionsHandle TFE_NewContextOptions()
@@ -110,7 +110,7 @@ namespace Tensorflow.Native.UnitTest
protected int TFE_OpGetInputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status) protected int TFE_OpGetInputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status)
=> c_api.TFE_OpGetInputLength(op, input_name, status); => c_api.TFE_OpGetInputLength(op, input_name, status);


protected int TFE_OpAddInputList(SafeEagerOpHandle op, SafeTensorHandleHandle[] inputs, int num_inputs, SafeStatusHandle status)
protected int TFE_OpAddInputList(SafeEagerOpHandle op, SafeEagerTensorHandle[] inputs, int num_inputs, SafeStatusHandle status)
=> c_api.TFE_OpAddInputList(op, inputs, num_inputs, status); => c_api.TFE_OpAddInputList(op, inputs, num_inputs, status);


protected int TFE_OpGetOutputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status) protected int TFE_OpGetOutputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status)
@@ -125,13 +125,13 @@ namespace Tensorflow.Native.UnitTest
protected void TFE_ExecutorWaitForAllPendingNodes(SafeExecutorHandle executor, SafeStatusHandle status) protected void TFE_ExecutorWaitForAllPendingNodes(SafeExecutorHandle executor, SafeStatusHandle status)
=> c_api.TFE_ExecutorWaitForAllPendingNodes(executor, status); => c_api.TFE_ExecutorWaitForAllPendingNodes(executor, status);


protected SafeTensorHandle TFE_TensorHandleResolve(SafeTensorHandleHandle h, SafeStatusHandle status)
protected SafeTensorHandle TFE_TensorHandleResolve(SafeEagerTensorHandle h, SafeStatusHandle status)
=> c_api.TFE_TensorHandleResolve(h, status); => c_api.TFE_TensorHandleResolve(h, status);


protected string TFE_TensorHandleDeviceName(SafeTensorHandleHandle h, SafeStatusHandle status)
protected string TFE_TensorHandleDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status)
=> c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(h, status)); => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(h, status));


protected string TFE_TensorHandleBackingDeviceName(SafeTensorHandleHandle h, SafeStatusHandle status)
protected string TFE_TensorHandleBackingDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status)
=> c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status)); => c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status));


protected SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status) protected SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status)
@@ -146,7 +146,7 @@ namespace Tensorflow.Native.UnitTest
protected string TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status) protected string TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status)
=> c_api.TF_DeviceListName(list, index, status); => c_api.TF_DeviceListName(list, index, status);


protected SafeTensorHandleHandle TFE_TensorHandleCopyToDevice(SafeTensorHandleHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status)
protected SafeEagerTensorHandle TFE_TensorHandleCopyToDevice(SafeEagerTensorHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status)
=> c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status); => c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status);


protected void TFE_OpSetDevice(SafeEagerOpHandle op, string device_name, SafeStatusHandle status) protected void TFE_OpSetDevice(SafeEagerOpHandle op, string device_name, SafeStatusHandle status)


+ 1
- 1
test/TensorFlowNET.Native.UnitTest/Eager/Eager.Execute_MatMul_CPU.cs View File

@@ -32,7 +32,7 @@ namespace Tensorflow.Native.UnitTest.Eager
{ {
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));


var retvals = new SafeTensorHandleHandle[2];
var retvals = new SafeEagerTensorHandle[2];
using (var m = TestMatrixTensorHandle()) using (var m = TestMatrixTensorHandle())
using (var matmul = MatMulOp(ctx, m, m)) using (var matmul = MatMulOp(ctx, m, m))
{ {


+ 2
- 2
test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpGetInputAndOutputLengths.cs View File

@@ -25,7 +25,7 @@ namespace Tensorflow.Native.UnitTest.Eager
using var input1 = TestMatrixTensorHandle(); using var input1 = TestMatrixTensorHandle();
using var input2 = TestMatrixTensorHandle(); using var input2 = TestMatrixTensorHandle();


var retvals = new SafeTensorHandleHandle[2];
var retvals = new SafeEagerTensorHandle[2];
using (var identityOp = TFE_NewOp(ctx, "IdentityN", status)) using (var identityOp = TFE_NewOp(ctx, "IdentityN", status))
{ {
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
@@ -36,7 +36,7 @@ namespace Tensorflow.Native.UnitTest.Eager
EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status)); EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status));
CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status));


var inputs = new SafeTensorHandleHandle[] { input1, input2 };
var inputs = new SafeEagerTensorHandle[] { input1, input2 };
TFE_OpAddInputList(identityOp, inputs, 2, status); TFE_OpAddInputList(identityOp, inputs, 2, status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));




+ 1
- 1
test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpInferMixedTypeInputListAttrs.cs View File

@@ -41,7 +41,7 @@ namespace Tensorflow.Native.UnitTest.Eager
//EXPECT_EQ(attr_found->second.list().type(1), tensorflow::DataType::DT_FLOAT); //EXPECT_EQ(attr_found->second.list().type(1), tensorflow::DataType::DT_FLOAT);
//EXPECT_EQ(attr_found->second.list().type(2), tensorflow::DataType::DT_INT32); //EXPECT_EQ(attr_found->second.list().type(2), tensorflow::DataType::DT_INT32);


var retvals = new SafeTensorHandleHandle[0];
var retvals = new SafeEagerTensorHandle[0];
int num_retvals; int num_retvals;
TFE_Execute(assertOp, retvals, out num_retvals, status); TFE_Execute(assertOp, retvals, out num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));


+ 1
- 1
test/TensorFlowNET.Native.UnitTest/Eager/Eager.TensorHandleDevices.cs View File

@@ -39,7 +39,7 @@ namespace Tensorflow.Native.UnitTest.Eager
using var hgpu = TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_device_name, status); using var hgpu = TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_device_name, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status));


var retvals = new SafeTensorHandleHandle[1];
var retvals = new SafeEagerTensorHandle[1];
using (var shape_op = ShapeOp(ctx, hgpu)) using (var shape_op = ShapeOp(ctx, hgpu))
{ {
TFE_OpSetDevice(shape_op, gpu_device_name, status); TFE_OpSetDevice(shape_op, gpu_device_name, status);


+ 1
- 1
test/TensorFlowNET.Native.UnitTest/Eager/Eager.Variables.cs View File

@@ -28,7 +28,7 @@ namespace Tensorflow.Native.UnitTest.Eager
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));


int num_retvals = 1; int num_retvals = 1;
var value_handle = new SafeTensorHandleHandle[1];
var value_handle = new SafeEagerTensorHandle[1];
using (var op = TFE_NewOp(ctx, "ReadVariableOp", status)) using (var op = TFE_NewOp(ctx, "ReadVariableOp", status))
{ {
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));


+ 15
- 15
test/TensorFlowNET.Native.UnitTest/Eager/Eager.cs View File

@@ -11,7 +11,7 @@ namespace Tensorflow.Native.UnitTest.Eager
[TestClass] [TestClass]
public partial class CApiEagerTest : CApiTest public partial class CApiEagerTest : CApiTest
{ {
SafeTensorHandleHandle TestMatrixTensorHandle()
SafeEagerTensorHandle TestMatrixTensorHandle()
{ {
var dims = new long[] { 2, 2 }; var dims = new long[] { 2, 2 };
var data = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; var data = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
@@ -25,7 +25,7 @@ namespace Tensorflow.Native.UnitTest.Eager
return th; return th;
} }


SafeEagerOpHandle MatMulOp(SafeContextHandle ctx, SafeTensorHandleHandle a, SafeTensorHandleHandle b)
SafeEagerOpHandle MatMulOp(SafeContextHandle ctx, SafeEagerTensorHandle a, SafeEagerTensorHandle b)
{ {
using var status = TF_NewStatus(); using var status = TF_NewStatus();


@@ -63,7 +63,7 @@ namespace Tensorflow.Native.UnitTest.Eager
return false; return false;
} }


SafeEagerOpHandle ShapeOp(SafeContextHandle ctx, SafeTensorHandleHandle a)
SafeEagerOpHandle ShapeOp(SafeContextHandle ctx, SafeEagerTensorHandle a)
{ {
using var status = TF_NewStatus(); using var status = TF_NewStatus();


@@ -76,27 +76,27 @@ namespace Tensorflow.Native.UnitTest.Eager
return op; return op;
} }


unsafe SafeTensorHandleHandle CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status)
unsafe SafeEagerTensorHandle CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status)
{ {
var var_handle = new SafeTensorHandleHandle[1];
var var_handle = new SafeEagerTensorHandle[1];
int num_retvals; int num_retvals;
using (var op = TFE_NewOp(ctx, "VarHandleOp", status)) using (var op = TFE_NewOp(ctx, "VarHandleOp", status))
{ {
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero);
TFE_OpSetAttrType(op, "dtype", TF_FLOAT); TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpSetAttrShape(op, "shape", new long[0], 0, status); TFE_OpSetAttrShape(op, "shape", new long[0], 0, status);
TFE_OpSetAttrString(op, "container", "", 0); TFE_OpSetAttrString(op, "container", "", 0);
TFE_OpSetAttrString(op, "shared_name", "", 0); TFE_OpSetAttrString(op, "shared_name", "", 0);
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero);
TFE_Execute(op, var_handle, out num_retvals, status); TFE_Execute(op, var_handle, out num_retvals, status);
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero);
CHECK_EQ(1, num_retvals); CHECK_EQ(1, num_retvals);
} }


// Assign 'value' to it. // Assign 'value' to it.
using (var op = TFE_NewOp(ctx, "AssignVariableOp", status)) using (var op = TFE_NewOp(ctx, "AssignVariableOp", status))
{ {
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero);
TFE_OpSetAttrType(op, "dtype", TF_FLOAT); TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpAddInput(op, var_handle[0], status); TFE_OpAddInput(op, var_handle[0], status);


@@ -105,20 +105,20 @@ namespace Tensorflow.Native.UnitTest.Eager
tf.memcpy(TF_TensorData(t).ToPointer(), &value, TF_TensorByteSize(t)); tf.memcpy(TF_TensorData(t).ToPointer(), &value, TF_TensorByteSize(t));


var value_handle = c_api.TFE_NewTensorHandle(t, status); var value_handle = c_api.TFE_NewTensorHandle(t, status);
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero);


TFE_OpAddInput(op, value_handle, status); TFE_OpAddInput(op, value_handle, status);
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero);


c_api.TFE_Execute(op, null, out num_retvals, status); c_api.TFE_Execute(op, null, out num_retvals, status);
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero);
CHECK_EQ(0, num_retvals); CHECK_EQ(0, num_retvals);
} }


return var_handle[0]; return var_handle[0];
} }


SafeTensorHandleHandle TestAxisTensorHandle()
SafeEagerTensorHandle TestAxisTensorHandle()
{ {
var dims = new long[] { 1 }; var dims = new long[] { 1 };
var data = new int[] { 1 }; var data = new int[] { 1 };
@@ -131,7 +131,7 @@ namespace Tensorflow.Native.UnitTest.Eager
return th; return th;
} }


SafeTensorHandleHandle TestScalarTensorHandle(bool value)
SafeEagerTensorHandle TestScalarTensorHandle(bool value)
{ {
var data = new[] { value }; var data = new[] { value };
var t = c_api.TF_AllocateTensor(TF_BOOL, null, 0, sizeof(bool)); var t = c_api.TF_AllocateTensor(TF_BOOL, null, 0, sizeof(bool));
@@ -143,7 +143,7 @@ namespace Tensorflow.Native.UnitTest.Eager
return th; return th;
} }


SafeTensorHandleHandle TestScalarTensorHandle(float value)
SafeEagerTensorHandle TestScalarTensorHandle(float value)
{ {
var data = new[] { value }; var data = new[] { value };
var t = c_api.TF_AllocateTensor(TF_FLOAT, null, 0, sizeof(float)); var t = c_api.TF_AllocateTensor(TF_FLOAT, null, 0, sizeof(float));


Loading…
Cancel
Save