diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index 4755454e..560b9536 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -36,6 +36,10 @@ namespace Tensorflow size = (ulong)(nd.size * nd.dtypesize); } + var dataType = ToTFDataType(nd.dtype); + // shape + var dims = nd.shape.Select(x => (long)x).ToArray(); + switch (nd.dtype.Name) { case "Int16": @@ -51,17 +55,8 @@ namespace Tensorflow Marshal.Copy(nd.Data(), 0, dotHandle, nd.size); break; case "String": - /*var value = nd.Data()[0]; - var bytes = Encoding.UTF8.GetBytes(value); - dotHandle = Marshal.AllocHGlobal(bytes.Length + 1); - Marshal.Copy(bytes, 0, dotHandle, bytes.Length); - size = (ulong)bytes.Length;*/ - var str = nd.Data()[0]; ulong dst_len = c_api.TF_StringEncodedSize((ulong)str.Length); - //dotHandle = Marshal.AllocHGlobal((int)dst_len); - //size = c_api.TF_StringEncode(str, (ulong)str.Length, dotHandle, dst_len, status); - var dataType1 = ToTFDataType(nd.dtype); // shape var dims1 = nd.shape.Select(x => (long)x).ToArray(); @@ -69,19 +64,16 @@ namespace Tensorflow var tfHandle1 = c_api.TF_AllocateTensor(dataType1, dims1, nd.ndim, - dst_len); + dst_len + sizeof(Int64)); dotHandle = c_api.TF_TensorData(tfHandle1); - c_api.TF_StringEncode(str, (ulong)str.Length, dotHandle, dst_len, status); + Marshal.WriteInt64(dotHandle, 0); + c_api.TF_StringEncode(str, (ulong)str.Length, dotHandle + sizeof(Int64), dst_len, status); return tfHandle1; - break; default: throw new NotImplementedException("Marshal.Copy failed."); } - - var dataType = ToTFDataType(nd.dtype); - // shape - var dims = nd.shape.Select(x => (long)x).ToArray(); + // Free the original buffer and set flag Deallocator deallocator = (IntPtr values, IntPtr len, ref bool closure) => { diff --git a/src/TensorFlowNET.Core/Train/Saving/Saver.cs b/src/TensorFlowNET.Core/Train/Saving/Saver.cs index 7ec46172..b1ed3322 100644 --- a/src/TensorFlowNET.Core/Train/Saving/Saver.cs +++ b/src/TensorFlowNET.Core/Train/Saving/Saver.cs @@ -162,12 +162,12 @@ namespace Tensorflow if (!_is_empty) { - var model_checkpoint_path1 = sess.run(_saver_def.SaveTensorName, new FeedItem[] { + model_checkpoint_path = sess.run(_saver_def.SaveTensorName, new FeedItem[] { new FeedItem(_saver_def.FilenameTensorName, checkpoint_file) }); } - throw new NotImplementedException(""); + throw new NotImplementedException("Saver.save"); return model_checkpoint_path; }