diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs index 3bab7c07..c6158ab0 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs @@ -213,7 +213,7 @@ namespace Tensorflow.Eager if (add_type_attr && !string.IsNullOrEmpty(input_arg.TypeAttr)) { - var dtype = c_api.TFE_TensorHandleDataType(tensor.EagerTensorHandle); + var dtype = tensor.dtype; c_api.TFE_OpSetAttrType(op, input_arg.TypeAttr, dtype); flattened_attrs.Add(input_arg.TypeAttr); flattened_attrs.Add(dtype); diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs index 1390daf2..b9f741f3 100644 --- a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs @@ -7,16 +7,10 @@ namespace Tensorflow.Eager { public partial class EagerTensor { - public EagerTensor(SafeTensorHandle handle) - { - NewEagerTensorHandle(handle); - } - public EagerTensor(SafeEagerTensorHandle handle) { _id = ops.uid(); _eagerTensorHandle = handle; - Resolve(); } #region scalar eager tensor @@ -67,8 +61,10 @@ namespace Tensorflow.Eager tf.Status.Check(true); } - private void Resolve() + public void Resolve() { + if (_handle != null) + return; _handle = c_api.TFE_TensorHandleResolve(_eagerTensorHandle, tf.Status.Handle); tf.Status.Check(true); } diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.cs index 30a13312..f85e8df6 100644 --- a/src/TensorFlowNET.Core/Eager/EagerTensor.cs +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.cs @@ -6,11 +6,47 @@ namespace Tensorflow.Eager { public partial class EagerTensor : Tensor { + public override SafeTensorHandle Handle + { + get + { + Resolve(); + return _handle; + } + } + + public override IntPtr buffer + { + get + { + Resolve(); + return base.buffer; + } + } + public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(_eagerTensorHandle, tf.Status.Handle)); public override TF_DataType dtype => c_api.TFE_TensorHandleDataType(_eagerTensorHandle); public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.Status.Handle); + public override ulong bytesize + { + get + { + Resolve(); + return base.bytesize; + } + } + + public override IntPtr TensorDataPointer + { + get + { + Resolve(); + return base.TensorDataPointer; + } + } + protected override Shape GetShapeInternal() { var dims = new int[c_api.TFE_TensorHandleNumDims(_eagerTensorHandle, tf.Status.Handle)]; @@ -19,6 +55,12 @@ namespace Tensorflow.Eager return dims; } + protected override void SetShapeInternal(Shape value) + { + if (!shape.is_compatible_with(value)) + throw new ValueError($"Tensor's shape is not compatible."); + } + public static int GetRank(IntPtr handle) { var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); @@ -33,5 +75,11 @@ namespace Tensorflow.Eager dims[i] = c_api.TFE_TensorHandleDim(tfe_tensor_handle, i, tf.Status.Handle); return dims; } + + public override T[] ToArray() + { + Resolve(); + return base.ToArray(); + } } } diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs b/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs index 1139d42c..fc8d6539 100644 --- a/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs +++ b/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs @@ -38,8 +38,8 @@ namespace Tensorflow.NumPy tensor = tf.defaultSession.eval(tensor); _handle = tensor.Handle; } - - NewEagerTensorHandle(); + + NewEagerTensorHandle(); } public static NDArray Scalar(T value) where T : unmanaged @@ -57,7 +57,9 @@ namespace Tensorflow.NumPy void NewEagerTensorHandle() { if (_handle is not null) - _eagerTensorHandle = new EagerTensor(_handle).EagerTensorHandle; + { + _eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle); + } } } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index 22eddc55..0e460bd3 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -28,7 +28,7 @@ namespace Tensorflow [SuppressMessage("ReSharper", "InvokeAsExtensionMethod")] public partial class Tensor { - public IntPtr TensorDataPointer => _handle == null ? IntPtr.Zero : TF_TensorData(_handle); + public virtual IntPtr TensorDataPointer => _handle == null ? IntPtr.Zero : TF_TensorData(_handle); public Tensor() { diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs index 8e0fb77f..d20c48ab 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs @@ -5,124 +5,88 @@ namespace Tensorflow { public partial class Tensor { - public static explicit operator bool(Tensor tensor) + public unsafe static explicit operator bool(Tensor tensor) { - unsafe - { - EnsureScalar(tensor); - EnsureDType(tensor, TF_DataType.TF_BOOL); - return *(bool*)tensor.buffer; - } + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_BOOL); + return *(bool*)tensor.buffer; } - public static explicit operator sbyte(Tensor tensor) + public unsafe static explicit operator sbyte(Tensor tensor) { - unsafe - { - EnsureScalar(tensor); - EnsureDType(tensor, TF_DataType.TF_INT8); - return *(sbyte*)tensor.buffer; - } + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_INT8); + return *(sbyte*)tensor.buffer; } - public static explicit operator byte(Tensor tensor) + public unsafe static explicit operator byte(Tensor tensor) { - unsafe - { - EnsureScalar(tensor); - EnsureDType(tensor, TF_DataType.TF_UINT8); - return *(byte*)tensor.buffer; - } + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_UINT8); + return *(byte*)tensor.buffer; } - public static explicit operator ushort(Tensor tensor) + public unsafe static explicit operator ushort(Tensor tensor) { - unsafe - { - EnsureScalar(tensor); - EnsureDType(tensor, TF_DataType.TF_UINT16); - return *(ushort*)tensor.buffer; - } + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_UINT16); + return *(ushort*)tensor.buffer; } - public static explicit operator short(Tensor tensor) + public unsafe static explicit operator short(Tensor tensor) { - unsafe - { - EnsureScalar(tensor); - EnsureDType(tensor, TF_DataType.TF_INT16); - return *(short*)tensor.buffer; - } + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_INT16); + return *(short*)tensor.buffer; } - public static explicit operator int(Tensor tensor) + public unsafe static explicit operator int(Tensor tensor) { - unsafe - { - EnsureScalar(tensor); - EnsureDType(tensor, TF_DataType.TF_INT32); - return *(int*)tensor.buffer; - } + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_INT32); + return *(int*)tensor.buffer; } - public static explicit operator uint(Tensor tensor) + public unsafe static explicit operator uint(Tensor tensor) { - unsafe - { - EnsureScalar(tensor); - EnsureDType(tensor, TF_DataType.TF_UINT32); - return *(uint*)tensor.buffer; - } + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_UINT32); + return *(uint*)tensor.buffer; } - public static explicit operator long(Tensor tensor) + public unsafe static explicit operator long(Tensor tensor) { - unsafe - { - EnsureScalar(tensor); - EnsureDType(tensor, TF_DataType.TF_INT64); - return *(long*)tensor.buffer; - } + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_INT64); + return *(long*)tensor.buffer; } - public static explicit operator ulong(Tensor tensor) + public unsafe static explicit operator ulong(Tensor tensor) { - unsafe - { - EnsureScalar(tensor); - EnsureDType(tensor, TF_DataType.TF_UINT64); - return *(ulong*)tensor.buffer; - } + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_UINT64); + return *(ulong*)tensor.buffer; } - public static explicit operator float(Tensor tensor) + public unsafe static explicit operator float(Tensor tensor) { - unsafe - { - EnsureScalar(tensor); - EnsureDType(tensor, TF_DataType.TF_FLOAT); - return *(float*)tensor.buffer; - } + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_FLOAT); + return *(float*)tensor.buffer; } - public static explicit operator double(Tensor tensor) + public unsafe static explicit operator double(Tensor tensor) { - unsafe - { - EnsureScalar(tensor); - EnsureDType(tensor, TF_DataType.TF_DOUBLE); - return *(double*)tensor.buffer; - } + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_DOUBLE); + return *(double*)tensor.buffer; } - public static explicit operator string(Tensor tensor) + public unsafe static explicit operator string(Tensor tensor) { - unsafe - { - EnsureScalar(tensor); - EnsureDType(tensor, TF_DataType.TF_STRING); - return new string((char*)tensor.buffer, 0, (int)tensor.size); - } + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_STRING); + return new string((char*)tensor.buffer, 0, (int)tensor.size); } [MethodImpl(MethodImplOptions.AggressiveInlining)] diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs index 5f00e6d9..5a977142 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs @@ -12,7 +12,7 @@ namespace Tensorflow /// /// /// - public unsafe T[] ToArray() where T : unmanaged + public virtual unsafe T[] ToArray() where T : unmanaged { //Are the types matching? if (typeof(T).as_tf_dtype() != dtype) diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 3e76d3fa..3f4ef8e5 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -68,10 +68,10 @@ namespace Tensorflow /// The DType of elements in this tensor. /// public virtual TF_DataType dtype => _handle == null ? _override_dtype : c_api.TF_TensorType(_handle); - public ulong bytesize => _handle == null ? 0 : c_api.TF_TensorByteSize(_handle); + public virtual ulong bytesize => _handle == null ? 0 : c_api.TF_TensorByteSize(_handle); public ulong dtypesize => (ulong)dtype.get_datatype_size(); public ulong size => _handle == null ? 0 : bytesize / dtypesize; - public IntPtr buffer => _handle == null ? IntPtr.Zero : c_api.TF_TensorData(_handle); + public virtual IntPtr buffer => _handle == null ? IntPtr.Zero : c_api.TF_TensorData(_handle); public int num_consumers(TF_Output oper_out) => _handle == null ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); public int ndim => rank; @@ -86,7 +86,7 @@ namespace Tensorflow /// public object Tag { get; set; } protected new SafeTensorHandle _handle; - public SafeTensorHandle Handle => _handle; + public virtual SafeTensorHandle Handle => _handle; protected SafeEagerTensorHandle _eagerTensorHandle; /// @@ -114,18 +114,7 @@ namespace Tensorflow set { - if (this is EagerTensor) - { - if(!shape.is_compatible_with(value)) - throw new ValueError($"Tensor's shape is not compatible."); - return; - } - - if (value == null) - c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.Status.Handle); - else - c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.dims, value.ndim, tf.Status.Handle); - + SetShapeInternal(value); tf.Status.Check(true); } } @@ -147,6 +136,14 @@ namespace Tensorflow return dims; } + protected virtual void SetShapeInternal(Shape value) + { + if (value == null) + c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.Status.Handle); + else + c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.dims, value.ndim, tf.Status.Handle); + } + public int[] _shape_tuple() { return rank < 0 ? null : shape.dims.Select(x => (int)x).ToArray(); diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 96d1e7b4..ae2d4dd8 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -233,9 +233,9 @@ namespace Tensorflow return false; } - if (tensor.GetType() == typeof(EagerTensor)) + if (tensor is EagerTensor eagerTensor) { - if(tensor.dtype == TF_DataType.TF_INT64) + if(tensor.dtype == tf.int64) return new Shape(tensor.ToArray()); else return new Shape(tensor.ToArray());