Browse Source

override dtype and shape

tags/TimeSeries
Oceania2018 3 years ago
parent
commit
fa37f0725a
15 changed files with 83 additions and 76 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Device/c_api.device.cs
  2. +2
    -2
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
  5. +10
    -8
      src/TensorFlowNET.Core/Eager/EagerTensor.cs
  6. +3
    -3
      src/TensorFlowNET.Core/Eager/SafeEagerTensorHandle.cs
  7. +12
    -12
      src/TensorFlowNET.Core/Eager/c_api.eager.cs
  8. +22
    -17
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  9. +10
    -10
      test/TensorFlowNET.Native.UnitTest/CApiTest.cs
  10. +1
    -1
      test/TensorFlowNET.Native.UnitTest/Eager/Eager.Execute_MatMul_CPU.cs
  11. +2
    -2
      test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpGetInputAndOutputLengths.cs
  12. +1
    -1
      test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpInferMixedTypeInputListAttrs.cs
  13. +1
    -1
      test/TensorFlowNET.Native.UnitTest/Eager/Eager.TensorHandleDevices.cs
  14. +1
    -1
      test/TensorFlowNET.Native.UnitTest/Eager/Eager.Variables.cs
  15. +15
    -15
      test/TensorFlowNET.Native.UnitTest/Eager/Eager.cs

+ 1
- 1
src/TensorFlowNET.Core/Device/c_api.device.cs View File

@@ -67,7 +67,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param>
/// <returns>TFE_TensorHandle*</returns>
[DllImport(TensorFlowLibName)]
public static extern SafeTensorHandleHandle TFE_TensorHandleCopyToDevice(SafeTensorHandleHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status);
public static extern SafeEagerTensorHandle TFE_TensorHandleCopyToDevice(SafeEagerTensorHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status);

/// <summary>
/// Retrieves the full name of the device (e.g. /job:worker/replica:0/...)


+ 2
- 2
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs View File

