From 747e6585e773256a8c2a405e14c0fdc8c8a2a7ca Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Sat, 22 Apr 2023 17:20:27 +0800 Subject: [PATCH 1/2] Change type of BuildInputShape to KerasShapesWrapper. --- .../Extensions/JObjectExtensions.cs | 23 ++++++ .../Framework/Models/TensorSpec.cs | 2 +- .../Keras/Activations/Activations.cs | 2 +- .../ArgsDefinition/AutoSerializeLayerArgs.cs | 3 +- .../ArgsDefinition/Core/InputLayerArgs.cs | 4 +- .../Keras/ArgsDefinition/LayerArgs.cs | 2 +- src/TensorFlowNET.Core/Keras/Layers/ILayer.cs | 6 +- .../CustomizedActivationJsonConverter.cs | 2 +- .../Json}/CustomizedAxisJsonConverter.cs | 6 +- .../Json}/CustomizedDTypeJsonConverter.cs | 4 +- .../CustomizedIInitializerJsonConverter.cs | 11 +-- ...stomizedKerasShapesWrapperJsonConverter.cs | 75 +++++++++++++++++++ .../CustomizedNodeConfigJsonConverter.cs | 6 +- .../Json}/CustomizedShapeJsonConverter.cs | 26 ++++--- .../Keras/Saving/KerasShapesWrapper.cs | 60 +++++++++++++++ .../Keras/Saving/ModelConfig.cs | 2 +- .../Keras/Saving/NodeConfig.cs | 2 +- src/TensorFlowNET.Core/NumPy/Axis.cs | 2 +- src/TensorFlowNET.Core/Numpy/Shape.cs | 2 +- .../Operations/Initializers/IInitializer.cs | 2 +- .../Operations/NnOps/RNNCell.cs | 9 ++- src/TensorFlowNET.Core/Tensors/TF_DataType.cs | 2 +- .../Engine/Functional.FromConfig.cs | 4 +- .../Engine/Functional.GetConfig.cs | 4 +- src/TensorFlowNET.Keras/Engine/Layer.cs | 12 +-- src/TensorFlowNET.Keras/Engine/Model.Build.cs | 40 +++++++--- src/TensorFlowNET.Keras/Engine/Sequential.cs | 2 +- .../Layers/Activation/ELU.cs | 3 +- .../Layers/Activation/Exponential.cs | 3 +- .../Layers/Activation/SELU.cs | 3 +- .../Layers/Attention/Attention.cs | 2 +- .../Layers/Convolution/Conv2DTranspose.cs | 6 +- .../Layers/Convolution/Convolutional.cs | 8 +- src/TensorFlowNET.Keras/Layers/Core/Dense.cs | 8 +- .../Layers/Core/EinsumDense.cs | 6 +- .../Layers/Core/Embedding.cs | 5 +- .../Layers/Core/InputLayer.cs | 13 ++-- .../Layers/Merging/Concatenate.cs | 3 +- .../Layers/Merging/Merge.cs | 3 +- .../Normalization/BatchNormalization.cs | 8 +- .../Normalization/LayerNormalization.cs | 8 +- .../Layers/Normalization/Normalization.cs | 10 ++- .../Preprocessing/PreprocessingLayer.cs | 4 +- .../Layers/Preprocessing/TextVectorization.cs | 5 +- .../Layers/Reshaping/Cropping1D.cs | 3 +- .../Layers/Reshaping/Cropping2D.cs | 3 +- .../Layers/Reshaping/Cropping3D.cs | 3 +- .../Layers/Reshaping/Permute.cs | 8 +- src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs | 3 +- .../Layers/Rnn/SimpleRNN.cs | 8 +- .../Layers/Rnn/SimpleRNNCell.cs | 8 +- src/TensorFlowNET.Keras/Models/ModelsApi.cs | 2 +- .../Saving/KerasMetaData.cs | 7 +- .../Saving/KerasModelConfig.cs | 16 ++++ .../Saving/KerasObjectLoader.cs | 13 ++-- .../Utils/base_layer_utils.cs | 5 ++ .../Utils/generic_utils.cs | 4 +- 57 files changed, 373 insertions(+), 123 deletions(-) create mode 100644 src/TensorFlowNET.Core/Extensions/JObjectExtensions.cs rename src/TensorFlowNET.Core/Keras/{Common => Saving/Json}/CustomizedActivationJsonConverter.cs (97%) rename src/TensorFlowNET.Core/Keras/{Common => Saving/Json}/CustomizedAxisJsonConverter.cs (92%) rename src/TensorFlowNET.Core/Keras/{Common => Saving/Json}/CustomizedDTypeJsonConverter.cs (89%) rename src/TensorFlowNET.Core/Keras/{Common => Saving/Json}/CustomizedIInitializerJsonConverter.cs (88%) create mode 100644 src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedKerasShapesWrapperJsonConverter.cs rename src/TensorFlowNET.Core/Keras/{Common => Saving/Json}/CustomizedNodeConfigJsonConverter.cs (96%) rename src/TensorFlowNET.Core/Keras/{Common => Saving/Json}/CustomizedShapeJsonConverter.cs (76%) create mode 100644 src/TensorFlowNET.Core/Keras/Saving/KerasShapesWrapper.cs create mode 100644 src/TensorFlowNET.Keras/Saving/KerasModelConfig.cs diff --git a/src/TensorFlowNET.Core/Extensions/JObjectExtensions.cs b/src/TensorFlowNET.Core/Extensions/JObjectExtensions.cs new file mode 100644 index 00000000..2e758dbf --- /dev/null +++ b/src/TensorFlowNET.Core/Extensions/JObjectExtensions.cs @@ -0,0 +1,23 @@ +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Extensions +{ + public static class JObjectExtensions + { + public static T? TryGetOrReturnNull(this JObject obj, string key) + { + var res = obj[key]; + if(res is null) + { + return default(T); + } + else + { + return res.ToObject(); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs b/src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs index b6a279db..083d4813 100644 --- a/src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs +++ b/src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs @@ -7,7 +7,7 @@ namespace Tensorflow.Framework.Models public TensorSpec(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) : base(shape, dtype, name) { - + } public TensorSpec _unbatch() diff --git a/src/TensorFlowNET.Core/Keras/Activations/Activations.cs b/src/TensorFlowNET.Core/Keras/Activations/Activations.cs index 3dde625e..f0d59ed6 100644 --- a/src/TensorFlowNET.Core/Keras/Activations/Activations.cs +++ b/src/TensorFlowNET.Core/Keras/Activations/Activations.cs @@ -1,7 +1,7 @@ using Newtonsoft.Json; using System.Reflection; using System.Runtime.Versioning; -using Tensorflow.Keras.Common; +using Tensorflow.Keras.Saving.Common; namespace Tensorflow.Keras { diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs index 59dc51b8..583ab932 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.ArgsDefinition { @@ -18,7 +19,7 @@ namespace Tensorflow.Keras.ArgsDefinition [JsonProperty("dtype")] public override TF_DataType DType { get => base.DType; set => base.DType = value; } [JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)] - public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; } + public override KerasShapesWrapper BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; } [JsonProperty("trainable")] public override bool Trainable { get => base.Trainable; set => base.Trainable = value; } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs index be43e0a6..e036e191 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs @@ -1,6 +1,6 @@ using Newtonsoft.Json; using Newtonsoft.Json.Serialization; -using Tensorflow.Keras.Common; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.ArgsDefinition { @@ -17,6 +17,6 @@ namespace Tensorflow.Keras.ArgsDefinition [JsonProperty("dtype")] public override TF_DataType DType { get => base.DType; set => base.DType = value; } [JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)] - public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; } + public override KerasShapesWrapper BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs index febf1417..11b8ba39 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs @@ -33,7 +33,7 @@ namespace Tensorflow.Keras.ArgsDefinition /// /// Only applicable to input layers. /// - public virtual Shape BatchInputShape { get; set; } + public virtual KerasShapesWrapper BatchInputShape { get; set; } public virtual int BatchSize { get; set; } = -1; diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index 1e473d75..f7669394 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -10,7 +10,7 @@ namespace Tensorflow.Keras string Name { get; } bool Trainable { get; } bool Built { get; } - void build(Shape input_shape); + void build(KerasShapesWrapper input_shape); List Layers { get; } List InboundNodes { get; } List OutboundNodes { get; } @@ -22,8 +22,8 @@ namespace Tensorflow.Keras void set_weights(IEnumerable weights); List get_weights(); Shape OutputShape { get; } - Shape BatchInputShape { get; } - TensorShapeConfig BuildInputShape { get; } + KerasShapesWrapper BatchInputShape { get; } + KerasShapesWrapper BuildInputShape { get; } TF_DataType DType { get; } int count_params(); void adapt(Tensor data, int? batch_size = null, int? steps = null); diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedActivationJsonConverter.cs similarity index 97% rename from src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs rename to src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedActivationJsonConverter.cs index 04ee79e3..b348780c 100644 --- a/src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs +++ b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedActivationJsonConverter.cs @@ -6,7 +6,7 @@ using System.Collections.Generic; using System.Text; using static Tensorflow.Binding; -namespace Tensorflow.Keras.Common +namespace Tensorflow.Keras.Saving.Common { public class CustomizedActivationJsonConverter : JsonConverter { diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedAxisJsonConverter.cs similarity index 92% rename from src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs rename to src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedAxisJsonConverter.cs index f6087a43..aea4af6d 100644 --- a/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs +++ b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedAxisJsonConverter.cs @@ -4,7 +4,7 @@ using System; using System.Collections.Generic; using System.Text; -namespace Tensorflow.Keras.Common +namespace Tensorflow.Keras.Saving.Common { public class CustomizedAxisJsonConverter : JsonConverter { @@ -38,7 +38,7 @@ namespace Tensorflow.Keras.Common public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) { int[]? axis; - if(reader.ValueType == typeof(long)) + if (reader.ValueType == typeof(long)) { axis = new int[1]; axis[0] = (int)serializer.Deserialize(reader, typeof(int)); @@ -51,7 +51,7 @@ namespace Tensorflow.Keras.Common { throw new ValueError("Cannot deserialize 'null' to `Axis`."); } - return new Axis((int[])(axis!)); + return new Axis(axis!); } } } diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedDTypeJsonConverter.cs similarity index 89% rename from src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs rename to src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedDTypeJsonConverter.cs index fce7bec5..29b3b094 100644 --- a/src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs +++ b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedDTypeJsonConverter.cs @@ -1,7 +1,7 @@ using Newtonsoft.Json.Linq; using Newtonsoft.Json; -namespace Tensorflow.Keras.Common +namespace Tensorflow.Keras.Saving.Common { public class CustomizedDTypeJsonConverter : JsonConverter { @@ -16,7 +16,7 @@ namespace Tensorflow.Keras.Common public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) { - var token = JToken.FromObject(dtypes.as_numpy_name((TF_DataType)value)); + var token = JToken.FromObject(((TF_DataType)value).as_numpy_name()); token.WriteTo(writer); } diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedIInitializerJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedIInitializerJsonConverter.cs similarity index 88% rename from src/TensorFlowNET.Core/Keras/Common/CustomizedIInitializerJsonConverter.cs rename to src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedIInitializerJsonConverter.cs index 0ff24518..a7bae56d 100644 --- a/src/TensorFlowNET.Core/Keras/Common/CustomizedIInitializerJsonConverter.cs +++ b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedIInitializerJsonConverter.cs @@ -4,9 +4,10 @@ using System; using System.Collections.Generic; using System.Text; using Tensorflow.Operations; + using Tensorflow.Operations.Initializers; -namespace Tensorflow.Keras.Common +namespace Tensorflow.Keras.Saving.Common { class InitializerInfo { @@ -27,7 +28,7 @@ namespace Tensorflow.Keras.Common public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) { var initializer = value as IInitializer; - if(initializer is null) + if (initializer is null) { JToken.FromObject(null).WriteTo(writer); return; @@ -42,7 +43,7 @@ namespace Tensorflow.Keras.Common public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) { var info = serializer.Deserialize(reader); - if(info is null) + if (info is null) { return null; } @@ -54,8 +55,8 @@ namespace Tensorflow.Keras.Common "Orthogonal" => new Orthogonal(info.config["gain"].ToObject(), info.config["seed"].ToObject()), "RandomNormal" => new RandomNormal(info.config["mean"].ToObject(), info.config["stddev"].ToObject(), info.config["seed"].ToObject()), - "RandomUniform" => new RandomUniform(minval:info.config["minval"].ToObject(), - maxval:info.config["maxval"].ToObject(), seed: info.config["seed"].ToObject()), + "RandomUniform" => new RandomUniform(minval: info.config["minval"].ToObject(), + maxval: info.config["maxval"].ToObject(), seed: info.config["seed"].ToObject()), "TruncatedNormal" => new TruncatedNormal(info.config["mean"].ToObject(), info.config["stddev"].ToObject(), info.config["seed"].ToObject()), "VarianceScaling" => new VarianceScaling(info.config["scale"].ToObject(), info.config["mode"].ToObject(), diff --git a/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedKerasShapesWrapperJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedKerasShapesWrapperJsonConverter.cs new file mode 100644 index 00000000..1a4245bf --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedKerasShapesWrapperJsonConverter.cs @@ -0,0 +1,75 @@ +using Newtonsoft.Json.Linq; +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Saving.Json +{ + public class CustomizedKerasShapesWrapperJsonConverter : JsonConverter + { + public override bool CanConvert(Type objectType) + { + return objectType == typeof(KerasShapesWrapper); + } + + public override bool CanRead => true; + + public override bool CanWrite => true; + + public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) + { + if (value is null) + { + JToken.FromObject(null).WriteTo(writer); + return; + } + if (value is not KerasShapesWrapper wrapper) + { + throw new TypeError($"Expected `KerasShapesWrapper` to be serialized, bug got {value.GetType()}"); + } + if (wrapper.Shapes.Length == 0) + { + JToken.FromObject(null).WriteTo(writer); + } + else if (wrapper.Shapes.Length == 1) + { + JToken.FromObject(wrapper.Shapes[0]).WriteTo(writer); + } + else + { + JToken.FromObject(wrapper.Shapes).WriteTo(writer); + } + } + + public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) + { + if (reader.TokenType == JsonToken.StartArray) + { + TensorShapeConfig[] shapes = serializer.Deserialize(reader); + if (shapes is null) + { + return null; + } + return new KerasShapesWrapper(shapes); + } + else if (reader.TokenType == JsonToken.StartObject) + { + var shape = serializer.Deserialize(reader); + if (shape is null) + { + return null; + } + return new KerasShapesWrapper(shape); + } + else if (reader.TokenType == JsonToken.Null) + { + return null; + } + else + { + throw new ValueError($"Cannot deserialize the token type {reader.TokenType}"); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedNodeConfigJsonConverter.cs similarity index 96% rename from src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs rename to src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedNodeConfigJsonConverter.cs index cfd8ee8f..51194a61 100644 --- a/src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs +++ b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedNodeConfigJsonConverter.cs @@ -7,7 +7,7 @@ using System.Linq; using System.Text; using Tensorflow.Keras.Saving; -namespace Tensorflow.Keras.Common +namespace Tensorflow.Keras.Saving.Common { public class CustomizedNodeConfigJsonConverter : JsonConverter { @@ -46,10 +46,10 @@ namespace Tensorflow.Keras.Common { throw new ValueError("Cannot deserialize 'null' to `Shape`."); } - if(values.Length == 1) + if (values.Length == 1) { var array = values[0] as JArray; - if(array is null) + if (array is null) { throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); } diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedShapeJsonConverter.cs similarity index 76% rename from src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs rename to src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedShapeJsonConverter.cs index 9d4b53a9..39799e92 100644 --- a/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs +++ b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedShapeJsonConverter.cs @@ -5,14 +5,14 @@ using System; using System.Collections.Generic; using System.Text; -namespace Tensorflow.Keras.Common +namespace Tensorflow.Keras.Saving.Common { class ShapeInfoFromPython { public string class_name { get; set; } public long?[] items { get; set; } } - public class CustomizedShapeJsonConverter: JsonConverter + public class CustomizedShapeJsonConverter : JsonConverter { public override bool CanConvert(Type objectType) { @@ -25,12 +25,12 @@ namespace Tensorflow.Keras.Common public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) { - if(value is null) + if (value is null) { var token = JToken.FromObject(null); token.WriteTo(writer); } - else if(value is not Shape) + else if (value is not Shape) { throw new TypeError($"Unable to use `CustomizedShapeJsonConverter` to serialize the type {value.GetType()}."); } @@ -38,7 +38,7 @@ namespace Tensorflow.Keras.Common { var shape = (value as Shape)!; long?[] dims = new long?[shape.ndim]; - for(int i = 0; i < dims.Length; i++) + for (int i = 0; i < dims.Length; i++) { if (shape.dims[i] == -1) { @@ -61,7 +61,7 @@ namespace Tensorflow.Keras.Common public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) { long?[] dims; - try + if (reader.TokenType == JsonToken.StartObject) { var shape_info_from_python = serializer.Deserialize(reader); if (shape_info_from_python is null) @@ -70,14 +70,22 @@ namespace Tensorflow.Keras.Common } dims = shape_info_from_python.items; } - catch(JsonSerializationException) + else if (reader.TokenType == JsonToken.StartArray) { dims = serializer.Deserialize(reader); } + else if (reader.TokenType == JsonToken.Null) + { + return null; + } + else + { + throw new ValueError($"Cannot deserialize the token {reader} as Shape."); + } long[] convertedDims = new long[dims.Length]; - for(int i = 0; i < dims.Length; i++) + for (int i = 0; i < dims.Length; i++) { - convertedDims[i] = dims[i] ?? (-1); + convertedDims[i] = dims[i] ?? -1; } return new Shape(convertedDims); } diff --git a/src/TensorFlowNET.Core/Keras/Saving/KerasShapesWrapper.cs b/src/TensorFlowNET.Core/Keras/Saving/KerasShapesWrapper.cs new file mode 100644 index 00000000..d91d3161 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/KerasShapesWrapper.cs @@ -0,0 +1,60 @@ +using Newtonsoft.Json.Linq; +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; +using System.Diagnostics; +using OneOf.Types; +using Tensorflow.Keras.Saving.Json; + +namespace Tensorflow.Keras.Saving +{ + [JsonConverter(typeof(CustomizedKerasShapesWrapperJsonConverter))] + public class KerasShapesWrapper + { + public TensorShapeConfig[] Shapes { get; set; } + + public KerasShapesWrapper(Shape shape) + { + Shapes = new TensorShapeConfig[] { shape }; + } + + public KerasShapesWrapper(TensorShapeConfig shape) + { + Shapes = new TensorShapeConfig[] { shape }; + } + + public KerasShapesWrapper(TensorShapeConfig[] shapes) + { + Shapes = shapes; + } + + public KerasShapesWrapper(IEnumerable shape) + { + Shapes = shape.Select(x => (TensorShapeConfig)x).ToArray(); + } + + public Shape ToSingleShape() + { + Debug.Assert(Shapes.Length == 1); + var shape_config = Shapes[0]; + Debug.Assert(shape_config is not null); + return new Shape(shape_config.Items.Select(x => x is null ? -1 : x.Value).ToArray()); + } + + public Shape[] ToShapeArray() + { + return Shapes.Select(x => new Shape(x.Items.Select(y => y is null ? -1 : y.Value).ToArray())).ToArray(); + } + + public static implicit operator KerasShapesWrapper(Shape shape) + { + return new KerasShapesWrapper(shape); + } + public static implicit operator KerasShapesWrapper(TensorShapeConfig shape) + { + return new KerasShapesWrapper(shape); + } + + } +} diff --git a/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs index 934d3b15..8ddcd1f0 100644 --- a/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs +++ b/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs @@ -9,7 +9,7 @@ using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; namespace Tensorflow.Keras.Saving { - public class ModelConfig : IKerasConfig + public class FunctionalConfig : IKerasConfig { [JsonProperty("name")] public string Name { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs index 20e2fef5..8337ae01 100644 --- a/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs +++ b/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs @@ -2,7 +2,7 @@ using System; using System.Collections.Generic; using System.Text; -using Tensorflow.Keras.Common; +using Tensorflow.Keras.Saving.Common; namespace Tensorflow.Keras.Saving { diff --git a/src/TensorFlowNET.Core/NumPy/Axis.cs b/src/TensorFlowNET.Core/NumPy/Axis.cs index 709ca9b2..976c764f 100644 --- a/src/TensorFlowNET.Core/NumPy/Axis.cs +++ b/src/TensorFlowNET.Core/NumPy/Axis.cs @@ -19,7 +19,7 @@ using System; using System.Collections.Generic; using System.Linq; using System.Text; -using Tensorflow.Keras.Common; +using Tensorflow.Keras.Saving.Common; namespace Tensorflow { diff --git a/src/TensorFlowNET.Core/Numpy/Shape.cs b/src/TensorFlowNET.Core/Numpy/Shape.cs index ecf73586..c339f12d 100644 --- a/src/TensorFlowNET.Core/Numpy/Shape.cs +++ b/src/TensorFlowNET.Core/Numpy/Shape.cs @@ -19,7 +19,7 @@ using System; using System.Collections.Generic; using System.Linq; using System.Text; -using Tensorflow.Keras.Common; +using Tensorflow.Keras.Saving.Common; using Tensorflow.NumPy; namespace Tensorflow diff --git a/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs b/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs index ca8348aa..35b92448 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs @@ -16,7 +16,7 @@ using Newtonsoft.Json; using System.Collections.Generic; -using Tensorflow.Keras.Common; +using Tensorflow.Keras.Saving.Common; namespace Tensorflow { diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index 5847e31a..ecc9ca11 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -80,9 +80,9 @@ namespace Tensorflow public Shape OutputShape => throw new NotImplementedException(); - public Shape BatchInputShape => throw new NotImplementedException(); + public KerasShapesWrapper BatchInputShape => throw new NotImplementedException(); - public TensorShapeConfig BuildInputShape => throw new NotImplementedException(); + public KerasShapesWrapper BuildInputShape => throw new NotImplementedException(); public TF_DataType DType => throw new NotImplementedException(); protected bool built = false; @@ -162,6 +162,11 @@ namespace Tensorflow throw new NotImplementedException(); } + public void build(KerasShapesWrapper input_shape) + { + throw new NotImplementedException(); + } + public Trackable GetTrackable() { throw new NotImplementedException(); } public void adapt(Tensor data, int? batch_size = null, int? steps = null) diff --git a/src/TensorFlowNET.Core/Tensors/TF_DataType.cs b/src/TensorFlowNET.Core/Tensors/TF_DataType.cs index 0f514b42..2a6f7114 100644 --- a/src/TensorFlowNET.Core/Tensors/TF_DataType.cs +++ b/src/TensorFlowNET.Core/Tensors/TF_DataType.cs @@ -1,5 +1,5 @@ using Newtonsoft.Json; -using Tensorflow.Keras.Common; +using Tensorflow.Keras.Saving.Common; namespace Tensorflow { diff --git a/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs b/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs index f4407265..7b826af8 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs @@ -11,7 +11,7 @@ namespace Tensorflow.Keras.Engine { public partial class Functional { - public static Functional from_config(ModelConfig config) + public static Functional from_config(FunctionalConfig config) { var (input_tensors, output_tensors, created_layers) = reconstruct_from_config(config); var model = new Functional(input_tensors, output_tensors, name: config.Name); @@ -24,7 +24,7 @@ namespace Tensorflow.Keras.Engine /// /// /// - public static (Tensors, Tensors, Dictionary) reconstruct_from_config(ModelConfig config, Dictionary? created_layers = null) + public static (Tensors, Tensors, Dictionary) reconstruct_from_config(FunctionalConfig config, Dictionary? created_layers = null) { // Layer instances created during the graph reconstruction process. created_layers = created_layers ?? new Dictionary(); diff --git a/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs b/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs index 3aeb3200..df77e596 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs @@ -19,9 +19,9 @@ namespace Tensorflow.Keras.Engine /// /// Builds the config, which consists of the node graph and serialized layers. /// - ModelConfig get_network_config() + FunctionalConfig get_network_config() { - var config = new ModelConfig + var config = new FunctionalConfig { Name = name }; diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index 11a0584c..7462b136 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -211,9 +211,9 @@ namespace Tensorflow.Keras.Engine protected bool computePreviousMask; protected List updates; - public Shape BatchInputShape => args.BatchInputShape; - protected TensorShapeConfig _buildInputShape = null; - public TensorShapeConfig BuildInputShape => _buildInputShape; + public KerasShapesWrapper BatchInputShape => args.BatchInputShape; + protected KerasShapesWrapper _buildInputShape = null; + public KerasShapesWrapper BuildInputShape => _buildInputShape; List inboundNodes; public List InboundNodes => inboundNodes; @@ -284,7 +284,7 @@ namespace Tensorflow.Keras.Engine // Manage input shape information if passed. if (args.BatchInputShape == null && args.InputShape != null) { - args.BatchInputShape = new long[] { args.BatchSize }.Concat(args.InputShape.dims).ToArray(); + args.BatchInputShape = new KerasShapesWrapper(new long[] { args.BatchSize }.Concat(args.InputShape.dims).ToArray()); } } @@ -363,7 +363,7 @@ namespace Tensorflow.Keras.Engine tf.Context.eager_mode(isFunc: tf.Context.is_build_function()); } - build(inputs.shape); + build(new KerasShapesWrapper(inputs.shape)); if (need_restore_mode) tf.Context.restore_mode(); @@ -371,7 +371,7 @@ namespace Tensorflow.Keras.Engine built = true; } - public virtual void build(Shape input_shape) + public virtual void build(KerasShapesWrapper input_shape) { _buildInputShape = input_shape; built = true; diff --git a/src/TensorFlowNET.Keras/Engine/Model.Build.cs b/src/TensorFlowNET.Keras/Engine/Model.Build.cs index a51b9434..69afdef9 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Build.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Build.cs @@ -1,6 +1,8 @@ using System; using System.Linq; using Tensorflow.Graphs; +using Tensorflow.Keras.Saving; +using Tensorflow.Keras.Utils; using static Tensorflow.Binding; using static Tensorflow.KerasApi; @@ -8,22 +10,40 @@ namespace Tensorflow.Keras.Engine { public partial class Model { - public override void build(Shape input_shape) + public override void build(KerasShapesWrapper input_shape) { - if (this is Functional || this is Sequential) + if (_is_graph_network || this is Functional || this is Sequential) { base.build(input_shape); return; } - var graph = tf.executing_eagerly() ? new FuncGraph("build_graph") : keras.backend.get_graph(); - - graph.as_default(); - - var x = tf.placeholder(DType, input_shape); - Call(x, training: false); - - graph.Exit(); + if(input_shape is not null && this.inputs is null) + { + var graph = tf.executing_eagerly() ? new FuncGraph("build_graph") : keras.backend.get_graph(); + graph.as_default(); + var shapes = input_shape.ToShapeArray(); + var x = new Tensors(shapes.Select(x => base_layer_utils.generate_placeholders_from_shape(x))); + try + { + Call(x, training: false); + } + catch (InvalidArgumentError) + { + throw new ValueError("You cannot build your model by calling `build` " + + "if your layers do not support float type inputs. " + + "Instead, in order to instantiate and build your " + + "model, `call` your model on real tensor data (of the correct dtype)."); + } + catch (TypeError) + { + throw new ValueError("You cannot build your model by calling `build` " + + "if your layers do not support float type inputs. " + + "Instead, in order to instantiate and build your " + + "model, `call` your model on real tensor data (of the correct dtype)."); + } + graph.Exit(); + } base.build(input_shape); } diff --git a/src/TensorFlowNET.Keras/Engine/Sequential.cs b/src/TensorFlowNET.Keras/Engine/Sequential.cs index c9b8cfac..90167a9d 100644 --- a/src/TensorFlowNET.Keras/Engine/Sequential.cs +++ b/src/TensorFlowNET.Keras/Engine/Sequential.cs @@ -92,7 +92,7 @@ namespace Tensorflow.Keras.Engine { // Instantiate an input layer. var x = keras.Input( - batch_input_shape: layer.BatchInputShape, + batch_input_shape: layer.BatchInputShape.ToSingleShape(), dtype: layer.DType, name: layer.Name + "_input"); diff --git a/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs b/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs index 9cb5b756..739c0d56 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Text; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; using static Tensorflow.Binding; namespace Tensorflow.Keras.Layers { @@ -19,7 +20,7 @@ namespace Tensorflow.Keras.Layers { this.args = args; } - public override void build(Shape input_shape) + public override void build(KerasShapesWrapper input_shape) { if (alpha < 0f) { diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs b/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs index 981f96f0..17636302 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Text; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; using static Tensorflow.Binding; namespace Tensorflow.Keras.Layers { @@ -12,7 +13,7 @@ namespace Tensorflow.Keras.Layers { { // Exponential has no args } - public override void build(Shape input_shape) + public override void build(KerasShapesWrapper input_shape) { base.build(input_shape); } diff --git a/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs b/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs index 9b5bc0e6..53101fbb 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Text; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; using static Tensorflow.Binding; namespace Tensorflow.Keras.Layers { @@ -15,7 +16,7 @@ namespace Tensorflow.Keras.Layers { public SELU ( LayerArgs args ) : base(args) { // SELU has no arguments } - public override void build(Shape input_shape) { + public override void build(KerasShapesWrapper input_shape) { if ( alpha < 0f ) { throw new ValueError("Alpha must be a number greater than 0."); } diff --git a/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs b/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs index c5131630..e6a8e1a6 100644 --- a/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs +++ b/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs @@ -93,7 +93,7 @@ namespace Tensorflow.Keras.Layers } // Creates variable when `use_scale` is True or `score_mode` is `concat`. - public override void build(Shape input_shape) + public override void build(KerasShapesWrapper input_shape) { if (this.use_scale) this.scale = this.add_weight(name: "scale", diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs index de4080b0..13bea627 100644 --- a/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs +++ b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs @@ -19,6 +19,7 @@ using static Tensorflow.Binding; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Utils; using static Tensorflow.KerasApi; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Layers { @@ -58,13 +59,14 @@ namespace Tensorflow.Keras.Layers return args; } - public override void build(Shape input_shape) + public override void build(KerasShapesWrapper input_shape) { + var single_shape = input_shape.ToSingleShape(); if (len(input_shape) != 4) throw new ValueError($"Inputs should have rank 4. Received input shape: {input_shape}"); var channel_axis = _get_channel_axis(); - var input_dim = input_shape[-1]; + var input_dim = single_shape[-1]; var kernel_shape = new Shape(kernel_size[0], kernel_size[1], filters, input_dim); kernel = add_weight(name: "kernel", diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs index 8f6a6c5b..c575362c 100644 --- a/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs +++ b/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs @@ -19,6 +19,7 @@ using System.Collections.Generic; using System.Linq; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; using Tensorflow.Keras.Utils; using Tensorflow.Operations; using static Tensorflow.Binding; @@ -57,12 +58,13 @@ namespace Tensorflow.Keras.Layers _tf_data_format = conv_utils.convert_data_format(data_format, rank + 2); } - public override void build(Shape input_shape) + public override void build(KerasShapesWrapper input_shape) { int channel_axis = data_format == "channels_first" ? 1 : -1; + var single_shape = input_shape.ToSingleShape(); var input_channel = channel_axis < 0 ? - input_shape.dims[input_shape.ndim + channel_axis] : - input_shape.dims[channel_axis]; + single_shape.dims[single_shape.ndim + channel_axis] : + single_shape.dims[channel_axis]; Shape kernel_shape = kernel_size.dims.concat(new long[] { input_channel / args.Groups, filters }); kernel = add_weight(name: "kernel", shape: kernel_shape, diff --git a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs index decdcb1d..b1cc2446 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs @@ -16,9 +16,11 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; using static Tensorflow.Binding; namespace Tensorflow.Keras.Layers @@ -41,10 +43,12 @@ namespace Tensorflow.Keras.Layers this.inputSpec = new InputSpec(min_ndim: 2); } - public override void build(Shape input_shape) + public override void build(KerasShapesWrapper input_shape) { _buildInputShape = input_shape; - var last_dim = input_shape.dims.Last(); + Debug.Assert(input_shape.Shapes.Length <= 1); + var single_shape = input_shape.ToSingleShape(); + var last_dim = single_shape.dims.Last(); var axes = new Dictionary(); axes[-1] = (int)last_dim; inputSpec = new InputSpec(min_ndim: 2, axes: axes); diff --git a/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs b/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs index c928591f..fb604f77 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Text.RegularExpressions; using Tensorflow.Keras.Engine; using Tensorflow.Keras.ArgsDefinition.Core; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Layers { @@ -119,9 +120,10 @@ namespace Tensorflow.Keras.Layers this.bias_constraint = args.BiasConstraint; } - public override void build(Shape input_shape) + public override void build(KerasShapesWrapper input_shape) { - var shape_data = _analyze_einsum_string(this.equation, this.bias_axes, input_shape, this.partial_output_shape); + var shape_data = _analyze_einsum_string(this.equation, this.bias_axes, + input_shape.ToSingleShape(), this.partial_output_shape); var kernel_shape = shape_data.Item1; var bias_shape = shape_data.Item2; this.full_output_shape = shape_data.Item3; diff --git a/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs b/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs index 606f387b..9487a7d0 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs @@ -17,6 +17,7 @@ using System.Linq; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; using static Tensorflow.Binding; namespace Tensorflow.Keras.Layers @@ -48,13 +49,13 @@ namespace Tensorflow.Keras.Layers args.InputShape = args.InputLength; if (args.BatchInputShape == null) - args.BatchInputShape = new long[] { args.BatchSize }.Concat(args.InputShape.dims).ToArray(); + args.BatchInputShape = new KerasShapesWrapper(new long[] { args.BatchSize }.Concat(args.InputShape.dims).ToArray()); embeddings_initializer = args.EmbeddingsInitializer ?? tf.random_uniform_initializer; SupportsMasking = mask_zero; } - public override void build(Shape input_shape) + public override void build(KerasShapesWrapper input_shape) { tf.Context.eager_mode(); embeddings = add_weight(shape: (input_dim, output_dim), diff --git a/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs b/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs index a44c0bde..f7385bad 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs @@ -40,10 +40,10 @@ namespace Tensorflow.Keras.Layers built = true; SupportsMasking = true; - if (BatchInputShape != null) + if (BatchInputShape is not null) { - args.BatchSize = (int)BatchInputShape.dims[0]; - args.InputShape = BatchInputShape.dims.Skip(1).ToArray(); + args.BatchSize = (int)(BatchInputShape.ToSingleShape().dims[0]); + args.InputShape = BatchInputShape.ToSingleShape().dims.Skip(1).ToArray(); } // moved to base class @@ -63,9 +63,8 @@ namespace Tensorflow.Keras.Layers { if (args.InputShape != null) { - args.BatchInputShape = new long[] { args.BatchSize } - .Concat(args.InputShape.dims) - .ToArray(); + args.BatchInputShape = new Saving.KerasShapesWrapper(new long[] { args.BatchSize } + .Concat(args.InputShape.dims).ToArray()); } else { @@ -76,7 +75,7 @@ namespace Tensorflow.Keras.Layers graph.as_default(); args.InputTensor = keras.backend.placeholder( - shape: BatchInputShape, + shape: BatchInputShape.ToSingleShape(), dtype: DType, name: Name, sparse: args.Sparse, diff --git a/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs b/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs index da7e857a..a2a8286b 100644 --- a/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs +++ b/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs @@ -4,6 +4,7 @@ using System.Linq; using System.Text; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; using Tensorflow.Keras.Utils; using static Tensorflow.Binding; using static Tensorflow.KerasApi; @@ -23,7 +24,7 @@ namespace Tensorflow.Keras.Layers this.args = args; } - public override void build(Shape input_shape) + public override void build(KerasShapesWrapper input_shape) { /*var shape_set = new HashSet(); var reduced_inputs_shapes = inputs.Select(x => x.shape).ToArray(); diff --git a/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs b/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs index 3cd43af9..7df654ee 100644 --- a/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs +++ b/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs @@ -4,6 +4,7 @@ using System.Text; using static Tensorflow.Binding; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Layers { @@ -14,7 +15,7 @@ namespace Tensorflow.Keras.Layers } - public override void build(Shape input_shape) + public override void build(KerasShapesWrapper input_shape) { // output_shape = input_shape.dims[1^]; _buildInputShape = input_shape; diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs index 3b8e1ee8..d02d2509 100644 --- a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs +++ b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs @@ -19,6 +19,7 @@ using System.Collections.Generic; using System.Linq; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; using Tensorflow.Keras.Utils; using static Tensorflow.Binding; @@ -53,9 +54,10 @@ namespace Tensorflow.Keras.Layers axis = args.Axis.dims.Select(x => (int)x).ToArray(); } - public override void build(Shape input_shape) + public override void build(KerasShapesWrapper input_shape) { - var ndims = input_shape.ndim; + var single_shape = input_shape.ToSingleShape(); + var ndims = single_shape.ndim; foreach (var (idx, x) in enumerate(axis)) if (x < 0) args.Axis.dims[idx] = axis[idx] = ndims + x; @@ -74,7 +76,7 @@ namespace Tensorflow.Keras.Layers var axis_to_dim = new Dictionary(); foreach (var x in axis) - axis_to_dim[x] = (int)input_shape[x]; + axis_to_dim[x] = (int)single_shape[x]; inputSpec = new InputSpec(ndim: ndims, axes: axis_to_dim); var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType; diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs index e19b9c30..e90c0402 100644 --- a/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs +++ b/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs @@ -19,6 +19,7 @@ using System.Collections.Generic; using System.Linq; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; using Tensorflow.Keras.Utils; using static Tensorflow.Binding; @@ -49,16 +50,17 @@ namespace Tensorflow.Keras.Layers axis = args.Axis.axis; } - public override void build(Shape input_shape) + public override void build(KerasShapesWrapper input_shape) { - var ndims = input_shape.ndim; + var single_shape = input_shape.ToSingleShape(); + var ndims = single_shape.ndim; foreach (var (idx, x) in enumerate(axis)) if (x < 0) axis[idx] = ndims + x; var axis_to_dim = new Dictionary(); foreach (var x in axis) - axis_to_dim[x] = (int)input_shape[x]; + axis_to_dim[x] = (int)single_shape[x]; inputSpec = new InputSpec(ndim: ndims, axes: axis_to_dim); var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType; diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs index c23dde69..a65154bf 100644 --- a/src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs +++ b/src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs @@ -15,6 +15,7 @@ ******************************************************************************/ using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Layers { @@ -45,10 +46,11 @@ namespace Tensorflow.Keras.Layers input_variance = args.Variance; } - public override void build(Shape input_shape) + public override void build(KerasShapesWrapper input_shape) { base.build(input_shape); - var ndim = input_shape.ndim; + var single_shape = input_shape.ToSingleShape(); + var ndim = single_shape.ndim; foreach (var (idx, x) in enumerate(axis)) if (x < 0) axis[idx] = ndim + x; @@ -57,8 +59,8 @@ namespace Tensorflow.Keras.Layers _reduce_axis = range(ndim).Where(d => !_keep_axis.Contains(d)).ToArray(); var _reduce_axis_mask = range(ndim).Select(d => _keep_axis.Contains(d) ? 0 : 1).ToArray(); // Broadcast any reduced axes. - _broadcast_shape = new Shape(range(ndim).Select(d => _keep_axis.Contains(d) ? input_shape.dims[d] : 1).ToArray()); - var mean_and_var_shape = _keep_axis.Select(d => input_shape.dims[d]).ToArray(); + _broadcast_shape = new Shape(range(ndim).Select(d => _keep_axis.Contains(d) ? single_shape.dims[d] : 1).ToArray()); + var mean_and_var_shape = _keep_axis.Select(d => single_shape.dims[d]).ToArray(); var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType; var param_shape = input_shape; diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/PreprocessingLayer.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/PreprocessingLayer.cs index 463936a3..a032dcd0 100644 --- a/src/TensorFlowNET.Keras/Layers/Preprocessing/PreprocessingLayer.cs +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/PreprocessingLayer.cs @@ -77,8 +77,8 @@ namespace Tensorflow.Keras.Layers { var data_shape = data.shape; var data_shape_nones = Enumerable.Range(0, data.ndim).Select(x => -1).ToArray(); - _args.BatchInputShape = BatchInputShape ?? new Shape(data_shape_nones); - build(data_shape); + _args.BatchInputShape = BatchInputShape ?? new Saving.KerasShapesWrapper(new Shape(data_shape_nones)); + build(new Saving.KerasShapesWrapper(data_shape)); built = true; } } diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs index 4c52af9b..6c504006 100644 --- a/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Text; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; using static Tensorflow.Binding; namespace Tensorflow.Keras.Layers @@ -35,12 +36,12 @@ namespace Tensorflow.Keras.Layers var shape = data.output_shapes[0]; if (shape.ndim == 1) data = data.map(tensor => array_ops.expand_dims(tensor, -1)); - build(data.variant_tensor.shape); + build(new KerasShapesWrapper(data.variant_tensor.shape)); var preprocessed_inputs = data.map(_preprocess); _index_lookup_layer.adapt(preprocessed_inputs); } - public override void build(Shape input_shape) + public override void build(KerasShapesWrapper input_shape) { base.build(input_shape); } diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs index 10c15b69..9ead15cb 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs @@ -1,5 +1,6 @@ using Tensorflow.Keras.ArgsDefinition.Reshaping; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Layers.Reshaping { @@ -11,7 +12,7 @@ namespace Tensorflow.Keras.Layers.Reshaping this.args = args; } - public override void build(Shape input_shape) + public override void build(KerasShapesWrapper input_shape) { if (args.cropping.rank != 1) { diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs index a8d7043e..087d59a1 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs @@ -1,5 +1,6 @@ using Tensorflow.Keras.ArgsDefinition.Reshaping; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Layers.Reshaping { @@ -15,7 +16,7 @@ namespace Tensorflow.Keras.Layers.Reshaping { this.args = args; } - public override void build(Shape input_shape) + public override void build(KerasShapesWrapper input_shape) { built = true; _buildInputShape = input_shape; diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs index 796c2dd3..04a1af60 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs @@ -1,5 +1,6 @@ using Tensorflow.Keras.ArgsDefinition.Reshaping; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Layers.Reshaping { @@ -14,7 +15,7 @@ namespace Tensorflow.Keras.Layers.Reshaping this.args = args; } - public override void build(Shape input_shape) + public override void build(KerasShapesWrapper input_shape) { built = true; _buildInputShape = input_shape; diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs index 8e7a19a9..e391775c 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs @@ -5,6 +5,7 @@ using Tensorflow.Keras.Engine; using Tensorflow.Keras.Utils; using static Tensorflow.Binding; using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Layers { public class Permute : Layer @@ -14,14 +15,15 @@ namespace Tensorflow.Keras.Layers { { this.dims = args.dims; } - public override void build(Shape input_shape) + public override void build(KerasShapesWrapper input_shape) { - var rank = input_shape.rank; + var single_shape = input_shape.ToSingleShape(); + var rank = single_shape.rank; if (dims.Length != rank - 1) { throw new ValueError("Dimensions must match."); } - permute = new int[input_shape.rank]; + permute = new int[single_shape.rank]; dims.CopyTo(permute, 1); built = true; _buildInputShape = input_shape; diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs index 6b755ece..310e8057 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; // from tensorflow.python.distribute import distribution_strategy_context as ds_context; namespace Tensorflow.Keras.Layers.Rnn @@ -36,7 +37,7 @@ namespace Tensorflow.Keras.Layers.Rnn //} } - public override void build(Shape input_shape) + public override void build(KerasShapesWrapper input_shape) { if (!cell.Built) { diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs index 19669b4b..2d7aab70 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs @@ -1,5 +1,6 @@ using System.Data; using Tensorflow.Keras.ArgsDefinition.Rnn; +using Tensorflow.Keras.Saving; using Tensorflow.Operations.Activation; using static HDF.PInvoke.H5Z; using static Tensorflow.ApiDef.Types; @@ -14,12 +15,13 @@ namespace Tensorflow.Keras.Layers.Rnn this.args = args; } - public override void build(Shape input_shape) + public override void build(KerasShapesWrapper input_shape) { - var input_dim = input_shape[-1]; + var single_shape = input_shape.ToSingleShape(); + var input_dim = single_shape[-1]; _buildInputShape = input_shape; - kernel = add_weight("kernel", (input_shape[-1], args.Units), + kernel = add_weight("kernel", (single_shape[-1], args.Units), initializer: args.KernelInitializer //regularizer = self.kernel_regularizer, //constraint = self.kernel_constraint, diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs index 9e5af450..46061b21 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Text; using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Layers.Rnn { @@ -18,11 +19,12 @@ namespace Tensorflow.Keras.Layers.Rnn this.args = args; } - public override void build(Shape input_shape) + public override void build(KerasShapesWrapper input_shape) { - var input_dim = input_shape[-1]; + var single_shape = input_shape.ToSingleShape(); + var input_dim = single_shape[-1]; - kernel = add_weight("kernel", (input_shape[-1], args.Units), + kernel = add_weight("kernel", (single_shape[-1], args.Units), initializer: args.KernelInitializer ); diff --git a/src/TensorFlowNET.Keras/Models/ModelsApi.cs b/src/TensorFlowNET.Keras/Models/ModelsApi.cs index 3a997ff2..44dca58d 100644 --- a/src/TensorFlowNET.Keras/Models/ModelsApi.cs +++ b/src/TensorFlowNET.Keras/Models/ModelsApi.cs @@ -11,7 +11,7 @@ namespace Tensorflow.Keras.Models { public class ModelsApi: IModelsApi { - public Functional from_config(ModelConfig config) + public Functional from_config(FunctionalConfig config) => Functional.from_config(config); public IModel load_model(string filepath, bool compile = true, LoadOptions? options = null) diff --git a/src/TensorFlowNET.Keras/Saving/KerasMetaData.cs b/src/TensorFlowNET.Keras/Saving/KerasMetaData.cs index 52e32b7c..04429681 100644 --- a/src/TensorFlowNET.Keras/Saving/KerasMetaData.cs +++ b/src/TensorFlowNET.Keras/Saving/KerasMetaData.cs @@ -22,16 +22,19 @@ namespace Tensorflow.Keras.Saving public int SharedObjectId { get; set; } [JsonProperty("must_restore_from_config")] public bool MustRestoreFromConfig { get; set; } + [JsonProperty("config")] public JObject Config { get; set; } [JsonProperty("build_input_shape")] - public TensorShapeConfig BuildInputShape { get; set; } + public KerasShapesWrapper BuildInputShape { get; set; } [JsonProperty("batch_input_shape")] - public TensorShapeConfig BatchInputShape { get; set; } + public KerasShapesWrapper BatchInputShape { get; set; } [JsonProperty("activity_regularizer")] public IRegularizer ActivityRegularizer { get; set; } [JsonProperty("input_spec")] public JToken InputSpec { get; set; } [JsonProperty("stateful")] public bool? Stateful { get; set; } + [JsonProperty("model_config")] + public KerasModelConfig? ModelConfig { get; set; } } } diff --git a/src/TensorFlowNET.Keras/Saving/KerasModelConfig.cs b/src/TensorFlowNET.Keras/Saving/KerasModelConfig.cs new file mode 100644 index 00000000..256c284a --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/KerasModelConfig.cs @@ -0,0 +1,16 @@ +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Saving +{ + public class KerasModelConfig + { + [JsonProperty("class_name")] + public string ClassName { get; set; } + [JsonProperty("config")] + public JObject Config { get; set; } + } +} diff --git a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs index 9cdd3b50..41d1f031 100644 --- a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs +++ b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs @@ -8,6 +8,7 @@ using System.Diagnostics; using System.Linq; using System.Reflection; using System.Text.RegularExpressions; +using Tensorflow.Extensions; using Tensorflow.Framework.Models; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; @@ -356,7 +357,7 @@ namespace Tensorflow.Keras.Saving var (obj, setter) = _revive_from_config(identifier, metadata, node_id); if (obj is null) { - (obj, setter) = _revive_custom_object(identifier, metadata); + (obj, setter) = revive_custom_object(identifier, metadata); } if(obj is null) { @@ -398,7 +399,7 @@ namespace Tensorflow.Keras.Saving return (obj, setter); } - private (Trackable, Action) _revive_custom_object(string identifier, KerasMetaData metadata) + private (Trackable, Action) revive_custom_object(string identifier, KerasMetaData metadata) { if(identifier == SavedModel.Constants.LAYER_IDENTIFIER) { @@ -437,7 +438,7 @@ namespace Tensorflow.Keras.Saving } else { - model = new Functional(new Tensors(), new Tensors(), config["name"].ToObject()); + model = new Functional(new Tensors(), new Tensors(), config.TryGetOrReturnNull("name")); } // Record this model and its layers. This will later be used to reconstruct @@ -619,7 +620,7 @@ namespace Tensorflow.Keras.Saving } } - private bool _try_build_layer(Layer obj, int node_id, Shape build_input_shape) + private bool _try_build_layer(Layer obj, int node_id, KerasShapesWrapper build_input_shape) { if (obj.Built) return true; @@ -679,10 +680,10 @@ namespace Tensorflow.Keras.Saving return inputs; } - private Shape _infer_input_shapes(int layer_node_id) + private KerasShapesWrapper _infer_input_shapes(int layer_node_id) { var inputs = _infer_inputs(layer_node_id); - return nest.map_structure(x => x.shape, inputs); + return new KerasShapesWrapper(nest.map_structure(x => x.shape, inputs)); } private int? _search_for_child_node(int parent_id, IEnumerable path_to_child) diff --git a/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs index 56190a22..e6c9ed42 100644 --- a/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs @@ -173,6 +173,11 @@ namespace Tensorflow.Keras.Utils obj is not Type; } + public static Tensor generate_placeholders_from_shape(Shape shape) + { + return array_ops.placeholder(keras.backend.floatx(), shape); + } + // recusive static bool uses_keras_history(Tensor op_input) { diff --git a/src/TensorFlowNET.Keras/Utils/generic_utils.cs b/src/TensorFlowNET.Keras/Utils/generic_utils.cs index 672ac60e..6a59fb88 100644 --- a/src/TensorFlowNET.Keras/Utils/generic_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/generic_utils.cs @@ -102,9 +102,9 @@ namespace Tensorflow.Keras.Utils return args as LayerArgs; } - public static ModelConfig deserialize_model_config(JToken json) + public static FunctionalConfig deserialize_model_config(JToken json) { - ModelConfig config = new ModelConfig(); + FunctionalConfig config = new FunctionalConfig(); config.Name = json["name"].ToObject(); config.Layers = new List(); var layersToken = json["layers"]; From 67cf274f7fb289db391b23cc092d772faed2e0fb Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Sat, 22 Apr 2023 17:22:20 +0800 Subject: [PATCH 2/2] Remove debug informations before. --- .../Training/Saving/SavedModel/function_deserialization.cs | 6 +----- src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs | 6 ------ test/TensorFlowNET.Keras.UnitTest/Layers/ModelSaveTest.cs | 4 ++-- 3 files changed, 3 insertions(+), 13 deletions(-) diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs index af9fbeda..77b115a4 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs @@ -116,12 +116,8 @@ namespace Tensorflow.Training.Saving.SavedModel } Dictionary loaded_gradients = new(); - // Debug(Rinne) - var temp = _sort_function_defs(library, function_deps); - int i = 0; - foreach (var fdef in temp) + foreach (var fdef in _sort_function_defs(library, function_deps)) { - i++; var orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix, new_gradient_op_types); object structured_input_signature = null; diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs index ae7e2cf5..727d18a8 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs @@ -214,12 +214,6 @@ namespace Tensorflow continue; } var proto = _proto.Nodes[node_id]; - if(node_id == 10522) - { - // Debug(Rinne) - Console.WriteLine(); - } - var temp = _get_node_dependencies(proto); foreach (var dep in _get_node_dependencies(proto).Values.Distinct()) { deps.Add(dep); diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/ModelSaveTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/ModelSaveTest.cs index 647b2ad7..90f5f380 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/ModelSaveTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/ModelSaveTest.cs @@ -18,8 +18,8 @@ namespace TensorFlowNET.Keras.UnitTest { var model = GetFunctionalModel(); var config = model.get_config(); - Debug.Assert(config is ModelConfig); - var new_model = new ModelsApi().from_config(config as ModelConfig); + Debug.Assert(config is FunctionalConfig); + var new_model = new ModelsApi().from_config(config as FunctionalConfig); Assert.AreEqual(model.Layers.Count, new_model.Layers.Count); }