diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs index 9b37f951..f86eca12 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs @@ -45,5 +45,7 @@ public IRegularizer ActivityRegularizer { get; set; } public bool Autocast { get; set; } + + public bool IsFromConfig { get; set; } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/ResizingArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/ResizingArgs.cs index bb6d9277..cf11595e 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/ResizingArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/ResizingArgs.cs @@ -1,6 +1,6 @@ namespace Tensorflow.Keras.ArgsDefinition { - public class ResizingArgs : LayerArgs + public class ResizingArgs : PreprocessingLayerArgs { public int Height { get; set; } public int Width { get; set; } diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/PreprocessingLayer.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/PreprocessingLayer.cs new file mode 100644 index 00000000..bd86874b --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/PreprocessingLayer.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Layers +{ + public class PreprocessingLayer : Layer + { + public PreprocessingLayer(PreprocessingLayerArgs args) : base(args) + { + + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/Resizing.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/Resizing.cs index eeb813d7..9d0589bc 100644 --- a/src/TensorFlowNET.Keras/Layers/Preprocessing/Resizing.cs +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/Resizing.cs @@ -1,7 +1,9 @@ -using System; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using System; using System.Text; using Tensorflow.Keras.ArgsDefinition; -using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Layers { @@ -9,7 +11,7 @@ namespace Tensorflow.Keras.Layers /// Resize the batched image input to target height and width. /// The input should be a 4-D tensor in the format of NHWC. /// - public class Resizing : Layer + public class Resizing : PreprocessingLayer { ResizingArgs args; public Resizing(ResizingArgs args) : base(args) @@ -26,5 +28,12 @@ namespace Tensorflow.Keras.Layers { return new TensorShape(input_shape.dims[0], args.Height, args.Width, input_shape.dims[3]); } + + public static Resizing from_config(JObject config) + { + var args = JsonConvert.DeserializeObject(config.ToString()); + args.IsFromConfig = true; + return new Resizing(args); + } } } diff --git a/src/TensorFlowNET.Keras/Saving/KerasMetaData.cs b/src/TensorFlowNET.Keras/Saving/KerasMetaData.cs index 7646695b..e9839850 100644 --- a/src/TensorFlowNET.Keras/Saving/KerasMetaData.cs +++ b/src/TensorFlowNET.Keras/Saving/KerasMetaData.cs @@ -1,4 +1,5 @@ using Newtonsoft.Json; +using Newtonsoft.Json.Linq; using System; using System.Collections.Generic; using System.Text; @@ -16,7 +17,7 @@ namespace Tensorflow.Keras.Saving public int SharedObjectId { get; set; } [JsonProperty("must_restore_from_config")] public bool MustRestoreFromConfig { get; set; } - public ModelConfig Config { get; set; } + public JObject Config { get; set; } [JsonProperty("build_input_shape")] public TensorShapeConfig BuildInputShape { get; set; } } diff --git a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs index 82722cc1..621d79c5 100644 --- a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs +++ b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs @@ -5,8 +5,10 @@ using System.Linq; using System.Text.RegularExpressions; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; using ThirdParty.Tensorflow.Python.Keras.Protobuf; using static Tensorflow.Binding; +using static Tensorflow.KerasApi; namespace Tensorflow.Keras.Saving { @@ -73,7 +75,7 @@ namespace Tensorflow.Keras.Saving { model = new Sequential(new SequentialArgs { - Name = config.Name + Name = config.GetValue("name").ToString() }); } else if (class_name == "Functional") @@ -97,7 +99,12 @@ namespace Tensorflow.Keras.Saving var class_name = metadata.ClassName; var shared_object_id = metadata.SharedObjectId; var must_restore_from_config = metadata.MustRestoreFromConfig; - + var obj = class_name switch + { + "Resizing" => Resizing.from_config(config), + _ => throw new NotImplementedException("") + }; + var built = _try_build_layer(obj, node_id, metadata.BuildInputShape); return null; } @@ -157,5 +164,13 @@ namespace Tensorflow.Keras.Saving return false; } + + bool _try_build_layer(Layer obj, int node_id, TensorShape build_input_shape) + { + if (obj.Built) + return true; + + return false; + } } } diff --git a/src/TensorFlowNET.Keras/Saving/TensorShapeConfig.cs b/src/TensorFlowNET.Keras/Saving/TensorShapeConfig.cs index 17772b8e..dd5f49c5 100644 --- a/src/TensorFlowNET.Keras/Saving/TensorShapeConfig.cs +++ b/src/TensorFlowNET.Keras/Saving/TensorShapeConfig.cs @@ -10,6 +10,6 @@ namespace Tensorflow.Keras.Saving public int?[] Items { get; set; } public static implicit operator TensorShape(TensorShapeConfig shape) - => new TensorShape(shape.Items.Select(x => x.HasValue ? x.Value : -1).ToArray()); + => shape == null ? null : new TensorShape(shape.Items.Select(x => x.HasValue ? x.Value : -1).ToArray()); } }