Browse Source

Tensor: correctly pass unmanaged ptr of NDArray to TF

tags/v0.12
Meinrad Recheis 6 years ago
parent
commit
1ac115623c
1 changed files with 22 additions and 34 deletions
  1. +22
    -34
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs

+ 22
- 34
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -21,6 +21,7 @@ using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
using NumSharp.Backends.Unmanaged;
using static Tensorflow.c_api;

namespace Tensorflow
@@ -484,6 +485,7 @@ namespace Tensorflow

public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null)
{
// todo: handle nd of type "String" here too
if (tensorDType == TF_DataType.TF_STRING && nd.dtype.Name == "Byte")
{
var buffer = nd.ToArray<byte>();
@@ -502,47 +504,33 @@ namespace Tensorflow
IsMemoryOwner = false;
return;
}
_handle = Allocate(nd, tensorDType: tensorDType);
_handle = CreateTensorFromNDArray(nd, tensorDType);
IsMemoryOwner = true;
}

private unsafe IntPtr Allocate(NDArray nd, TF_DataType? tensorDType = null)
private unsafe IntPtr CreateTensorFromNDArray(NDArray nd, TF_DataType? given_dtype)
{
IntPtr dotHandle = IntPtr.Zero;
int buffersize = 0;

if (nd.dtype.Name != "String")
if (nd.dtype.Name == "String")
throw new NotImplementedException("Support for NDArray of type string not implemented yet");
IArraySlice arraySlice;
var shape = nd.Unsafe.Storage.Shape;
if (shape.IsSliced || shape.IsBroadcasted)
{
buffersize = (nd.size * nd.dtypesize);
// dotHandle = Marshal.AllocHGlobal(buffersize);
// the memory is NOT contiguous, so we have to copy the view into a contiguous memory block.
arraySlice = nd.CloneData();
}

var dataType = ToTFDataType(nd.dtype);
// shape
var dims = nd.shape.Select(x => (long)x).ToArray();
// var nd1 = nd.ravel();
/*switch (nd.dtype.Name)
else
{
case "Boolean":
var boolVals = Array.ConvertAll(nd1.ToArray<bool>(), x => Convert.ToByte(x));
Marshal.Copy(boolVals, 0, dotHandle, nd.size);
break;
case "String":
throw new NotImplementedException($"Marshal.Copy failed for {nd.dtype.Name}.");
//return new Tensor(UTF8Encoding.UTF8.GetBytes(nd.ToArray<string>(0)), TF_DataType.TF_STRING);
default:
System.Buffer.MemoryCopy(nd1.Unsafe.Address, dotHandle.ToPointer(), nd.size, nd.size);
break;
}*/
var tfHandle = c_api.TF_NewTensor(dataType,
dims,
dims.Length,
new IntPtr(nd.Unsafe.Address),
(UIntPtr)buffersize,
_hGlobalDeallocator,
ref _deallocatorArgs);

return tfHandle;
// the memory is contiguous
arraySlice = nd.GetData();
}
this.Tag = arraySlice; // keep a reference to the memory block to make sure it is not disposed while TF is using it
var ptr = new IntPtr(arraySlice.Address);
int num_bytes = (nd.size * nd.dtypesize);
var dtype = given_dtype ?? ToTFDataType(nd.dtype);
var handle = TF_NewTensor(dtype, dims: nd.shape.Select(i=>(long)i).ToArray(), num_dims: nd.ndim, data: ptr, len: (UIntPtr)num_bytes, deallocator: _nothingDeallocator, ref _deallocatorArgs);
IsMemoryOwner = false;
return handle;
}

public unsafe Tensor(byte[][] buffer, long[] shape)


Loading…
Cancel
Save