@@ -213,7 +213,7 @@ namespace Tensorflow.Eager | |||||
if (add_type_attr && !string.IsNullOrEmpty(input_arg.TypeAttr)) | if (add_type_attr && !string.IsNullOrEmpty(input_arg.TypeAttr)) | ||||
{ | { | ||||
var dtype = c_api.TFE_TensorHandleDataType(tensor.EagerTensorHandle); | |||||
var dtype = tensor.dtype; | |||||
c_api.TFE_OpSetAttrType(op, input_arg.TypeAttr, dtype); | c_api.TFE_OpSetAttrType(op, input_arg.TypeAttr, dtype); | ||||
flattened_attrs.Add(input_arg.TypeAttr); | flattened_attrs.Add(input_arg.TypeAttr); | ||||
flattened_attrs.Add(dtype); | flattened_attrs.Add(dtype); | ||||
@@ -7,16 +7,10 @@ namespace Tensorflow.Eager | |||||
{ | { | ||||
public partial class EagerTensor | public partial class EagerTensor | ||||
{ | { | ||||
public EagerTensor(SafeTensorHandle handle) | |||||
{ | |||||
NewEagerTensorHandle(handle); | |||||
} | |||||
public EagerTensor(SafeEagerTensorHandle handle) | public EagerTensor(SafeEagerTensorHandle handle) | ||||
{ | { | ||||
_id = ops.uid(); | _id = ops.uid(); | ||||
_eagerTensorHandle = handle; | _eagerTensorHandle = handle; | ||||
Resolve(); | |||||
} | } | ||||
#region scalar eager tensor | #region scalar eager tensor | ||||
@@ -67,8 +61,10 @@ namespace Tensorflow.Eager | |||||
tf.Status.Check(true); | tf.Status.Check(true); | ||||
} | } | ||||
private void Resolve() | |||||
public void Resolve() | |||||
{ | { | ||||
if (_handle != null) | |||||
return; | |||||
_handle = c_api.TFE_TensorHandleResolve(_eagerTensorHandle, tf.Status.Handle); | _handle = c_api.TFE_TensorHandleResolve(_eagerTensorHandle, tf.Status.Handle); | ||||
tf.Status.Check(true); | tf.Status.Check(true); | ||||
} | } | ||||
@@ -6,11 +6,47 @@ namespace Tensorflow.Eager | |||||
{ | { | ||||
public partial class EagerTensor : Tensor | public partial class EagerTensor : Tensor | ||||
{ | { | ||||
public override SafeTensorHandle Handle | |||||
{ | |||||
get | |||||
{ | |||||
Resolve(); | |||||
return _handle; | |||||
} | |||||
} | |||||
public override IntPtr buffer | |||||
{ | |||||
get | |||||
{ | |||||
Resolve(); | |||||
return base.buffer; | |||||
} | |||||
} | |||||
public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(_eagerTensorHandle, tf.Status.Handle)); | public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(_eagerTensorHandle, tf.Status.Handle)); | ||||
public override TF_DataType dtype => c_api.TFE_TensorHandleDataType(_eagerTensorHandle); | public override TF_DataType dtype => c_api.TFE_TensorHandleDataType(_eagerTensorHandle); | ||||
public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.Status.Handle); | public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.Status.Handle); | ||||
public override ulong bytesize | |||||
{ | |||||
get | |||||
{ | |||||
Resolve(); | |||||
return base.bytesize; | |||||
} | |||||
} | |||||
public override IntPtr TensorDataPointer | |||||
{ | |||||
get | |||||
{ | |||||
Resolve(); | |||||
return base.TensorDataPointer; | |||||
} | |||||
} | |||||
protected override Shape GetShapeInternal() | protected override Shape GetShapeInternal() | ||||
{ | { | ||||
var dims = new int[c_api.TFE_TensorHandleNumDims(_eagerTensorHandle, tf.Status.Handle)]; | var dims = new int[c_api.TFE_TensorHandleNumDims(_eagerTensorHandle, tf.Status.Handle)]; | ||||
@@ -19,6 +55,12 @@ namespace Tensorflow.Eager | |||||
return dims; | return dims; | ||||
} | } | ||||
protected override void SetShapeInternal(Shape value) | |||||
{ | |||||
if (!shape.is_compatible_with(value)) | |||||
throw new ValueError($"Tensor's shape is not compatible."); | |||||
} | |||||
public static int GetRank(IntPtr handle) | public static int GetRank(IntPtr handle) | ||||
{ | { | ||||
var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); | var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); | ||||
@@ -33,5 +75,11 @@ namespace Tensorflow.Eager | |||||
dims[i] = c_api.TFE_TensorHandleDim(tfe_tensor_handle, i, tf.Status.Handle); | dims[i] = c_api.TFE_TensorHandleDim(tfe_tensor_handle, i, tf.Status.Handle); | ||||
return dims; | return dims; | ||||
} | } | ||||
public override T[] ToArray<T>() | |||||
{ | |||||
Resolve(); | |||||
return base.ToArray<T>(); | |||||
} | |||||
} | } | ||||
} | } |
@@ -38,8 +38,8 @@ namespace Tensorflow.NumPy | |||||
tensor = tf.defaultSession.eval(tensor); | tensor = tf.defaultSession.eval(tensor); | ||||
_handle = tensor.Handle; | _handle = tensor.Handle; | ||||
} | } | ||||
NewEagerTensorHandle(); | |||||
NewEagerTensorHandle(); | |||||
} | } | ||||
public static NDArray Scalar<T>(T value) where T : unmanaged | public static NDArray Scalar<T>(T value) where T : unmanaged | ||||
@@ -57,7 +57,9 @@ namespace Tensorflow.NumPy | |||||
void NewEagerTensorHandle() | void NewEagerTensorHandle() | ||||
{ | { | ||||
if (_handle is not null) | if (_handle is not null) | ||||
_eagerTensorHandle = new EagerTensor(_handle).EagerTensorHandle; | |||||
{ | |||||
_eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle); | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -28,7 +28,7 @@ namespace Tensorflow | |||||
[SuppressMessage("ReSharper", "InvokeAsExtensionMethod")] | [SuppressMessage("ReSharper", "InvokeAsExtensionMethod")] | ||||
public partial class Tensor | public partial class Tensor | ||||
{ | { | ||||
public IntPtr TensorDataPointer => _handle == null ? IntPtr.Zero : TF_TensorData(_handle); | |||||
public virtual IntPtr TensorDataPointer => _handle == null ? IntPtr.Zero : TF_TensorData(_handle); | |||||
public Tensor() | public Tensor() | ||||
{ | { | ||||
@@ -5,124 +5,88 @@ namespace Tensorflow | |||||
{ | { | ||||
public partial class Tensor | public partial class Tensor | ||||
{ | { | ||||
public static explicit operator bool(Tensor tensor) | |||||
public unsafe static explicit operator bool(Tensor tensor) | |||||
{ | { | ||||
unsafe | |||||
{ | |||||
EnsureScalar(tensor); | |||||
EnsureDType(tensor, TF_DataType.TF_BOOL); | |||||
return *(bool*)tensor.buffer; | |||||
} | |||||
EnsureScalar(tensor); | |||||
EnsureDType(tensor, TF_DataType.TF_BOOL); | |||||
return *(bool*)tensor.buffer; | |||||
} | } | ||||
public static explicit operator sbyte(Tensor tensor) | |||||
public unsafe static explicit operator sbyte(Tensor tensor) | |||||
{ | { | ||||
unsafe | |||||
{ | |||||
EnsureScalar(tensor); | |||||
EnsureDType(tensor, TF_DataType.TF_INT8); | |||||
return *(sbyte*)tensor.buffer; | |||||
} | |||||
EnsureScalar(tensor); | |||||
EnsureDType(tensor, TF_DataType.TF_INT8); | |||||
return *(sbyte*)tensor.buffer; | |||||
} | } | ||||
public static explicit operator byte(Tensor tensor) | |||||
public unsafe static explicit operator byte(Tensor tensor) | |||||
{ | { | ||||
unsafe | |||||
{ | |||||
EnsureScalar(tensor); | |||||
EnsureDType(tensor, TF_DataType.TF_UINT8); | |||||
return *(byte*)tensor.buffer; | |||||
} | |||||
EnsureScalar(tensor); | |||||
EnsureDType(tensor, TF_DataType.TF_UINT8); | |||||
return *(byte*)tensor.buffer; | |||||
} | } | ||||
public static explicit operator ushort(Tensor tensor) | |||||
public unsafe static explicit operator ushort(Tensor tensor) | |||||
{ | { | ||||
unsafe | |||||
{ | |||||
EnsureScalar(tensor); | |||||
EnsureDType(tensor, TF_DataType.TF_UINT16); | |||||
return *(ushort*)tensor.buffer; | |||||
} | |||||
EnsureScalar(tensor); | |||||
EnsureDType(tensor, TF_DataType.TF_UINT16); | |||||
return *(ushort*)tensor.buffer; | |||||
} | } | ||||
public static explicit operator short(Tensor tensor) | |||||
public unsafe static explicit operator short(Tensor tensor) | |||||
{ | { | ||||
unsafe | |||||
{ | |||||
EnsureScalar(tensor); | |||||
EnsureDType(tensor, TF_DataType.TF_INT16); | |||||
return *(short*)tensor.buffer; | |||||
} | |||||
EnsureScalar(tensor); | |||||
EnsureDType(tensor, TF_DataType.TF_INT16); | |||||
return *(short*)tensor.buffer; | |||||
} | } | ||||
public static explicit operator int(Tensor tensor) | |||||
public unsafe static explicit operator int(Tensor tensor) | |||||
{ | { | ||||
unsafe | |||||
{ | |||||
EnsureScalar(tensor); | |||||
EnsureDType(tensor, TF_DataType.TF_INT32); | |||||
return *(int*)tensor.buffer; | |||||
} | |||||
EnsureScalar(tensor); | |||||
EnsureDType(tensor, TF_DataType.TF_INT32); | |||||
return *(int*)tensor.buffer; | |||||
} | } | ||||
public static explicit operator uint(Tensor tensor) | |||||
public unsafe static explicit operator uint(Tensor tensor) | |||||
{ | { | ||||
unsafe | |||||
{ | |||||
EnsureScalar(tensor); | |||||
EnsureDType(tensor, TF_DataType.TF_UINT32); | |||||
return *(uint*)tensor.buffer; | |||||
} | |||||
EnsureScalar(tensor); | |||||
EnsureDType(tensor, TF_DataType.TF_UINT32); | |||||
return *(uint*)tensor.buffer; | |||||
} | } | ||||
public static explicit operator long(Tensor tensor) | |||||
public unsafe static explicit operator long(Tensor tensor) | |||||
{ | { | ||||
unsafe | |||||
{ | |||||
EnsureScalar(tensor); | |||||
EnsureDType(tensor, TF_DataType.TF_INT64); | |||||
return *(long*)tensor.buffer; | |||||
} | |||||
EnsureScalar(tensor); | |||||
EnsureDType(tensor, TF_DataType.TF_INT64); | |||||
return *(long*)tensor.buffer; | |||||
} | } | ||||
public static explicit operator ulong(Tensor tensor) | |||||
public unsafe static explicit operator ulong(Tensor tensor) | |||||
{ | { | ||||
unsafe | |||||
{ | |||||
EnsureScalar(tensor); | |||||
EnsureDType(tensor, TF_DataType.TF_UINT64); | |||||
return *(ulong*)tensor.buffer; | |||||
} | |||||
EnsureScalar(tensor); | |||||
EnsureDType(tensor, TF_DataType.TF_UINT64); | |||||
return *(ulong*)tensor.buffer; | |||||
} | } | ||||
public static explicit operator float(Tensor tensor) | |||||
public unsafe static explicit operator float(Tensor tensor) | |||||
{ | { | ||||
unsafe | |||||
{ | |||||
EnsureScalar(tensor); | |||||
EnsureDType(tensor, TF_DataType.TF_FLOAT); | |||||
return *(float*)tensor.buffer; | |||||
} | |||||
EnsureScalar(tensor); | |||||
EnsureDType(tensor, TF_DataType.TF_FLOAT); | |||||
return *(float*)tensor.buffer; | |||||
} | } | ||||
public static explicit operator double(Tensor tensor) | |||||
public unsafe static explicit operator double(Tensor tensor) | |||||
{ | { | ||||
unsafe | |||||
{ | |||||
EnsureScalar(tensor); | |||||
EnsureDType(tensor, TF_DataType.TF_DOUBLE); | |||||
return *(double*)tensor.buffer; | |||||
} | |||||
EnsureScalar(tensor); | |||||
EnsureDType(tensor, TF_DataType.TF_DOUBLE); | |||||
return *(double*)tensor.buffer; | |||||
} | } | ||||
public static explicit operator string(Tensor tensor) | |||||
public unsafe static explicit operator string(Tensor tensor) | |||||
{ | { | ||||
unsafe | |||||
{ | |||||
EnsureScalar(tensor); | |||||
EnsureDType(tensor, TF_DataType.TF_STRING); | |||||
return new string((char*)tensor.buffer, 0, (int)tensor.size); | |||||
} | |||||
EnsureScalar(tensor); | |||||
EnsureDType(tensor, TF_DataType.TF_STRING); | |||||
return new string((char*)tensor.buffer, 0, (int)tensor.size); | |||||
} | } | ||||
[MethodImpl(MethodImplOptions.AggressiveInlining)] | [MethodImpl(MethodImplOptions.AggressiveInlining)] | ||||
@@ -12,7 +12,7 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
/// <typeparam name="T"></typeparam> | /// <typeparam name="T"></typeparam> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public unsafe T[] ToArray<T>() where T : unmanaged | |||||
public virtual unsafe T[] ToArray<T>() where T : unmanaged | |||||
{ | { | ||||
//Are the types matching? | //Are the types matching? | ||||
if (typeof(T).as_tf_dtype() != dtype) | if (typeof(T).as_tf_dtype() != dtype) | ||||
@@ -68,10 +68,10 @@ namespace Tensorflow | |||||
/// The DType of elements in this tensor. | /// The DType of elements in this tensor. | ||||
/// </summary> | /// </summary> | ||||
public virtual TF_DataType dtype => _handle == null ? _override_dtype : c_api.TF_TensorType(_handle); | public virtual TF_DataType dtype => _handle == null ? _override_dtype : c_api.TF_TensorType(_handle); | ||||
public ulong bytesize => _handle == null ? 0 : c_api.TF_TensorByteSize(_handle); | |||||
public virtual ulong bytesize => _handle == null ? 0 : c_api.TF_TensorByteSize(_handle); | |||||
public ulong dtypesize => (ulong)dtype.get_datatype_size(); | public ulong dtypesize => (ulong)dtype.get_datatype_size(); | ||||
public ulong size => _handle == null ? 0 : bytesize / dtypesize; | public ulong size => _handle == null ? 0 : bytesize / dtypesize; | ||||
public IntPtr buffer => _handle == null ? IntPtr.Zero : c_api.TF_TensorData(_handle); | |||||
public virtual IntPtr buffer => _handle == null ? IntPtr.Zero : c_api.TF_TensorData(_handle); | |||||
public int num_consumers(TF_Output oper_out) => _handle == null ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); | public int num_consumers(TF_Output oper_out) => _handle == null ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); | ||||
public int ndim => rank; | public int ndim => rank; | ||||
@@ -86,7 +86,7 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
public object Tag { get; set; } | public object Tag { get; set; } | ||||
protected new SafeTensorHandle _handle; | protected new SafeTensorHandle _handle; | ||||
public SafeTensorHandle Handle => _handle; | |||||
public virtual SafeTensorHandle Handle => _handle; | |||||
protected SafeEagerTensorHandle _eagerTensorHandle; | protected SafeEagerTensorHandle _eagerTensorHandle; | ||||
/// <summary> | /// <summary> | ||||
@@ -114,18 +114,7 @@ namespace Tensorflow | |||||
set | set | ||||
{ | { | ||||
if (this is EagerTensor) | |||||
{ | |||||
if(!shape.is_compatible_with(value)) | |||||
throw new ValueError($"Tensor's shape is not compatible."); | |||||
return; | |||||
} | |||||
if (value == null) | |||||
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.Status.Handle); | |||||
else | |||||
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.dims, value.ndim, tf.Status.Handle); | |||||
SetShapeInternal(value); | |||||
tf.Status.Check(true); | tf.Status.Check(true); | ||||
} | } | ||||
} | } | ||||
@@ -147,6 +136,14 @@ namespace Tensorflow | |||||
return dims; | return dims; | ||||
} | } | ||||
protected virtual void SetShapeInternal(Shape value) | |||||
{ | |||||
if (value == null) | |||||
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.Status.Handle); | |||||
else | |||||
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.dims, value.ndim, tf.Status.Handle); | |||||
} | |||||
public int[] _shape_tuple() | public int[] _shape_tuple() | ||||
{ | { | ||||
return rank < 0 ? null : shape.dims.Select(x => (int)x).ToArray(); | return rank < 0 ? null : shape.dims.Select(x => (int)x).ToArray(); | ||||
@@ -233,9 +233,9 @@ namespace Tensorflow | |||||
return false; | return false; | ||||
} | } | ||||
if (tensor.GetType() == typeof(EagerTensor)) | |||||
if (tensor is EagerTensor eagerTensor) | |||||
{ | { | ||||
if(tensor.dtype == TF_DataType.TF_INT64) | |||||
if(tensor.dtype == tf.int64) | |||||
return new Shape(tensor.ToArray<long>()); | return new Shape(tensor.ToArray<long>()); | ||||
else | else | ||||
return new Shape(tensor.ToArray<int>()); | return new Shape(tensor.ToArray<int>()); | ||||