Browse Source

Malformed TF_STRING tensor; element 0 out of range

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
8ae2feb5fb
8 changed files with 131 additions and 94 deletions
  1. +4
    -4
      src/TensorFlowNET.Core/Graphs/Graph.cs
  2. +3
    -2
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  3. +111
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  4. +0
    -80
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  5. +2
    -2
      src/TensorFlowNET.Core/Tensors/c_api.tensor.cs
  6. +3
    -1
      src/TensorFlowNET.Core/Train/Saving/Saver.cs
  7. +8
    -4
      test/TensorFlowNET.UnitTest/ConstantTest.cs
  8. +0
    -1
      test/TensorFlowNET.UnitTest/TrainSaverTest.cs

+ 4
- 4
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -76,12 +76,12 @@ namespace Tensorflow
obj = temp_obj;

// If obj appears to be a name...
if (obj is String str)
if (obj is string name)
{
if(str.Contains(":") && allow_tensor)
if(name.Contains(":") && allow_tensor)
{
string op_name = str.Split(':')[0];
int out_n = int.Parse(str.Split(':')[1]);
string op_name = name.Split(':')[0];
int out_n = int.Parse(name.Split(':')[1]);

if (_nodes_by_name.ContainsKey(op_name))
return _nodes_by_name[op_name].outputs[out_n];


+ 3
- 2
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -67,7 +67,7 @@ namespace Tensorflow
default:
throw new NotImplementedException("_run subfeed");
}
feed_map[subfeed_t.name] = new Tuple<object, object>(subfeed_t, subfeed.Value);
feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value);
}
}

@@ -178,7 +178,8 @@ namespace Tensorflow
case TF_DataType.TF_STRING:
var bytes = tensor.Data();
// wired, don't know why we have to start from offset 9.
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes.Length - 9);
// length in the begin
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
nd = np.array(str).reshape();
break;
case TF_DataType.TF_INT16:


+ 111
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -0,0 +1,111 @@
using NumSharp.Core;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using static Tensorflow.c_api;

namespace Tensorflow
{
public partial class Tensor
{
/// <summary>
/// if original buffer is free.
/// </summary>
private bool deallocator_called;

public Tensor(IntPtr handle)
{
_handle = handle;
}

public Tensor(NDArray nd)
{
_handle = Allocate(nd);
}

private IntPtr Allocate(NDArray nd)
{
IntPtr dotHandle = IntPtr.Zero;
ulong size = 0;

if (nd.dtype.Name != "String")
{
dotHandle = Marshal.AllocHGlobal(nd.dtypesize * nd.size);
size = (ulong)(nd.size * nd.dtypesize);
}

switch (nd.dtype.Name)
{
case "Int16":
Marshal.Copy(nd.Data<short>(), 0, dotHandle, nd.size);
break;
case "Int32":
Marshal.Copy(nd.Data<int>(), 0, dotHandle, nd.size);
break;
case "Single":
Marshal.Copy(nd.Data<float>(), 0, dotHandle, nd.size);
break;
case "Double":
Marshal.Copy(nd.Data<double>(), 0, dotHandle, nd.size);
break;
case "String":
/*var value = nd.Data<string>()[0];
var bytes = Encoding.UTF8.GetBytes(value);
dotHandle = Marshal.AllocHGlobal(bytes.Length + 1);
Marshal.Copy(bytes, 0, dotHandle, bytes.Length);
size = (ulong)bytes.Length;*/

var str = nd.Data<string>()[0];
ulong dst_len = c_api.TF_StringEncodedSize((ulong)str.Length);
//dotHandle = Marshal.AllocHGlobal((int)dst_len);
//size = c_api.TF_StringEncode(str, (ulong)str.Length, dotHandle, dst_len, status);

var dataType1 = ToTFDataType(nd.dtype);
// shape
var dims1 = nd.shape.Select(x => (long)x).ToArray();

var tfHandle1 = c_api.TF_AllocateTensor(dataType1,
dims1,
nd.ndim,
dst_len);

dotHandle = c_api.TF_TensorData(tfHandle1);
c_api.TF_StringEncode(str, (ulong)str.Length, dotHandle, dst_len, status);
return tfHandle1;
break;
default:
throw new NotImplementedException("Marshal.Copy failed.");
}

var dataType = ToTFDataType(nd.dtype);
// shape
var dims = nd.shape.Select(x => (long)x).ToArray();
// Free the original buffer and set flag
Deallocator deallocator = (IntPtr values, IntPtr len, ref bool closure) =>
{
Marshal.FreeHGlobal(dotHandle);
closure = true;
};

var tfHandle = c_api.TF_NewTensor(dataType,
dims,
nd.ndim,
dotHandle,
size,
deallocator,
ref deallocator_called);

return tfHandle;
}

public Tensor(Operation op, int value_index, TF_DataType dtype)
{
this.op = op;
this.value_index = value_index;
this._dtype = dtype;
_id = ops.uid();
}
}
}

+ 0
- 80
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -95,86 +95,6 @@ namespace Tensorflow

public int NDims => rank;

/// <summary>
/// if original buffer is free.
/// </summary>
private bool deallocator_called;