@@ -48,7 +48,7 @@ namespace Tensorflow.Eager
{
for (int i = 0; i < inputs.Length; ++i)
{
SafeTensorHandleHandle tensor_handle = inputs[i] switch
SafeEagerTensorHandle tensor_handle = inputs[i] switch
{
EagerTensor et => et.EagerTensorHandle,
Tensor nd => nd.EagerTensorHandle,
@@ -61,7 +61,7 @@ namespace Tensorflow.Eager
if (status.ok() && attrs != null)
SetOpAttrs(op, attrs);

var outputs = new SafeTensorHandleHandle[num_outputs];
var outputs = new SafeEagerTensorHandle[num_outputs];
if (status.ok())
{
c_api.TFE_Execute(op, outputs, out num_outputs, status.Handle);


+ 1
- 1
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs View File

@@ -141,7 +141,7 @@ namespace Tensorflow.Eager
num_retvals += (int)delta;
}

var retVals = new SafeTensorHandleHandle[num_retvals];
var retVals = new SafeEagerTensorHandle[num_retvals];
c_api.TFE_Execute(op, retVals, out num_retvals, status.Handle);
status.Check(true);



+ 1
- 1
src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs View File

@@ -12,7 +12,7 @@ namespace Tensorflow.Eager
NewEagerTensorHandle(handle);
}

public EagerTensor(SafeTensorHandleHandle handle)
public EagerTensor(SafeEagerTensorHandle handle)
{
_id = ops.uid();
_eagerTensorHandle = handle;


+ 10
- 8
src/TensorFlowNET.Core/Eager/EagerTensor.cs View File

@@ -6,17 +6,19 @@ namespace Tensorflow.Eager
{
public partial class EagerTensor : Tensor
{
public override string Device
{
get
{
using var _ = EagerTensorHandle.Lease();
return c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(EagerTensorHandle, tf.Status.Handle));
}
}
public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(_eagerTensorHandle, tf.Status.Handle));
public override TF_DataType dtype => c_api.TFE_TensorHandleDataType(_eagerTensorHandle);

public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.Status.Handle);

protected override Shape GetShapeInternal()
{
var dims = new int[c_api.TFE_TensorHandleNumDims(_eagerTensorHandle, tf.Status.Handle)];
for (int i = 0; i < dims.Length; i++)
dims[i] = c_api.TFE_TensorHandleDim(_eagerTensorHandle, i, tf.Status.Handle);
return dims;
}

public static int GetRank(IntPtr handle)
{
var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle);


src/TensorFlowNET.Core/Eager/SafeTensorHandleHandle.cs → src/TensorFlowNET.Core/Eager/SafeEagerTensorHandle.cs View File

@@ -20,13 +20,13 @@ using static Tensorflow.Binding;

namespace Tensorflow.Eager
{
public sealed class SafeTensorHandleHandle : SafeTensorflowHandle
public sealed class SafeEagerTensorHandle : SafeTensorflowHandle
{
private SafeTensorHandleHandle()
private SafeEagerTensorHandle()
{
}

public SafeTensorHandleHandle(IntPtr handle)
public SafeEagerTensorHandle(IntPtr handle)
: base(handle)
{
}

+ 12
- 12
src/TensorFlowNET.Core/Eager/c_api.eager.cs View File

@@ -94,7 +94,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern int TFE_OpAddInputList(SafeEagerOpHandle op, [In, MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(SafeHandleArrayMarshaler))] SafeTensorHandleHandle[] inputs, int num_inputs, SafeStatusHandle status);
public static extern int TFE_OpAddInputList(SafeEagerOpHandle op, [In, MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(SafeHandleArrayMarshaler))] SafeEagerTensorHandle[] inputs, int num_inputs, SafeStatusHandle status);

/// <summary>
///
@@ -161,7 +161,7 @@ namespace Tensorflow
/// <param name="retvals"></param>
/// <param name="num_retvals"></param>
/// <param name="status"></param>
public static void TFE_Execute(SafeEagerOpHandle op, SafeTensorHandleHandle[] retvals, out int num_retvals, SafeStatusHandle status)
public static void TFE_Execute(SafeEagerOpHandle op, SafeEagerTensorHandle[] retvals, out int num_retvals, SafeStatusHandle status)
{
unsafe
{
@@ -173,7 +173,7 @@ namespace Tensorflow
// A handle is created for every return, even if rawReturns[i] is null. The resulting handle will be
// non-null but invalid, which is the same behavior P/Invoke gives for non-array SafeHandle return
// values.
retvals[i] = new SafeTensorHandleHandle(rawReturns[i]);
retvals[i] = new SafeEagerTensorHandle(rawReturns[i]);
}
}
}
@@ -295,7 +295,7 @@ namespace Tensorflow
/// <param name="h">TFE_TensorHandle*</param>
/// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_OpAddInput(SafeEagerOpHandle op, SafeTensorHandleHandle h, SafeStatusHandle status);
public static extern void TFE_OpAddInput(SafeEagerOpHandle op, SafeEagerTensorHandle h, SafeStatusHandle status);

/// <summary>
///
@@ -303,10 +303,10 @@ namespace Tensorflow
/// <param name="t">const tensorflow::Tensor&amp;</param>
/// <returns>TFE_TensorHandle*</returns>
[DllImport(TensorFlowLibName)]
public static extern SafeTensorHandleHandle TFE_NewTensorHandle(SafeTensorHandle t, SafeStatusHandle status);
public static extern SafeEagerTensorHandle TFE_NewTensorHandle(SafeTensorHandle t, SafeStatusHandle status);

[DllImport(TensorFlowLibName)]
public static extern SafeTensorHandleHandle TFE_EagerTensorHandle(IntPtr t);
public static extern SafeEagerTensorHandle TFE_EagerTensorHandle(IntPtr t);

/// <summary>
/// Sets the default execution mode (sync/async). Note that this can be
@@ -323,7 +323,7 @@ namespace Tensorflow
/// <param name="h">TFE_TensorHandle*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern TF_DataType TFE_TensorHandleDataType(SafeTensorHandleHandle h);
public static extern TF_DataType TFE_TensorHandleDataType(SafeEagerTensorHandle h);

/// <summary>
/// This function will block till the operation that produces `h` has
@@ -334,7 +334,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern SafeTensorHandle TFE_TensorHandleResolve(SafeTensorHandleHandle h, SafeStatusHandle status);
public static extern SafeTensorHandle TFE_TensorHandleResolve(SafeEagerTensorHandle h, SafeStatusHandle status);


/// <summary>
@@ -344,10 +344,10 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern int TFE_TensorHandleNumDims(SafeTensorHandleHandle h, SafeStatusHandle status);
public static extern int TFE_TensorHandleNumDims(SafeEagerTensorHandle h, SafeStatusHandle status);

[DllImport(TensorFlowLibName)]
public static extern int TFE_TensorHandleDim(SafeTensorHandleHandle h, int dim, SafeStatusHandle status);
public static extern int TFE_TensorHandleDim(SafeEagerTensorHandle h, int dim, SafeStatusHandle status);

/// <summary>
/// Returns the device of the operation that produced `h`. If `h` was produced by
@@ -360,7 +360,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_TensorHandleDeviceName(SafeTensorHandleHandle h, SafeStatusHandle status);
public static extern IntPtr TFE_TensorHandleDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status);

