Browse Source

Revise customized json converters.

pull/989/head
Yaohui Liu 2 years ago
parent
commit
01e88bb8bb
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
6 changed files with 50 additions and 10 deletions
  1. +10
    -1
      src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs
  2. +20
    -2
      src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs
  3. +1
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs
  4. +8
    -1
      src/TensorFlowNET.Keras/Engine/Functional.cs
  5. +10
    -5
      src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs
  6. +1
    -1
      test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs

+ 10
- 1
src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs View File

@@ -37,7 +37,16 @@ namespace Tensorflow.Keras.Common

public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
var axis = serializer.Deserialize(reader, typeof(int[]));
int[]? axis;
if(reader.ValueType == typeof(long))
{
axis = new int[1];
axis[0] = (int)serializer.Deserialize(reader, typeof(int));
}
else
{
axis = serializer.Deserialize(reader, typeof(int[])) as int[];
}
if (axis is null)
{
throw new ValueError("Cannot deserialize 'null' to `Axis`.");


+ 20
- 2
src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs View File

@@ -51,8 +51,26 @@ namespace Tensorflow.Keras.Common

public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
var dims = serializer.Deserialize(reader, typeof(long?[])) as long?[];
if(dims is null)
long?[] dims;
try
{
dims = serializer.Deserialize(reader, typeof(long?[])) as long?[];
}
catch (JsonSerializationException ex)
{
if (reader.Value.Equals("class_name"))
{
reader.Read();
reader.Read();
reader.Read();
dims = serializer.Deserialize(reader, typeof(long?[])) as long?[];
}
else
{
throw ex;
}
}
if (dims is null)
{
throw new ValueError("Cannot deserialize 'null' to `Shape`.");
}


+ 1
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs View File

@@ -11,6 +11,7 @@ using pbc = global::Google.Protobuf.Collections;
using static Tensorflow.Binding;
using System.Runtime.CompilerServices;
using Tensorflow.Variables;
using Tensorflow.Functions;

namespace Tensorflow
{


+ 8
- 1
src/TensorFlowNET.Keras/Engine/Functional.cs View File

@@ -75,7 +75,14 @@ namespace Tensorflow.Keras.Engine
this.inputs = inputs;
this.outputs = outputs;
built = true;
_buildInputShape = inputs.shape;
if(inputs.Length > 0)
{
_buildInputShape = inputs.shape;
}
else
{
_buildInputShape = new Saving.TensorShapeConfig();
}

if (outputs.Any(x => x.KerasHistory == null))
base_layer_utils.create_keras_history(outputs);


+ 10
- 5
src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs View File

@@ -72,6 +72,10 @@ namespace Tensorflow.Keras.Saving
{
try
{
if (node_metadata.Identifier.Equals("_tf_keras_metric"))
{
continue;
}
loaded_nodes[node_metadata.NodeId] = _load_layer(node_metadata.NodeId, node_metadata.Identifier,
node_metadata.Metadata);
}
@@ -324,7 +328,9 @@ namespace Tensorflow.Keras.Saving
Trackable obj;
if(identifier == Keras.Saving.SavedModel.Constants.METRIC_IDENTIFIER)
{
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues.");
// TODO(Rinne): implement it.
return (null, null);
//throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues.");
}
else
{
@@ -343,7 +349,7 @@ namespace Tensorflow.Keras.Saving

private (Trackable, Action<object, object, object>) _revive_custom_object(string identifier, KerasMetaData metadata)
{
// TODO: implement it.
// TODO(Rinne): implement it.
throw new NotImplementedException();
}

@@ -367,15 +373,14 @@ namespace Tensorflow.Keras.Saving
}
else if(identifier == Keras.Saving.SavedModel.Constants.SEQUENTIAL_IDENTIFIER)
{
model = model = new Sequential(new SequentialArgs
model = new Sequential(new SequentialArgs
{
Name = class_name
});
}
else
{
// TODO: implement it.
throw new NotImplementedException("Not implemented");
model = new Functional(new Tensors(), new Tensors(), config["name"].ToObject<string>());
}

// Record this model and its layers. This will later be used to reconstruct


+ 1
- 1
test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs View File

@@ -21,7 +21,7 @@ public class SequentialModelLoad
[TestMethod]
public void SimpleModelFromSequential()
{
var model = KerasLoadModelUtils.load_model(@"D:/development/tf.net/tf_test/tf.net.simple.sequential");
var model = KerasLoadModelUtils.load_model(@"D:/development/tf.net/tf_test/model.pb");
Debug.Assert(model is Model);
var m = model as Model;



Loading…
Cancel
Save