Add the constructor of NDArray which reuses memorytags/v0.100.5-BERT-load
@@ -70,7 +70,7 @@ public class NpyFormat | |||||
if (type == typeof(bool)) | if (type == typeof(bool)) | ||||
return "|b1"; | return "|b1"; | ||||
else if (type == typeof(byte)) | else if (type == typeof(byte)) | ||||
return "|i1"; | |||||
return "|u1"; | |||||
else if (type == typeof(short)) | else if (type == typeof(short)) | ||||
return "<i2"; | return "<i2"; | ||||
else if (type == typeof(int)) | else if (type == typeof(int)) | ||||
@@ -8,6 +8,7 @@ namespace Tensorflow.NumPy | |||||
{ | { | ||||
public partial class NDArray | public partial class NDArray | ||||
{ | { | ||||
protected NDArray() { } | |||||
public NDArray(bool value) : base(value) => NewEagerTensorHandle(); | public NDArray(bool value) : base(value) => NewEagerTensorHandle(); | ||||
public NDArray(byte value) : base(value) => NewEagerTensorHandle(); | public NDArray(byte value) : base(value) => NewEagerTensorHandle(); | ||||
public NDArray(short value) : base(value) => NewEagerTensorHandle(); | public NDArray(short value) : base(value) => NewEagerTensorHandle(); | ||||
@@ -57,6 +58,20 @@ namespace Tensorflow.NumPy | |||||
_ => throw new NotImplementedException("") | _ => throw new NotImplementedException("") | ||||
}; | }; | ||||
/// <summary> | |||||
/// Reuse the existing memory instead of copying it. | |||||
/// </summary> | |||||
/// <param name="data_ptr"></param> | |||||
/// <param name="shape"></param> | |||||
/// <param name="dtype"></param> | |||||
/// <param name="deallocator"></param> | |||||
protected void InitWithExistingMemory(IntPtr data_ptr, Shape shape, TF_DataType dtype, c_api.DeallocatorV2 deallocator) | |||||
{ | |||||
_handle = c_api.TF_NewTensor(TF_DataType.TF_STRING, shape.dims, shape.ndim, data_ptr, (ulong)(shape.size * dtype.get_datatype_size()), deallocator, IntPtr.Zero); | |||||
tensor_util.DangerousManuallySetTensorDType(_handle, dtype); | |||||
NewEagerTensorHandle(); | |||||
} | |||||
void NewEagerTensorHandle() | void NewEagerTensorHandle() | ||||
{ | { | ||||
if (_handle is not null) | if (_handle is not null) | ||||
@@ -417,7 +417,7 @@ namespace Tensorflow | |||||
{ | { | ||||
TF_DataType.TF_DOUBLE => constant(1.0d), | TF_DataType.TF_DOUBLE => constant(1.0d), | ||||
TF_DataType.TF_FLOAT => constant(1.0f), | TF_DataType.TF_FLOAT => constant(1.0f), | ||||
_ => constant(1) | |||||
_ => constant(1, dtype) | |||||
}; | }; | ||||
if (shape.ndim == 0) | if (shape.ndim == 0) | ||||
@@ -71,7 +71,7 @@ namespace Tensorflow | |||||
/// <param name="deallocator_arg"></param> | /// <param name="deallocator_arg"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern SafeTensorHandle TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, ulong len, Deallocator deallocator, IntPtr deallocator_arg); | |||||
public static extern SafeTensorHandle TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, ulong len, DeallocatorV2 deallocator, IntPtr deallocator_arg); | |||||
public static unsafe SafeTensorHandle TF_NewTensor(byte[] data, Shape shape, TF_DataType dtype) | public static unsafe SafeTensorHandle TF_NewTensor(byte[] data, Shape shape, TF_DataType dtype) | ||||
{ | { | ||||
@@ -147,6 +147,15 @@ namespace Tensorflow | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern TF_DataType TF_TensorType(SafeTensorHandle tensor); | public static extern TF_DataType TF_TensorType(SafeTensorHandle tensor); | ||||
/// <summary> | |||||
/// Set a new shape for the Tensor. Note that this API only works after tf2.11. | |||||
/// </summary> | |||||
/// <param name="tensor"></param> | |||||
/// <param name="dims"></param> | |||||
/// <param name="num_dims"></param> | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern void TF_SetShape(SafeTensorHandle tensor, long[] dims, int num_dims); | |||||
/// <summary> | /// <summary> | ||||
/// Return the size in bytes required to encode a string `len` bytes long into a | /// Return the size in bytes required to encode a string `len` bytes long into a | ||||
/// TF_STRING tensor. | /// TF_STRING tensor. | ||||
@@ -22,6 +22,7 @@ using System.Text; | |||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
using Tensorflow.Graphs; | using Tensorflow.Graphs; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using System.Diagnostics; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -649,5 +650,24 @@ would not be rank 1.", tensor.op.get_attr("axis"))); | |||||
NewAxisMask = new_axis_mask | NewAxisMask = new_axis_mask | ||||
}; | }; | ||||
} | } | ||||
/// <summary> | |||||
/// Warning: this method is an extremely dangerous method. It directly changes the dtype inside the tensor | |||||
/// and security is not guaranteed at all. Currently this method is only used for some conditions to reuse | |||||
/// the existing memory. Any other usage should be prevented. If you are sure you want to use it when | |||||
/// developing tensorflow.net, please ask @Oceanic2018 or @AsakusaRinne first. | |||||
/// </summary> | |||||
/// <param name="handle"></param> | |||||
/// <param name="dtype"></param> | |||||
internal static unsafe void DangerousManuallySetTensorDType(SafeTensorHandle handle, TF_DataType dtype) | |||||
{ | |||||
long tf_tensor_address = handle.DangerousGetHandle().ToInt64(); | |||||
long interface_address = *(long*)(tf_tensor_address); | |||||
long tensor_shape_address = interface_address + 8; | |||||
long tensor_dtype_address = tensor_shape_address + 13; | |||||
byte* dtype_pointer = (byte*)tensor_dtype_address; | |||||
*dtype_pointer = (byte)dtype; | |||||
Debug.Assert(c_api.TF_TensorType(handle) == dtype); | |||||
} | |||||
} | } | ||||
} | } |