Browse Source

Merge pull request #1044 from AsakusaRinne/add_cv_compatibility

Add the constructor of NDArray which reuses memory
tags/v0.100.5-BERT-load
Haiping GitHub 2 years ago
parent
commit
179e32ae20
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 47 additions and 3 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/NumPy/Persistence/NpyFormat.cs
  2. +15
    -0
      src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Operations/array_ops.cs
  4. +10
    -1
      src/TensorFlowNET.Core/Tensors/c_api.tensor.cs
  5. +20
    -0
      src/TensorFlowNET.Core/Tensors/tensor_util.cs

+ 1
- 1
src/TensorFlowNET.Core/NumPy/Persistence/NpyFormat.cs View File

@@ -70,7 +70,7 @@ public class NpyFormat
if (type == typeof(bool))
return "|b1";
else if (type == typeof(byte))
return "|i1";
return "|u1";
else if (type == typeof(short))
return "<i2";
else if (type == typeof(int))


+ 15
- 0
src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs View File

@@ -8,6 +8,7 @@ namespace Tensorflow.NumPy
{
public partial class NDArray
{
protected NDArray() { }
public NDArray(bool value) : base(value) => NewEagerTensorHandle();
public NDArray(byte value) : base(value) => NewEagerTensorHandle();
public NDArray(short value) : base(value) => NewEagerTensorHandle();
@@ -57,6 +58,20 @@ namespace Tensorflow.NumPy
_ => throw new NotImplementedException("")
};

/// <summary>
/// Reuse the existing memory instead of copying it.
/// </summary>
/// <param name="data_ptr"></param>
/// <param name="shape"></param>
/// <param name="dtype"></param>
/// <param name="deallocator"></param>
protected void InitWithExistingMemory(IntPtr data_ptr, Shape shape, TF_DataType dtype, c_api.DeallocatorV2 deallocator)
{
_handle = c_api.TF_NewTensor(TF_DataType.TF_STRING, shape.dims, shape.ndim, data_ptr, (ulong)(shape.size * dtype.get_datatype_size()), deallocator, IntPtr.Zero);
tensor_util.DangerousManuallySetTensorDType(_handle, dtype);
NewEagerTensorHandle();
}

void NewEagerTensorHandle()
{
if (_handle is not null)


+ 1
- 1
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -417,7 +417,7 @@ namespace Tensorflow
{
TF_DataType.TF_DOUBLE => constant(1.0d),
TF_DataType.TF_FLOAT => constant(1.0f),
_ => constant(1)
_ => constant(1, dtype)
};

if (shape.ndim == 0)


+ 10
- 1
src/TensorFlowNET.Core/Tensors/c_api.tensor.cs View File

@@ -71,7 +71,7 @@ namespace Tensorflow
/// <param name="deallocator_arg"></param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern SafeTensorHandle TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, ulong len, Deallocator deallocator, IntPtr deallocator_arg);
public static extern SafeTensorHandle TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, ulong len, DeallocatorV2 deallocator, IntPtr deallocator_arg);

public static unsafe SafeTensorHandle TF_NewTensor(byte[] data, Shape shape, TF_DataType dtype)
{
@@ -147,6 +147,15 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern TF_DataType TF_TensorType(SafeTensorHandle tensor);

/// <summary>
/// Set a new shape for the Tensor. Note that this API only works after tf2.11.
/// </summary>
/// <param name="tensor"></param>
/// <param name="dims"></param>
/// <param name="num_dims"></param>
[DllImport(TensorFlowLibName)]
public static extern void TF_SetShape(SafeTensorHandle tensor, long[] dims, int num_dims);

/// <summary>
/// Return the size in bytes required to encode a string `len` bytes long into a
/// TF_STRING tensor.


+ 20
- 0
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -22,6 +22,7 @@ using System.Text;
using Tensorflow.Eager;
using Tensorflow.Graphs;
using static Tensorflow.Binding;
using System.Diagnostics;

namespace Tensorflow
{
@@ -649,5 +650,24 @@ would not be rank 1.", tensor.op.get_attr("axis")));
NewAxisMask = new_axis_mask
};
}

/// <summary>
/// Warning: this method is an extremely dangerous method. It directly changes the dtype inside the tensor
/// and security is not guaranteed at all. Currently this method is only used for some conditions to reuse
/// the existing memory. Any other usage should be prevented. If you are sure you want to use it when
/// developing tensorflow.net, please ask @Oceanic2018 or @AsakusaRinne first.
/// </summary>
/// <param name="handle"></param>
/// <param name="dtype"></param>
internal static unsafe void DangerousManuallySetTensorDType(SafeTensorHandle handle, TF_DataType dtype)
{
long tf_tensor_address = handle.DangerousGetHandle().ToInt64();
long interface_address = *(long*)(tf_tensor_address);
long tensor_shape_address = interface_address + 8;
long tensor_dtype_address = tensor_shape_address + 13;
byte* dtype_pointer = (byte*)tensor_dtype_address;
*dtype_pointer = (byte)dtype;
Debug.Assert(c_api.TF_TensorType(handle) == dtype);
}
}
}

Loading…
Cancel
Save