@@ -86,6 +86,9 @@ namespace Tensorflow | |||||
case NPTypeCode.Char: | case NPTypeCode.Char: | ||||
full_values.Add(float.NaN); | full_values.Add(float.NaN); | ||||
break; | break; | ||||
case NPTypeCode.Byte: | |||||
full_values.Add(float.NaN); | |||||
break; | |||||
default: | default: | ||||
throw new NotImplementedException($"build_results tensor_values[0] {tensor_values[0].dtype.Name}"); | throw new NotImplementedException($"build_results tensor_values[0] {tensor_values[0].dtype.Name}"); | ||||
} | } | ||||
@@ -138,7 +138,7 @@ namespace Tensorflow | |||||
dtype = TF_DataType.TF_INT8; | dtype = TF_DataType.TF_INT8; | ||||
break; | break; | ||||
case "Byte": | case "Byte": | ||||
dtype = TF_DataType.TF_UINT8; | |||||
dtype = dtype ?? TF_DataType.TF_UINT8; | |||||
break; | break; | ||||
case "Int16": | case "Int16": | ||||
dtype = TF_DataType.TF_INT16; | dtype = TF_DataType.TF_INT16; | ||||
@@ -18,6 +18,7 @@ using NumSharp; | |||||
using System; | using System; | ||||
using System.Linq; | using System.Linq; | ||||
using NumSharp.Utilities; | using NumSharp.Utilities; | ||||
using System.Text; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -256,7 +257,7 @@ namespace Tensorflow | |||||
nd = np.array(doubleVals); | nd = np.array(doubleVals); | ||||
break; | break; | ||||
case string strVal: | case string strVal: | ||||
nd = NDArray.FromString(strVal); | |||||
nd = new NDArray(Encoding.ASCII.GetBytes(strVal)); | |||||
break; | break; | ||||
case string[] strVals: | case string[] strVals: | ||||
nd = strVals; | nd = strVals; | ||||
@@ -19,6 +19,7 @@ using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.IO; | using System.IO; | ||||
using System.Linq; | using System.Linq; | ||||
using System.Text; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -188,10 +189,12 @@ namespace Tensorflow | |||||
if (write_state) | if (write_state) | ||||
{ | { | ||||
var path = UTF8Encoding.UTF8.GetString((byte[])model_checkpoint_path[0]); | |||||
_RecordLastCheckpoint(path); | |||||
_RecordLastCheckpoint(model_checkpoint_path[0].ToString()); | _RecordLastCheckpoint(model_checkpoint_path[0].ToString()); | ||||
checkpoint_management.update_checkpoint_state_internal( | checkpoint_management.update_checkpoint_state_internal( | ||||
save_dir: save_path_parent, | save_dir: save_path_parent, | ||||
model_checkpoint_path: model_checkpoint_path[0].ToString(), | |||||
model_checkpoint_path: path, | |||||
all_model_checkpoint_paths: _last_checkpoints.Keys.Select(x => x).ToList(), | all_model_checkpoint_paths: _last_checkpoints.Keys.Select(x => x).ToList(), | ||||
latest_filename: latest_filename, | latest_filename: latest_filename, | ||||
save_relative_paths: _save_relative_paths); | save_relative_paths: _save_relative_paths); | ||||
@@ -205,7 +208,7 @@ namespace Tensorflow | |||||
export_meta_graph(meta_graph_filename, strip_default_attrs: strip_default_attrs, save_debug_info: save_debug_info); | export_meta_graph(meta_graph_filename, strip_default_attrs: strip_default_attrs, save_debug_info: save_debug_info); | ||||
} | } | ||||
return _is_empty ? string.Empty : model_checkpoint_path.ToString(); | |||||
return _is_empty ? string.Empty : UTF8Encoding.UTF8.GetString((byte[])model_checkpoint_path[0]); | |||||
} | } | ||||
public (Saver, object) import_meta_graph(string meta_graph_or_file, | public (Saver, object) import_meta_graph(string meta_graph_or_file, | ||||
@@ -109,7 +109,7 @@ namespace TensorFlowNET.UnitTest | |||||
var c = tf.strings.substr(a, 4, 8); | var c = tf.strings.substr(a, 4, 8); | ||||
using (var sess = tf.Session()) | using (var sess = tf.Session()) | ||||
{ | { | ||||
var result = (string) c.eval(sess); | |||||
var result = UTF8Encoding.UTF8.GetString((byte[])c.eval(sess)); | |||||
Console.WriteLine(result); | Console.WriteLine(result); | ||||
result.Should().Be("heythere"); | result.Should().Be("heythere"); | ||||
} | } | ||||