/// <summary>
/// Returns the name of the device in whose memory `h` resides.
@@ -369,7 +369,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_TensorHandleBackingDeviceName(SafeTensorHandleHandle h, SafeStatusHandle status);
public static extern IntPtr TFE_TensorHandleBackingDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status);

/// <summary>
///


+ 22
- 17
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -67,9 +67,9 @@ namespace Tensorflow
/// <summary>
/// The DType of elements in this tensor.
/// </summary>
public TF_DataType dtype => _handle == null ? _override_dtype : c_api.TF_TensorType(_handle);
public virtual TF_DataType dtype => _handle == null ? _override_dtype : c_api.TF_TensorType(_handle);
public ulong bytesize => _handle == null ? 0 : c_api.TF_TensorByteSize(_handle);
public ulong dtypesize => _handle == null ? 0 : c_api.TF_DataTypeSize(dtype);
public ulong dtypesize => (ulong)dtype.get_datatype_size();
public ulong size => _handle == null ? 0 : bytesize / dtypesize;
public IntPtr buffer => _handle == null ? IntPtr.Zero : c_api.TF_TensorData(_handle);
public int num_consumers(TF_Output oper_out) => _handle == null ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out);
@@ -88,11 +88,11 @@ namespace Tensorflow
protected new SafeTensorHandle _handle;
public SafeTensorHandle Handle => _handle;

protected SafeTensorHandleHandle _eagerTensorHandle;
protected SafeEagerTensorHandle _eagerTensorHandle;
/// <summary>
/// TFE_TensorHandle
/// </summary>
public SafeTensorHandleHandle EagerTensorHandle => _eagerTensorHandle;
public SafeEagerTensorHandle EagerTensorHandle => _eagerTensorHandle;

protected bool _isCreatedInGraphMode;
@@ -109,19 +109,7 @@ namespace Tensorflow
if (rank < 0)
return Shape.Null;

var dims = new Shape(new long[rank]);

if (_handle == null)
{
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, tf.Status.Handle);
}
else
{
for (int i = 0; i < rank; i++)
dims[i] = c_api.TF_Dim(_handle, i);
}

return dims;
return GetShapeInternal();
}

set
@@ -142,6 +130,23 @@ namespace Tensorflow
}
}

