Browse Source

ByteString #479

tags/v0.20
Oceania2018 5 years ago
parent
commit
06c08c93de
5 changed files with 12 additions and 5 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  3. +2
    -1
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  4. +5
    -2
      src/TensorFlowNET.Core/Training/Saving/Saver.cs
  5. +1
    -1
      test/TensorFlowNET.UnitTest/SessionTest.cs

+ 3
- 0
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -86,6 +86,9 @@ namespace Tensorflow
case NPTypeCode.Char: case NPTypeCode.Char:
full_values.Add(float.NaN); full_values.Add(float.NaN);
break; break;
case NPTypeCode.Byte:
full_values.Add(float.NaN);
break;
default: default:
throw new NotImplementedException($"build_results tensor_values[0] {tensor_values[0].dtype.Name}"); throw new NotImplementedException($"build_results tensor_values[0] {tensor_values[0].dtype.Name}");
} }


+ 1
- 1
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -138,7 +138,7 @@ namespace Tensorflow
dtype = TF_DataType.TF_INT8; dtype = TF_DataType.TF_INT8;
break; break;
case "Byte": case "Byte":
dtype = TF_DataType.TF_UINT8;
dtype = dtype ?? TF_DataType.TF_UINT8;
break; break;
case "Int16": case "Int16":
dtype = TF_DataType.TF_INT16; dtype = TF_DataType.TF_INT16;


+ 2
- 1
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -18,6 +18,7 @@ using NumSharp;
using System; using System;
using System.Linq; using System.Linq;
using NumSharp.Utilities; using NumSharp.Utilities;
using System.Text;


namespace Tensorflow namespace Tensorflow
{ {
@@ -256,7 +257,7 @@ namespace Tensorflow
nd = np.array(doubleVals); nd = np.array(doubleVals);
break; break;
case string strVal: case string strVal:
nd = NDArray.FromString(strVal);
nd = new NDArray(Encoding.ASCII.GetBytes(strVal));
break; break;
case string[] strVals: case string[] strVals:
nd = strVals; nd = strVals;


+ 5
- 2
src/TensorFlowNET.Core/Training/Saving/Saver.cs View File

@@ -19,6 +19,7 @@ using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
using System.Text;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
@@ -188,10 +189,12 @@ namespace Tensorflow


if (write_state) if (write_state)
{ {
var path = UTF8Encoding.UTF8.GetString((byte[])model_checkpoint_path[0]);
_RecordLastCheckpoint(path);
_RecordLastCheckpoint(model_checkpoint_path[0].ToString()); _RecordLastCheckpoint(model_checkpoint_path[0].ToString());
checkpoint_management.update_checkpoint_state_internal( checkpoint_management.update_checkpoint_state_internal(
save_dir: save_path_parent, save_dir: save_path_parent,
model_checkpoint_path: model_checkpoint_path[0].ToString(),
model_checkpoint_path: path,
all_model_checkpoint_paths: _last_checkpoints.Keys.Select(x => x).ToList(), all_model_checkpoint_paths: _last_checkpoints.Keys.Select(x => x).ToList(),
latest_filename: latest_filename, latest_filename: latest_filename,
save_relative_paths: _save_relative_paths); save_relative_paths: _save_relative_paths);
@@ -205,7 +208,7 @@ namespace Tensorflow
export_meta_graph(meta_graph_filename, strip_default_attrs: strip_default_attrs, save_debug_info: save_debug_info); export_meta_graph(meta_graph_filename, strip_default_attrs: strip_default_attrs, save_debug_info: save_debug_info);
} }


return _is_empty ? string.Empty : model_checkpoint_path.ToString();
return _is_empty ? string.Empty : UTF8Encoding.UTF8.GetString((byte[])model_checkpoint_path[0]);
} }


public (Saver, object) import_meta_graph(string meta_graph_or_file, public (Saver, object) import_meta_graph(string meta_graph_or_file,


+ 1
- 1
test/TensorFlowNET.UnitTest/SessionTest.cs View File

@@ -109,7 +109,7 @@ namespace TensorFlowNET.UnitTest
var c = tf.strings.substr(a, 4, 8); var c = tf.strings.substr(a, 4, 8);
using (var sess = tf.Session()) using (var sess = tf.Session())
{ {
var result = (string) c.eval(sess);
var result = UTF8Encoding.UTF8.GetString((byte[])c.eval(sess));
Console.WriteLine(result); Console.WriteLine(result);
result.Should().Be("heythere"); result.Should().Be("heythere");
} }


Loading…
Cancel
Save