@@ -213,7 +213,7 @@ namespace Tensorflow.Eager | |||
if (add_type_attr && !string.IsNullOrEmpty(input_arg.TypeAttr)) | |||
{ | |||
var dtype = c_api.TFE_TensorHandleDataType(tensor.EagerTensorHandle); | |||
var dtype = tensor.dtype; | |||
c_api.TFE_OpSetAttrType(op, input_arg.TypeAttr, dtype); | |||
flattened_attrs.Add(input_arg.TypeAttr); | |||
flattened_attrs.Add(dtype); | |||
@@ -7,16 +7,10 @@ namespace Tensorflow.Eager | |||
{ | |||
public partial class EagerTensor | |||
{ | |||
public EagerTensor(SafeTensorHandle handle) | |||
{ | |||
NewEagerTensorHandle(handle); | |||
} | |||
public EagerTensor(SafeEagerTensorHandle handle) | |||
{ | |||
_id = ops.uid(); | |||
_eagerTensorHandle = handle; | |||
Resolve(); | |||
} | |||
#region scalar eager tensor | |||
@@ -67,8 +61,10 @@ namespace Tensorflow.Eager | |||
tf.Status.Check(true); | |||
} | |||
private void Resolve() | |||
public void Resolve() | |||
{ | |||
if (_handle != null) | |||
return; | |||
_handle = c_api.TFE_TensorHandleResolve(_eagerTensorHandle, tf.Status.Handle); | |||
tf.Status.Check(true); | |||
} | |||
@@ -6,11 +6,47 @@ namespace Tensorflow.Eager | |||
{ | |||
public partial class EagerTensor : Tensor | |||
{ | |||
public override SafeTensorHandle Handle | |||
{ | |||
get | |||
{ | |||
Resolve(); | |||
return _handle; | |||
} | |||
} | |||
public override IntPtr buffer | |||
{ | |||
get | |||
{ | |||
Resolve(); | |||
return base.buffer; | |||
} | |||
} | |||
public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(_eagerTensorHandle, tf.Status.Handle)); | |||
public override TF_DataType dtype => c_api.TFE_TensorHandleDataType(_eagerTensorHandle); | |||
public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.Status.Handle); | |||
public override ulong bytesize | |||
{ | |||
get | |||
{ | |||
Resolve(); | |||
return base.bytesize; | |||
} | |||
} | |||
public override IntPtr TensorDataPointer | |||
{ | |||
get | |||
{ | |||
Resolve(); | |||
return base.TensorDataPointer; | |||
} | |||
} | |||
protected override Shape GetShapeInternal() | |||
{ | |||
var dims = new int[c_api.TFE_TensorHandleNumDims(_eagerTensorHandle, tf.Status.Handle)]; | |||
@@ -19,6 +55,12 @@ namespace Tensorflow.Eager | |||
return dims; | |||
} | |||
protected override void SetShapeInternal(Shape value) | |||
{ | |||
if (!shape.is_compatible_with(value)) | |||
throw new ValueError($"Tensor's shape is not compatible."); | |||
} | |||
public static int GetRank(IntPtr handle) | |||
{ | |||
var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); | |||
@@ -33,5 +75,11 @@ namespace Tensorflow.Eager | |||
dims[i] = c_api.TFE_TensorHandleDim(tfe_tensor_handle, i, tf.Status.Handle); | |||
return dims; | |||
} | |||
public override T[] ToArray<T>() | |||
{ | |||
Resolve(); | |||
return base.ToArray<T>(); | |||
} | |||
} | |||
} |
@@ -38,8 +38,8 @@ namespace Tensorflow.NumPy | |||
tensor = tf.defaultSession.eval(tensor); | |||
_handle = tensor.Handle; | |||
} | |||
NewEagerTensorHandle(); | |||
NewEagerTensorHandle(); | |||
} | |||
public static NDArray Scalar<T>(T value) where T : unmanaged | |||
@@ -57,7 +57,9 @@ namespace Tensorflow.NumPy | |||
void NewEagerTensorHandle() | |||
{ | |||
if (_handle is not null) | |||
_eagerTensorHandle = new EagerTensor(_handle).EagerTensorHandle; | |||
{ | |||
_eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle); | |||
} | |||
} | |||
} | |||
} |
@@ -28,7 +28,7 @@ namespace Tensorflow | |||
[SuppressMessage("ReSharper", "InvokeAsExtensionMethod")] | |||
public partial class Tensor | |||
{ | |||
public IntPtr TensorDataPointer => _handle == null ? IntPtr.Zero : TF_TensorData(_handle); | |||
public virtual IntPtr TensorDataPointer => _handle == null ? IntPtr.Zero : TF_TensorData(_handle); | |||
public Tensor() | |||
{ | |||
@@ -5,124 +5,88 @@ namespace Tensorflow | |||
{ | |||
public partial class Tensor | |||
{ | |||
public static explicit operator bool(Tensor tensor) | |||
public unsafe static explicit operator bool(Tensor tensor) | |||
{ | |||
unsafe | |||
{ | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_BOOL); | |||
return *(bool*)tensor.buffer; | |||
} | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_BOOL); | |||
return *(bool*)tensor.buffer; | |||
} | |||
public static explicit operator sbyte(Tensor tensor) | |||
public unsafe static explicit operator sbyte(Tensor tensor) | |||
{ | |||
unsafe | |||
{ | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_INT8); | |||
return *(sbyte*)tensor.buffer; | |||
} | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_INT8); | |||
return *(sbyte*)tensor.buffer; | |||
} | |||
public static explicit operator byte(Tensor tensor) | |||
public unsafe static explicit operator byte(Tensor tensor) | |||
{ | |||
unsafe | |||
{ | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_UINT8); | |||
return *(byte*)tensor.buffer; | |||
} | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_UINT8); | |||
return *(byte*)tensor.buffer; | |||
} | |||
public static explicit operator ushort(Tensor tensor) | |||
public unsafe static explicit operator ushort(Tensor tensor) | |||
{ | |||
unsafe | |||
{ | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_UINT16); | |||
return *(ushort*)tensor.buffer; | |||
} | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_UINT16); | |||
return *(ushort*)tensor.buffer; | |||
} | |||
public static explicit operator short(Tensor tensor) | |||
public unsafe static explicit operator short(Tensor tensor) | |||
{ | |||
unsafe | |||
{ | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_INT16); | |||
return *(short*)tensor.buffer; | |||
} | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_INT16); | |||
return *(short*)tensor.buffer; | |||
} | |||
public static explicit operator int(Tensor tensor) | |||
public unsafe static explicit operator int(Tensor tensor) | |||
{ | |||
unsafe | |||
{ | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_INT32); | |||
return *(int*)tensor.buffer; | |||
} | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_INT32); | |||
return *(int*)tensor.buffer; | |||
} | |||
public static explicit operator uint(Tensor tensor) | |||
public unsafe static explicit operator uint(Tensor tensor) | |||
{ | |||
unsafe | |||
{ | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_UINT32); | |||
return *(uint*)tensor.buffer; | |||
} | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_UINT32); | |||
return *(uint*)tensor.buffer; | |||
} | |||
public static explicit operator long(Tensor tensor) | |||
public unsafe static explicit operator long(Tensor tensor) | |||
{ | |||
unsafe | |||
{ | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_INT64); | |||
return *(long*)tensor.buffer; | |||
} | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_INT64); | |||
return *(long*)tensor.buffer; | |||
} | |||
public static explicit operator ulong(Tensor tensor) | |||
public unsafe static explicit operator ulong(Tensor tensor) | |||
{ | |||
unsafe | |||
{ | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_UINT64); | |||
return *(ulong*)tensor.buffer; | |||
} | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_UINT64); | |||
return *(ulong*)tensor.buffer; | |||
} | |||
public static explicit operator float(Tensor tensor) | |||
public unsafe static explicit operator float(Tensor tensor) | |||
{ | |||
unsafe | |||
{ | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_FLOAT); | |||
return *(float*)tensor.buffer; | |||
} | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_FLOAT); | |||
return *(float*)tensor.buffer; | |||
} | |||
public static explicit operator double(Tensor tensor) | |||
public unsafe static explicit operator double(Tensor tensor) | |||
{ | |||
unsafe | |||
{ | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_DOUBLE); | |||
return *(double*)tensor.buffer; | |||
} | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_DOUBLE); | |||
return *(double*)tensor.buffer; | |||
} | |||
public static explicit operator string(Tensor tensor) | |||
public unsafe static explicit operator string(Tensor tensor) | |||
{ | |||
unsafe | |||
{ | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_STRING); | |||
return new string((char*)tensor.buffer, 0, (int)tensor.size); | |||
} | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_STRING); | |||
return new string((char*)tensor.buffer, 0, (int)tensor.size); | |||
} | |||
[MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
@@ -12,7 +12,7 @@ namespace Tensorflow | |||
/// </summary> | |||
/// <typeparam name="T"></typeparam> | |||
/// <returns></returns> | |||
public unsafe T[] ToArray<T>() where T : unmanaged | |||
public virtual unsafe T[] ToArray<T>() where T : unmanaged | |||
{ | |||
//Are the types matching? | |||
if (typeof(T).as_tf_dtype() != dtype) | |||
@@ -68,10 +68,10 @@ namespace Tensorflow | |||
/// The DType of elements in this tensor. | |||
/// </summary> | |||
public virtual TF_DataType dtype => _handle == null ? _override_dtype : c_api.TF_TensorType(_handle); | |||
public ulong bytesize => _handle == null ? 0 : c_api.TF_TensorByteSize(_handle); | |||
public virtual ulong bytesize => _handle == null ? 0 : c_api.TF_TensorByteSize(_handle); | |||
public ulong dtypesize => (ulong)dtype.get_datatype_size(); | |||
public ulong size => _handle == null ? 0 : bytesize / dtypesize; | |||
public IntPtr buffer => _handle == null ? IntPtr.Zero : c_api.TF_TensorData(_handle); | |||
public virtual IntPtr buffer => _handle == null ? IntPtr.Zero : c_api.TF_TensorData(_handle); | |||
public int num_consumers(TF_Output oper_out) => _handle == null ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); | |||
public int ndim => rank; | |||
@@ -86,7 +86,7 @@ namespace Tensorflow | |||
/// </summary> | |||
public object Tag { get; set; } | |||
protected new SafeTensorHandle _handle; | |||
public SafeTensorHandle Handle => _handle; | |||
public virtual SafeTensorHandle Handle => _handle; | |||
protected SafeEagerTensorHandle _eagerTensorHandle; | |||
/// <summary> | |||
@@ -114,18 +114,7 @@ namespace Tensorflow | |||
set | |||
{ | |||
if (this is EagerTensor) | |||
{ | |||
if(!shape.is_compatible_with(value)) | |||
throw new ValueError($"Tensor's shape is not compatible."); | |||
return; | |||
} | |||
if (value == null) | |||
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.Status.Handle); | |||
else | |||
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.dims, value.ndim, tf.Status.Handle); | |||
SetShapeInternal(value); | |||
tf.Status.Check(true); | |||
} | |||
} | |||
@@ -147,6 +136,14 @@ namespace Tensorflow | |||
return dims; | |||
} | |||
protected virtual void SetShapeInternal(Shape value) | |||
{ | |||
if (value == null) | |||
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.Status.Handle); | |||
else | |||
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.dims, value.ndim, tf.Status.Handle); | |||
} | |||
public int[] _shape_tuple() | |||
{ | |||
return rank < 0 ? null : shape.dims.Select(x => (int)x).ToArray(); | |||
@@ -233,9 +233,9 @@ namespace Tensorflow | |||
return false; | |||
} | |||
if (tensor.GetType() == typeof(EagerTensor)) | |||
if (tensor is EagerTensor eagerTensor) | |||
{ | |||
if(tensor.dtype == TF_DataType.TF_INT64) | |||
if(tensor.dtype == tf.int64) | |||
return new Shape(tensor.ToArray<long>()); | |||
else | |||
return new Shape(tensor.ToArray<int>()); | |||