protected virtual Shape GetShapeInternal()
{
var dims = new Shape(new long[rank]);

if (_handle == null)
{
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, tf.Status.Handle);
}
else
{
for (int i = 0; i < rank; i++)
dims[i] = c_api.TF_Dim(_handle, i);
}

return dims;
}

public int[] _shape_tuple()
{
return rank < 0 ? null : shape.dims.Select(x => (int)x).ToArray();


+ 10
- 10
test/TensorFlowNET.Native.UnitTest/CApiTest.cs View File

@@ -56,10 +56,10 @@ namespace Tensorflow.Native.UnitTest
protected void TF_SetAttrBool(OperationDescription desc, string attrName, bool value)
=> c_api.TF_SetAttrBool(desc, attrName, value);

protected TF_DataType TFE_TensorHandleDataType(SafeTensorHandleHandle h)
protected TF_DataType TFE_TensorHandleDataType(SafeEagerTensorHandle h)
=> c_api.TFE_TensorHandleDataType(h);

protected int TFE_TensorHandleNumDims(SafeTensorHandleHandle h, SafeStatusHandle status)
protected int TFE_TensorHandleNumDims(SafeEagerTensorHandle h, SafeStatusHandle status)
=> c_api.TFE_TensorHandleNumDims(h, status);

protected TF_Code TF_GetCode(Status s)
@@ -80,7 +80,7 @@ namespace Tensorflow.Native.UnitTest
protected ulong TF_TensorByteSize(SafeTensorHandle t)
=> c_api.TF_TensorByteSize(t);

protected void TFE_OpAddInput(SafeEagerOpHandle op, SafeTensorHandleHandle h, SafeStatusHandle status)
protected void TFE_OpAddInput(SafeEagerOpHandle op, SafeEagerTensorHandle h, SafeStatusHandle status)
=> c_api.TFE_OpAddInput(op, h, status);

protected void TFE_OpSetAttrType(SafeEagerOpHandle op, string attr_name, TF_DataType value)
@@ -95,10 +95,10 @@ namespace Tensorflow.Native.UnitTest
protected SafeEagerOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status)
=> c_api.TFE_NewOp(ctx, op_or_function_name, status);

protected SafeTensorHandleHandle TFE_NewTensorHandle(SafeTensorHandle t, SafeStatusHandle status)
protected SafeEagerTensorHandle TFE_NewTensorHandle(SafeTensorHandle t, SafeStatusHandle status)
=> c_api.TFE_NewTensorHandle(t, status);

protected void TFE_Execute(SafeEagerOpHandle op, SafeTensorHandleHandle[] retvals, out int num_retvals, SafeStatusHandle status)
protected void TFE_Execute(SafeEagerOpHandle op, SafeEagerTensorHandle[] retvals, out int num_retvals, SafeStatusHandle status)
=> c_api.TFE_Execute(op, retvals, out num_retvals, status);

protected SafeContextOptionsHandle TFE_NewContextOptions()
@@ -110,7 +110,7 @@ namespace Tensorflow.Native.UnitTest
protected int TFE_OpGetInputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status)
=> c_api.TFE_OpGetInputLength(op, input_name, status);

protected int TFE_OpAddInputList(SafeEagerOpHandle op, SafeTensorHandleHandle[] inputs, int num_inputs, SafeStatusHandle status)
protected int TFE_OpAddInputList(SafeEagerOpHandle op, SafeEagerTensorHandle[] inputs, int num_inputs, SafeStatusHandle status)
=> c_api.TFE_OpAddInputList(op, inputs, num_inputs, status);

protected int TFE_OpGetOutputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status)
@@ -125,13 +125,13 @@ namespace Tensorflow.Native.UnitTest
protected void TFE_ExecutorWaitForAllPendingNodes(SafeExecutorHandle executor, SafeStatusHandle status)
=> c_api.TFE_ExecutorWaitForAllPendingNodes(executor, status);

