@@ -76,12 +76,12 @@ namespace Tensorflow | |||
obj = temp_obj; | |||
// If obj appears to be a name... | |||
if (obj is String str) | |||
if (obj is string name) | |||
{ | |||
if(str.Contains(":") && allow_tensor) | |||
if(name.Contains(":") && allow_tensor) | |||
{ | |||
string op_name = str.Split(':')[0]; | |||
int out_n = int.Parse(str.Split(':')[1]); | |||
string op_name = name.Split(':')[0]; | |||
int out_n = int.Parse(name.Split(':')[1]); | |||
if (_nodes_by_name.ContainsKey(op_name)) | |||
return _nodes_by_name[op_name].outputs[out_n]; | |||
@@ -67,7 +67,7 @@ namespace Tensorflow | |||
default: | |||
throw new NotImplementedException("_run subfeed"); | |||
} | |||
feed_map[subfeed_t.name] = new Tuple<object, object>(subfeed_t, subfeed.Value); | |||
feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value); | |||
} | |||
} | |||
@@ -178,7 +178,8 @@ namespace Tensorflow | |||
case TF_DataType.TF_STRING: | |||
var bytes = tensor.Data(); | |||
// wired, don't know why we have to start from offset 9. | |||
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes.Length - 9); | |||
// length in the begin | |||
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | |||
nd = np.array(str).reshape(); | |||
break; | |||
case TF_DataType.TF_INT16: | |||
@@ -0,0 +1,111 @@ | |||
using NumSharp.Core; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Runtime.InteropServices; | |||
using System.Text; | |||
using static Tensorflow.c_api; | |||
namespace Tensorflow | |||
{ | |||
public partial class Tensor | |||
{ | |||
/// <summary> | |||
/// if original buffer is free. | |||
/// </summary> | |||
private bool deallocator_called; | |||
public Tensor(IntPtr handle) | |||
{ | |||
_handle = handle; | |||
} | |||
public Tensor(NDArray nd) | |||
{ | |||
_handle = Allocate(nd); | |||
} | |||
private IntPtr Allocate(NDArray nd) | |||
{ | |||
IntPtr dotHandle = IntPtr.Zero; | |||
ulong size = 0; | |||
if (nd.dtype.Name != "String") | |||
{ | |||
dotHandle = Marshal.AllocHGlobal(nd.dtypesize * nd.size); | |||
size = (ulong)(nd.size * nd.dtypesize); | |||
} | |||
switch (nd.dtype.Name) | |||
{ | |||
case "Int16": | |||
Marshal.Copy(nd.Data<short>(), 0, dotHandle, nd.size); | |||
break; | |||
case "Int32": | |||
Marshal.Copy(nd.Data<int>(), 0, dotHandle, nd.size); | |||
break; | |||
case "Single": | |||
Marshal.Copy(nd.Data<float>(), 0, dotHandle, nd.size); | |||
break; | |||
case "Double": | |||
Marshal.Copy(nd.Data<double>(), 0, dotHandle, nd.size); | |||
break; | |||
case "String": | |||
/*var value = nd.Data<string>()[0]; | |||
var bytes = Encoding.UTF8.GetBytes(value); | |||
dotHandle = Marshal.AllocHGlobal(bytes.Length + 1); | |||
Marshal.Copy(bytes, 0, dotHandle, bytes.Length); | |||
size = (ulong)bytes.Length;*/ | |||
var str = nd.Data<string>()[0]; | |||
ulong dst_len = c_api.TF_StringEncodedSize((ulong)str.Length); | |||
//dotHandle = Marshal.AllocHGlobal((int)dst_len); | |||
//size = c_api.TF_StringEncode(str, (ulong)str.Length, dotHandle, dst_len, status); | |||
var dataType1 = ToTFDataType(nd.dtype); | |||
// shape | |||
var dims1 = nd.shape.Select(x => (long)x).ToArray(); | |||
var tfHandle1 = c_api.TF_AllocateTensor(dataType1, | |||
dims1, | |||
nd.ndim, | |||
dst_len); | |||
dotHandle = c_api.TF_TensorData(tfHandle1); | |||
c_api.TF_StringEncode(str, (ulong)str.Length, dotHandle, dst_len, status); | |||
return tfHandle1; | |||
break; | |||
default: | |||
throw new NotImplementedException("Marshal.Copy failed."); | |||
} | |||
var dataType = ToTFDataType(nd.dtype); | |||
// shape | |||
var dims = nd.shape.Select(x => (long)x).ToArray(); | |||
// Free the original buffer and set flag | |||
Deallocator deallocator = (IntPtr values, IntPtr len, ref bool closure) => | |||
{ | |||
Marshal.FreeHGlobal(dotHandle); | |||
closure = true; | |||
}; | |||
var tfHandle = c_api.TF_NewTensor(dataType, | |||
dims, | |||
nd.ndim, | |||
dotHandle, | |||
size, | |||
deallocator, | |||
ref deallocator_called); | |||
return tfHandle; | |||
} | |||
public Tensor(Operation op, int value_index, TF_DataType dtype) | |||
{ | |||
this.op = op; | |||
this.value_index = value_index; | |||
this._dtype = dtype; | |||
_id = ops.uid(); | |||
} | |||
} | |||
} |
@@ -95,86 +95,6 @@ namespace Tensorflow | |||
public int NDims => rank; | |||
/// <summary> | |||
/// if original buffer is free. | |||
/// </summary> | |||
private bool deallocator_called; | |||
public Tensor(IntPtr handle) | |||
{ | |||
_handle = handle; | |||
} | |||
public Tensor(NDArray nd) | |||
{ | |||
_handle = Allocate(nd); | |||
} | |||
private IntPtr Allocate(NDArray nd) | |||
{ | |||
IntPtr dotHandle = IntPtr.Zero; | |||
ulong size = 0; | |||
if (nd.dtype.Name != "String") | |||
{ | |||
dotHandle = Marshal.AllocHGlobal(nd.dtypesize * nd.size); | |||
size = (ulong)(nd.size * nd.dtypesize); | |||
} | |||
switch (nd.dtype.Name) | |||
{ | |||
case "Int16": | |||
Marshal.Copy(nd.Data<short>(), 0, dotHandle, nd.size); | |||
break; | |||
case "Int32": | |||
Marshal.Copy(nd.Data<int>(), 0, dotHandle, nd.size); | |||
break; | |||
case "Single": | |||
Marshal.Copy(nd.Data<float>(), 0, dotHandle, nd.size); | |||
break; | |||
case "Double": | |||
Marshal.Copy(nd.Data<double>(), 0, dotHandle, nd.size); | |||
break; | |||
case "String": | |||
var value = nd.Data<string>()[0]; | |||
var bytes = Encoding.UTF8.GetBytes(value); | |||
dotHandle = Marshal.AllocHGlobal(bytes.Length + 1); | |||
Marshal.Copy(bytes, 0, dotHandle, bytes.Length); | |||
size = (ulong)bytes.Length; | |||
break; | |||
default: | |||
throw new NotImplementedException("Marshal.Copy failed."); | |||
} | |||
var dataType = ToTFDataType(nd.dtype); | |||
// shape | |||
var dims = nd.shape.Select(x => (long)x).ToArray(); | |||
// Free the original buffer and set flag | |||
Deallocator deallocator = (IntPtr values, IntPtr len, ref bool closure) => | |||
{ | |||
Marshal.FreeHGlobal(dotHandle); | |||
closure = true; | |||
}; | |||
var tfHandle = c_api.TF_NewTensor(dataType, | |||
dims, | |||
nd.ndim, | |||
dotHandle, | |||
size, | |||
deallocator, | |||
ref deallocator_called); | |||
return tfHandle; | |||
} | |||
public Tensor(Operation op, int value_index, TF_DataType dtype) | |||
{ | |||
this.op = op; | |||
this.value_index = value_index; | |||
this._dtype = dtype; | |||
_id = ops.uid(); | |||
} | |||
public Operation[] Consumers => consumers(); | |||
public string Device => op.Device; | |||
@@ -120,7 +120,7 @@ namespace Tensorflow | |||
/// <param name="status">TF_Status*</param> | |||
/// <returns>On success returns the size in bytes of the encoded string.</returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern ulong TF_StringEncode(string src, ulong src_len, string dst, ulong dst_len, IntPtr status); | |||
public static extern ulong TF_StringEncode(string src, ulong src_len, IntPtr dst, ulong dst_len, IntPtr status); | |||
/// <summary> | |||
/// Decode a string encoded using TF_StringEncode. | |||
@@ -132,6 +132,6 @@ namespace Tensorflow | |||
/// <param name="status">TF_Status*</param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern ulong TF_StringDecode(string src, ulong src_len, IntPtr dst, ref ulong dst_len, IntPtr status); | |||
public static extern ulong TF_StringDecode(IntPtr src, ulong src_len, IntPtr dst, ref ulong dst_len, IntPtr status); | |||
} | |||
} |
@@ -138,12 +138,14 @@ namespace Tensorflow | |||
public string save(Session sess, | |||
string save_path, | |||
string global_step = "", | |||
string latest_filename = "", | |||
string meta_graph_suffix = "meta", | |||
bool write_meta_graph = true, | |||
bool write_state = true, | |||
bool strip_default_attrs = false) | |||
{ | |||
string latest_filename = "checkpoint"; | |||
if (string.IsNullOrEmpty(latest_filename)) | |||
latest_filename = "checkpoint"; | |||
string model_checkpoint_path = ""; | |||
string checkpoint_file = ""; | |||
@@ -3,6 +3,7 @@ using NumSharp.Core; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Runtime.InteropServices; | |||
using System.Text; | |||
using Tensorflow; | |||
@@ -104,11 +105,14 @@ namespace TensorFlowNET.UnitTest | |||
string str = "Hello, TensorFlow.NET!"; | |||
ulong dst_len = c_api.TF_StringEncodedSize((ulong)str.Length); | |||
Assert.AreEqual(dst_len, (ulong)23); | |||
string dst = ""; | |||
c_api.TF_StringEncode(str, (ulong)str.Length, dst, dst_len, status); | |||
IntPtr dst = Marshal.AllocHGlobal((int)dst_len); | |||
ulong encoded_len = c_api.TF_StringEncode(str, (ulong)str.Length, dst, dst_len, status); | |||
Assert.AreEqual((ulong)23, encoded_len); | |||
Assert.AreEqual(status.Code, TF_Code.TF_OK); | |||
//c_api.TF_StringDecode(str, (ulong)str.Length, IntPtr.Zero, ref dst_len, status); | |||
string encoded_str = Marshal.PtrToStringUTF8(dst + sizeof(byte)); | |||
Assert.AreEqual(encoded_str, str); | |||
Assert.AreEqual(str.Length, Marshal.ReadByte(dst)); | |||
//c_api.TF_StringDecode(dst, (ulong)str.Length, IntPtr.Zero, ref dst_len, status); | |||
} | |||
/// <summary> | |||
@@ -45,7 +45,6 @@ namespace TensorFlowNET.UnitTest | |||
}); | |||
} | |||
[TestMethod] | |||
public void Save2() | |||
{ | |||
var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer); | |||