|
|
@@ -69,6 +69,7 @@ namespace Tensorflow |
|
|
|
private IntPtr Allocate(NDArray nd) |
|
|
|
{ |
|
|
|
var dotHandle = Marshal.AllocHGlobal(nd.dtypesize * nd.size); |
|
|
|
ulong size = (ulong)(nd.size * nd.dtypesize); |
|
|
|
|
|
|
|
switch (nd.dtype.Name) |
|
|
|
{ |
|
|
@@ -81,16 +82,21 @@ namespace Tensorflow |
|
|
|
case "Double": |
|
|
|
Marshal.Copy(nd.Data<double>(), 0, dotHandle, nd.size); |
|
|
|
break; |
|
|
|
case "String": |
|
|
|
dotHandle = Marshal.StringToHGlobalAuto(nd.Data<string>()[0]); |
|
|
|
size = (ulong)nd.Data<string>()[0].Length; |
|
|
|
break; |
|
|
|
default: |
|
|
|
throw new NotImplementedException("Marshal.Copy failed."); |
|
|
|
} |
|
|
|
|
|
|
|
var dataType = ToTFDataType(nd.dtype); |
|
|
|
|
|
|
|
var tfHandle = c_api.TF_NewTensor(dataType, |
|
|
|
nd.shape.Select(x => (long)x).ToArray(), // shape |
|
|
|
nd.ndim, |
|
|
|
dotHandle, |
|
|
|
(ulong)(nd.size * nd.dtypesize), |
|
|
|
size, |
|
|
|
(IntPtr values, IntPtr len, ref bool closure) => |
|
|
|
{ |
|
|
|
// Free the original buffer and set flag |
|
|
@@ -154,6 +160,8 @@ namespace Tensorflow |
|
|
|
return TF_DataType.TF_FLOAT; |
|
|
|
case "Double": |
|
|
|
return TF_DataType.TF_DOUBLE; |
|
|
|
case "String": |
|
|
|
return TF_DataType.TF_STRING; |
|
|
|
} |
|
|
|
|
|
|
|
return TF_DataType.DtInvalid; |
|
|
|