diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs b/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs
index d9743ead..af7e94c8 100644
--- a/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs
+++ b/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs
@@ -8,6 +8,7 @@ namespace Tensorflow.NumPy
{
public partial class NDArray
{
+ protected NDArray() { }
public NDArray(bool value) : base(value) => NewEagerTensorHandle();
public NDArray(byte value) : base(value) => NewEagerTensorHandle();
public NDArray(short value) : base(value) => NewEagerTensorHandle();
@@ -57,6 +58,20 @@ namespace Tensorflow.NumPy
_ => throw new NotImplementedException("")
};
+ ///
+ /// Reuse the existing memory instead of copying it.
+ ///
+ ///
+ ///
+ ///
+ ///
+ protected void InitWithExistingMemory(IntPtr data_ptr, Shape shape, TF_DataType dtype, c_api.DeallocatorV2 deallocator)
+ {
+ _handle = c_api.TF_NewTensor(TF_DataType.TF_STRING, shape.dims, shape.ndim, data_ptr, (ulong)(shape.size * dtype.get_datatype_size()), deallocator, IntPtr.Zero);
+ tensor_util.DangerousManuallySetTensorDType(_handle, dtype);
+ NewEagerTensorHandle();
+ }
+
void NewEagerTensorHandle()
{
if (_handle is not null)
diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs
index 0e888a0a..2767e821 100644
--- a/src/TensorFlowNET.Core/Operations/array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/array_ops.cs
@@ -417,7 +417,7 @@ namespace Tensorflow
{
TF_DataType.TF_DOUBLE => constant(1.0d),
TF_DataType.TF_FLOAT => constant(1.0f),
- _ => constant(1)
+ _ => constant(1, dtype)
};
if (shape.ndim == 0)
diff --git a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs
index 2e7edc66..3779ddcf 100644
--- a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs
@@ -71,7 +71,7 @@ namespace Tensorflow
///
///
[DllImport(TensorFlowLibName)]
- public static extern SafeTensorHandle TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, ulong len, Deallocator deallocator, IntPtr deallocator_arg);
+ public static extern SafeTensorHandle TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, ulong len, DeallocatorV2 deallocator, IntPtr deallocator_arg);
public static unsafe SafeTensorHandle TF_NewTensor(byte[] data, Shape shape, TF_DataType dtype)
{
@@ -147,6 +147,15 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern TF_DataType TF_TensorType(SafeTensorHandle tensor);
+ ///
+ /// Set a new shape for the Tensor. Note that this API only works after tf2.11.
+ ///
+ ///
+ ///
+ ///
+ [DllImport(TensorFlowLibName)]
+ public static extern void TF_SetShape(SafeTensorHandle tensor, long[] dims, int num_dims);
+
///
/// Return the size in bytes required to encode a string `len` bytes long into a
/// TF_STRING tensor.
diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
index 25bb8882..e65c4850 100644
--- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs
+++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
@@ -22,6 +22,7 @@ using System.Text;
using Tensorflow.Eager;
using Tensorflow.Graphs;
using static Tensorflow.Binding;
+using System.Diagnostics;
namespace Tensorflow
{
@@ -649,5 +650,24 @@ would not be rank 1.", tensor.op.get_attr("axis")));
NewAxisMask = new_axis_mask
};
}
+
+ ///
+ /// Warning: this method is an extremely dangerous method. It directly changes the dtype inside the tensor
+ /// and security is not guaranteed at all. Currently this method is only used for some conditions to reuse
+ /// the existing memory. Any other usage should be prevented. If you are sure you want to use it when
+ /// developing tensorflow.net, please ask @Oceanic2018 or @AsakusaRinne first.
+ ///
+ ///
+ ///
+ internal static unsafe void DangerousManuallySetTensorDType(SafeTensorHandle handle, TF_DataType dtype)
+ {
+ long tf_tensor_address = handle.DangerousGetHandle().ToInt64();
+ long interface_address = *(long*)(tf_tensor_address);
+ long tensor_shape_address = interface_address + 8;
+ long tensor_dtype_address = tensor_shape_address + 13;
+ byte* dtype_pointer = (byte*)tensor_dtype_address;
+ *dtype_pointer = (byte)dtype;
+ Debug.Assert(c_api.TF_TensorType(handle) == dtype);
+ }
}
}