protected SafeTensorHandle TFE_TensorHandleResolve(SafeTensorHandleHandle h, SafeStatusHandle status)
protected SafeTensorHandle TFE_TensorHandleResolve(SafeEagerTensorHandle h, SafeStatusHandle status)
=> c_api.TFE_TensorHandleResolve(h, status);

protected string TFE_TensorHandleDeviceName(SafeTensorHandleHandle h, SafeStatusHandle status)
protected string TFE_TensorHandleDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status)
=> c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(h, status));

protected string TFE_TensorHandleBackingDeviceName(SafeTensorHandleHandle h, SafeStatusHandle status)
protected string TFE_TensorHandleBackingDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status)
=> c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status));

protected SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status)
@@ -146,7 +146,7 @@ namespace Tensorflow.Native.UnitTest
protected string TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status)
=> c_api.TF_DeviceListName(list, index, status);

protected SafeTensorHandleHandle TFE_TensorHandleCopyToDevice(SafeTensorHandleHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status)
protected SafeEagerTensorHandle TFE_TensorHandleCopyToDevice(SafeEagerTensorHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status)
=> c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status);

protected void TFE_OpSetDevice(SafeEagerOpHandle op, string device_name, SafeStatusHandle status)


+ 1
- 1
test/TensorFlowNET.Native.UnitTest/Eager/Eager.Execute_MatMul_CPU.cs View File

@@ -32,7 +32,7 @@ namespace Tensorflow.Native.UnitTest.Eager
{
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));

var retvals = new SafeTensorHandleHandle[2];
var retvals = new SafeEagerTensorHandle[2];
using (var m = TestMatrixTensorHandle())
using (var matmul = MatMulOp(ctx, m, m))
{


+ 2
- 2
test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpGetInputAndOutputLengths.cs View File

@@ -25,7 +25,7 @@ namespace Tensorflow.Native.UnitTest.Eager
using var input1 = TestMatrixTensorHandle();
using var input2 = TestMatrixTensorHandle();

var retvals = new SafeTensorHandleHandle[2];
var retvals = new SafeEagerTensorHandle[2];
using (var identityOp = TFE_NewOp(ctx, "IdentityN", status))
{
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
@@ -36,7 +36,7 @@ namespace Tensorflow.Native.UnitTest.Eager
EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status));
CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status));

var inputs = new SafeTensorHandleHandle[] { input1, input2 };
var inputs = new SafeEagerTensorHandle[] { input1, input2 };
TFE_OpAddInputList(identityOp, inputs, 2, status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));



+ 1
- 1
test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpInferMixedTypeInputListAttrs.cs View File

@@ -41,7 +41,7 @@ namespace Tensorflow.Native.UnitTest.Eager
//EXPECT_EQ(attr_found->second.list().type(1), tensorflow::DataType::DT_FLOAT);
//EXPECT_EQ(attr_found->second.list().type(2), tensorflow::DataType::DT_INT32);

var retvals = new SafeTensorHandleHandle[0];
var retvals = new SafeEagerTensorHandle[0];
int num_retvals;
TFE_Execute(assertOp, retvals, out num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));


+ 1
- 1
test/TensorFlowNET.Native.UnitTest/Eager/Eager.TensorHandleDevices.cs View File

@@ -39,7 +39,7 @@ namespace Tensorflow.Native.UnitTest.Eager
using var hgpu = TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_device_name, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status));

var retvals = new SafeTensorHandleHandle[1];
var retvals = new SafeEagerTensorHandle[1];
using (var shape_op = ShapeOp(ctx, hgpu))
{
TFE_OpSetDevice(shape_op, gpu_device_name, status);


+ 1
- 1
test/TensorFlowNET.Native.UnitTest/Eager/Eager.Variables.cs View File

@@ -28,7 +28,7 @@ namespace Tensorflow.Native.UnitTest.Eager
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));

