Browse Source

Fix allocate tensor for string #171

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
2579afc84a
2 changed files with 10 additions and 18 deletions
  1. +8
    -16
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  2. +2
    -2
      src/TensorFlowNET.Core/Train/Saving/Saver.cs

+ 8
- 16
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -36,6 +36,10 @@ namespace Tensorflow
size = (ulong)(nd.size * nd.dtypesize);
}

var dataType = ToTFDataType(nd.dtype);
// shape
var dims = nd.shape.Select(x => (long)x).ToArray();

switch (nd.dtype.Name)
{
case "Int16":
@@ -51,17 +55,8 @@ namespace Tensorflow
Marshal.Copy(nd.Data<double>(), 0, dotHandle, nd.size);
break;
case "String":
/*var value = nd.Data<string>()[0];
var bytes = Encoding.UTF8.GetBytes(value);
dotHandle = Marshal.AllocHGlobal(bytes.Length + 1);
Marshal.Copy(bytes, 0, dotHandle, bytes.Length);
size = (ulong)bytes.Length;*/

var str = nd.Data<string>()[0];
ulong dst_len = c_api.TF_StringEncodedSize((ulong)str.Length);
//dotHandle = Marshal.AllocHGlobal((int)dst_len);
//size = c_api.TF_StringEncode(str, (ulong)str.Length, dotHandle, dst_len, status);

var dataType1 = ToTFDataType(nd.dtype);
// shape
var dims1 = nd.shape.Select(x => (long)x).ToArray();
@@ -69,19 +64,16 @@ namespace Tensorflow
var tfHandle1 = c_api.TF_AllocateTensor(dataType1,
dims1,
nd.ndim,
dst_len);
dst_len + sizeof(Int64));

dotHandle = c_api.TF_TensorData(tfHandle1);
c_api.TF_StringEncode(str, (ulong)str.Length, dotHandle, dst_len, status);
Marshal.WriteInt64(dotHandle, 0);
c_api.TF_StringEncode(str, (ulong)str.Length, dotHandle + sizeof(Int64), dst_len, status);
return tfHandle1;
break;
default:
throw new NotImplementedException("Marshal.Copy failed.");
}

var dataType = ToTFDataType(nd.dtype);
// shape
var dims = nd.shape.Select(x => (long)x).ToArray();
// Free the original buffer and set flag
Deallocator deallocator = (IntPtr values, IntPtr len, ref bool closure) =>
{


+ 2
- 2
src/TensorFlowNET.Core/Train/Saving/Saver.cs View File

@@ -162,12 +162,12 @@ namespace Tensorflow

if (!_is_empty)
{
var model_checkpoint_path1 = sess.run(_saver_def.SaveTensorName, new FeedItem[] {
model_checkpoint_path = sess.run(_saver_def.SaveTensorName, new FeedItem[] {
new FeedItem(_saver_def.FilenameTensorName, checkpoint_file)
});
}

throw new NotImplementedException("");
throw new NotImplementedException("Saver.save");

return model_checkpoint_path;
}


Loading…
Cancel
Save