@@ -37,7 +37,16 @@ namespace Tensorflow.Keras.Common | |||||
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | ||||
{ | { | ||||
var axis = serializer.Deserialize(reader, typeof(int[])); | |||||
int[]? axis; | |||||
if(reader.ValueType == typeof(long)) | |||||
{ | |||||
axis = new int[1]; | |||||
axis[0] = (int)serializer.Deserialize(reader, typeof(int)); | |||||
} | |||||
else | |||||
{ | |||||
axis = serializer.Deserialize(reader, typeof(int[])) as int[]; | |||||
} | |||||
if (axis is null) | if (axis is null) | ||||
{ | { | ||||
throw new ValueError("Cannot deserialize 'null' to `Axis`."); | throw new ValueError("Cannot deserialize 'null' to `Axis`."); | ||||
@@ -51,8 +51,26 @@ namespace Tensorflow.Keras.Common | |||||
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | ||||
{ | { | ||||
var dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; | |||||
if(dims is null) | |||||
long?[] dims; | |||||
try | |||||
{ | |||||
dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; | |||||
} | |||||
catch (JsonSerializationException ex) | |||||
{ | |||||
if (reader.Value.Equals("class_name")) | |||||
{ | |||||
reader.Read(); | |||||
reader.Read(); | |||||
reader.Read(); | |||||
dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; | |||||
} | |||||
else | |||||
{ | |||||
throw ex; | |||||
} | |||||
} | |||||
if (dims is null) | |||||
{ | { | ||||
throw new ValueError("Cannot deserialize 'null' to `Shape`."); | throw new ValueError("Cannot deserialize 'null' to `Shape`."); | ||||
} | } | ||||
@@ -11,6 +11,7 @@ using pbc = global::Google.Protobuf.Collections; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using System.Runtime.CompilerServices; | using System.Runtime.CompilerServices; | ||||
using Tensorflow.Variables; | using Tensorflow.Variables; | ||||
using Tensorflow.Functions; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -75,7 +75,14 @@ namespace Tensorflow.Keras.Engine | |||||
this.inputs = inputs; | this.inputs = inputs; | ||||
this.outputs = outputs; | this.outputs = outputs; | ||||
built = true; | built = true; | ||||
_buildInputShape = inputs.shape; | |||||
if(inputs.Length > 0) | |||||
{ | |||||
_buildInputShape = inputs.shape; | |||||
} | |||||
else | |||||
{ | |||||
_buildInputShape = new Saving.TensorShapeConfig(); | |||||
} | |||||
if (outputs.Any(x => x.KerasHistory == null)) | if (outputs.Any(x => x.KerasHistory == null)) | ||||
base_layer_utils.create_keras_history(outputs); | base_layer_utils.create_keras_history(outputs); | ||||
@@ -72,6 +72,10 @@ namespace Tensorflow.Keras.Saving | |||||
{ | { | ||||
try | try | ||||
{ | { | ||||
if (node_metadata.Identifier.Equals("_tf_keras_metric")) | |||||
{ | |||||
continue; | |||||
} | |||||
loaded_nodes[node_metadata.NodeId] = _load_layer(node_metadata.NodeId, node_metadata.Identifier, | loaded_nodes[node_metadata.NodeId] = _load_layer(node_metadata.NodeId, node_metadata.Identifier, | ||||
node_metadata.Metadata); | node_metadata.Metadata); | ||||
} | } | ||||
@@ -324,7 +328,9 @@ namespace Tensorflow.Keras.Saving | |||||
Trackable obj; | Trackable obj; | ||||
if(identifier == Keras.Saving.SavedModel.Constants.METRIC_IDENTIFIER) | if(identifier == Keras.Saving.SavedModel.Constants.METRIC_IDENTIFIER) | ||||
{ | { | ||||
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||||
// TODO(Rinne): implement it. | |||||
return (null, null); | |||||
//throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -343,7 +349,7 @@ namespace Tensorflow.Keras.Saving | |||||
private (Trackable, Action<object, object, object>) _revive_custom_object(string identifier, KerasMetaData metadata) | private (Trackable, Action<object, object, object>) _revive_custom_object(string identifier, KerasMetaData metadata) | ||||
{ | { | ||||
// TODO: implement it. | |||||
// TODO(Rinne): implement it. | |||||
throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
} | } | ||||
@@ -367,15 +373,14 @@ namespace Tensorflow.Keras.Saving | |||||
} | } | ||||
else if(identifier == Keras.Saving.SavedModel.Constants.SEQUENTIAL_IDENTIFIER) | else if(identifier == Keras.Saving.SavedModel.Constants.SEQUENTIAL_IDENTIFIER) | ||||
{ | { | ||||
model = model = new Sequential(new SequentialArgs | |||||
model = new Sequential(new SequentialArgs | |||||
{ | { | ||||
Name = class_name | Name = class_name | ||||
}); | }); | ||||
} | } | ||||
else | else | ||||
{ | { | ||||
// TODO: implement it. | |||||
throw new NotImplementedException("Not implemented"); | |||||
model = new Functional(new Tensors(), new Tensors(), config["name"].ToObject<string>()); | |||||
} | } | ||||
// Record this model and its layers. This will later be used to reconstruct | // Record this model and its layers. This will later be used to reconstruct | ||||
@@ -21,7 +21,7 @@ public class SequentialModelLoad | |||||
[TestMethod] | [TestMethod] | ||||
public void SimpleModelFromSequential() | public void SimpleModelFromSequential() | ||||
{ | { | ||||
var model = KerasLoadModelUtils.load_model(@"D:/development/tf.net/tf_test/tf.net.simple.sequential"); | |||||
var model = KerasLoadModelUtils.load_model(@"D:/development/tf.net/tf_test/model.pb"); | |||||
Debug.Assert(model is Model); | Debug.Assert(model is Model); | ||||
var m = model as Model; | var m = model as Model; | ||||