public Tensor(IntPtr handle)
{
_handle = handle;
}

public Tensor(NDArray nd)
{
_handle = Allocate(nd);
}

private IntPtr Allocate(NDArray nd)
{
IntPtr dotHandle = IntPtr.Zero;
ulong size = 0;

if (nd.dtype.Name != "String")
{
dotHandle = Marshal.AllocHGlobal(nd.dtypesize * nd.size);
size = (ulong)(nd.size * nd.dtypesize);
}
switch (nd.dtype.Name)
{
case "Int16":
Marshal.Copy(nd.Data<short>(), 0, dotHandle, nd.size);
break;
case "Int32":
Marshal.Copy(nd.Data<int>(), 0, dotHandle, nd.size);
break;
case "Single":
Marshal.Copy(nd.Data<float>(), 0, dotHandle, nd.size);
break;
case "Double":
Marshal.Copy(nd.Data<double>(), 0, dotHandle, nd.size);
break;
case "String":
var value = nd.Data<string>()[0];
var bytes = Encoding.UTF8.GetBytes(value);
dotHandle = Marshal.AllocHGlobal(bytes.Length + 1);
Marshal.Copy(bytes, 0, dotHandle, bytes.Length);
size = (ulong)bytes.Length;
break;
default:
throw new NotImplementedException("Marshal.Copy failed.");
}

var dataType = ToTFDataType(nd.dtype);
// shape
var dims = nd.shape.Select(x => (long)x).ToArray();
// Free the original buffer and set flag
Deallocator deallocator = (IntPtr values, IntPtr len, ref bool closure) =>
{
Marshal.FreeHGlobal(dotHandle);
closure = true;
};

var tfHandle = c_api.TF_NewTensor(dataType,
dims,
nd.ndim,
dotHandle,
size,
deallocator,
ref deallocator_called);

return tfHandle;
}

public Tensor(Operation op, int value_index, TF_DataType dtype)
{
this.op = op;
this.value_index = value_index;
this._dtype = dtype;
_id = ops.uid();
}

public Operation[] Consumers => consumers();

public string Device => op.Device;


+ 2
- 2
src/TensorFlowNET.Core/Tensors/c_api.tensor.cs View File

@@ -120,7 +120,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param>
/// <returns>On success returns the size in bytes of the encoded string.</returns>
[DllImport(TensorFlowLibName)]
public static extern ulong TF_StringEncode(string src, ulong src_len, string dst, ulong dst_len, IntPtr status);
public static extern ulong TF_StringEncode(string src, ulong src_len, IntPtr dst, ulong dst_len, IntPtr status);

/// <summary>
/// Decode a string encoded using TF_StringEncode.
@@ -132,6 +132,6 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern ulong TF_StringDecode(string src, ulong src_len, IntPtr dst, ref ulong dst_len, IntPtr status);
public static extern ulong TF_StringDecode(IntPtr src, ulong src_len, IntPtr dst, ref ulong dst_len, IntPtr status);
}
}

+ 3
- 1
src/TensorFlowNET.Core/Train/Saving/Saver.cs View File

@@ -138,12 +138,14 @@ namespace Tensorflow
public string save(Session sess,
string save_path,
string global_step = "",
string latest_filename = "",
string meta_graph_suffix = "meta",
bool write_meta_graph = true,
bool write_state = true,
bool strip_default_attrs = false)
{
string latest_filename = "checkpoint";
if (string.IsNullOrEmpty(latest_filename))
latest_filename = "checkpoint";
string model_checkpoint_path = "";
string checkpoint_file = "";



+ 8
- 4
test/TensorFlowNET.UnitTest/ConstantTest.cs View File

@@ -3,6 +3,7 @@ using NumSharp.Core;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using Tensorflow;

@@ -104,11 +105,14 @@ namespace TensorFlowNET.UnitTest
string str = "Hello, TensorFlow.NET!";
ulong dst_len = c_api.TF_StringEncodedSize((ulong)str.Length);
Assert.AreEqual(dst_len, (ulong)23);
string dst = "";
c_api.TF_StringEncode(str, (ulong)str.Length, dst, dst_len, status);
IntPtr dst = Marshal.AllocHGlobal((int)dst_len);
ulong encoded_len = c_api.TF_StringEncode(str, (ulong)str.Length, dst, dst_len, status);
Assert.AreEqual((ulong)23, encoded_len);
Assert.AreEqual(status.Code, TF_Code.TF_OK);

//c_api.TF_StringDecode(str, (ulong)str.Length, IntPtr.Zero, ref dst_len, status);
string encoded_str = Marshal.PtrToStringUTF8(dst + sizeof(byte));
Assert.AreEqual(encoded_str, str);
Assert.AreEqual(str.Length, Marshal.ReadByte(dst));
//c_api.TF_StringDecode(dst, (ulong)str.Length, IntPtr.Zero, ref dst_len, status);
}

/// <summary>


+ 0
- 1
test/TensorFlowNET.UnitTest/TrainSaverTest.cs View File

@@ -45,7 +45,6 @@ namespace TensorFlowNET.UnitTest
});
}

[TestMethod]
public void Save2()
{
var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer);


Loading…
Cancel
Save