|
|
@@ -486,7 +486,7 @@ namespace Tensorflow |
|
|
|
{ |
|
|
|
if (tensorDType == TF_DataType.TF_STRING && nd.dtype.Name == "Byte") |
|
|
|
{ |
|
|
|
var buffer = nd.ToArray<byte>(); |
|
|
|
var buffer = nd.Data<byte>(); |
|
|
|
var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); |
|
|
|
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); |
|
|
|
|
|
|
@@ -524,30 +524,29 @@ namespace Tensorflow |
|
|
|
switch (nd.dtype.Name) |
|
|
|
{ |
|
|
|
case "Boolean": |
|
|
|
var boolVals = Array.ConvertAll(nd1.ToArray<bool>(), x => Convert.ToByte(x)); |
|
|
|
var boolVals = Array.ConvertAll(nd1.Data<bool>(), x => Convert.ToByte(x)); |
|
|
|
Marshal.Copy(boolVals, 0, dotHandle, nd.size); |
|
|
|
break; |
|
|
|
case "Int16": |
|
|
|
Marshal.Copy(nd1.ToArray<short>(), 0, dotHandle, nd.size); |
|
|
|
Marshal.Copy(nd1.Data<short>(), 0, dotHandle, nd.size); |
|
|
|
break; |
|
|
|
case "Int32": |
|
|
|
Marshal.Copy(nd1.ToArray<int>(), 0, dotHandle, nd.size); |
|
|
|
Marshal.Copy(nd1.Data<int>(), 0, dotHandle, nd.size); |
|
|
|
break; |
|
|
|
case "Int64": |
|
|
|
Marshal.Copy(nd1.ToArray<long>(), 0, dotHandle, nd.size); |
|
|
|
Marshal.Copy(nd1.Data<long>(), 0, dotHandle, nd.size); |
|
|
|
break; |
|
|
|
case "Single": |
|
|
|
Marshal.Copy(nd1.ToArray<float>(), 0, dotHandle, nd.size); |
|
|
|
Marshal.Copy(nd1.Data<float>(), 0, dotHandle, nd.size); |
|
|
|
break; |
|
|
|
case "Double": |
|
|
|
Marshal.Copy(nd1.ToArray<double>(), 0, dotHandle, nd.size); |
|
|
|
Marshal.Copy(nd1.Data<double>(), 0, dotHandle, nd.size); |
|
|
|
break; |
|
|
|
case "Byte": |
|
|
|
Marshal.Copy(nd1.ToArray<byte>(), 0, dotHandle, nd.size); |
|
|
|
Marshal.Copy(nd1.Data<byte>(), 0, dotHandle, nd.size); |
|
|
|
break; |
|
|
|
case "String": |
|
|
|
throw new NotImplementedException($"Marshal.Copy failed for {nd.dtype.Name}."); |
|
|
|
//return new Tensor(UTF8Encoding.UTF8.GetBytes(nd.Data<string>(0)), TF_DataType.TF_STRING); |
|
|
|
return new Tensor(UTF8Encoding.UTF8.GetBytes(nd.Data<string>(0)), TF_DataType.TF_STRING); |
|
|
|
default: |
|
|
|
throw new NotImplementedException($"Marshal.Copy failed for {nd.dtype.Name}."); |
|
|
|
} |
|
|
|