int num_retvals = 1;
var value_handle = new SafeTensorHandleHandle[1];
var value_handle = new SafeEagerTensorHandle[1];
using (var op = TFE_NewOp(ctx, "ReadVariableOp", status))
{
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));


+ 15
- 15
test/TensorFlowNET.Native.UnitTest/Eager/Eager.cs View File

@@ -11,7 +11,7 @@ namespace Tensorflow.Native.UnitTest.Eager
[TestClass]
public partial class CApiEagerTest : CApiTest
{
SafeTensorHandleHandle TestMatrixTensorHandle()
SafeEagerTensorHandle TestMatrixTensorHandle()
{
var dims = new long[] { 2, 2 };
var data = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
@@ -25,7 +25,7 @@ namespace Tensorflow.Native.UnitTest.Eager
return th;
}

SafeEagerOpHandle MatMulOp(SafeContextHandle ctx, SafeTensorHandleHandle a, SafeTensorHandleHandle b)
SafeEagerOpHandle MatMulOp(SafeContextHandle ctx, SafeEagerTensorHandle a, SafeEagerTensorHandle b)
{
using var status = TF_NewStatus();

@@ -63,7 +63,7 @@ namespace Tensorflow.Native.UnitTest.Eager
return false;
}

SafeEagerOpHandle ShapeOp(SafeContextHandle ctx, SafeTensorHandleHandle a)
SafeEagerOpHandle ShapeOp(SafeContextHandle ctx, SafeEagerTensorHandle a)
{
using var status = TF_NewStatus();

@@ -76,27 +76,27 @@ namespace Tensorflow.Native.UnitTest.Eager
return op;
}

unsafe SafeTensorHandleHandle CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status)
unsafe SafeEagerTensorHandle CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status)
{
var var_handle = new SafeTensorHandleHandle[1];
var var_handle = new SafeEagerTensorHandle[1];
int num_retvals;
using (var op = TFE_NewOp(ctx, "VarHandleOp", status))
{
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero);
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpSetAttrShape(op, "shape", new long[0], 0, status);
TFE_OpSetAttrString(op, "container", "", 0);
TFE_OpSetAttrString(op, "shared_name", "", 0);
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero);
TFE_Execute(op, var_handle, out num_retvals, status);
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero);
CHECK_EQ(1, num_retvals);
}

// Assign 'value' to it.
using (var op = TFE_NewOp(ctx, "AssignVariableOp", status))
{
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero);
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpAddInput(op, var_handle[0], status);

@@ -105,20 +105,20 @@ namespace Tensorflow.Native.UnitTest.Eager
tf.memcpy(TF_TensorData(t).ToPointer(), &value, TF_TensorByteSize(t));

var value_handle = c_api.TFE_NewTensorHandle(t, status);
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero);

TFE_OpAddInput(op, value_handle, status);
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero);

c_api.TFE_Execute(op, null, out num_retvals, status);
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero);
CHECK_EQ(0, num_retvals);
}

return var_handle[0];
}

SafeTensorHandleHandle TestAxisTensorHandle()
SafeEagerTensorHandle TestAxisTensorHandle()
{
var dims = new long[] { 1 };
var data = new int[] { 1 };
@@ -131,7 +131,7 @@ namespace Tensorflow.Native.UnitTest.Eager
return th;
}

SafeTensorHandleHandle TestScalarTensorHandle(bool value)
SafeEagerTensorHandle TestScalarTensorHandle(bool value)
{
var data = new[] { value };
var t = c_api.TF_AllocateTensor(TF_BOOL, null, 0, sizeof(bool));
@@ -143,7 +143,7 @@ namespace Tensorflow.Native.UnitTest.Eager
return th;
}

SafeTensorHandleHandle TestScalarTensorHandle(float value)
SafeEagerTensorHandle TestScalarTensorHandle(float value)
{
var data = new[] { value };
var t = c_api.TF_AllocateTensor(TF_FLOAT, null, 0, sizeof(float));


Loading…
Cancel
Save