diff --git a/src/TensorFlowNET.Core/Keras/Regularizers/IRegularizer.cs b/src/TensorFlowNET.Core/Keras/Regularizers/IRegularizer.cs index e5de76dd..06dbb7c8 100644 --- a/src/TensorFlowNET.Core/Keras/Regularizers/IRegularizer.cs +++ b/src/TensorFlowNET.Core/Keras/Regularizers/IRegularizer.cs @@ -12,5 +12,14 @@ namespace Tensorflow.Keras [JsonProperty("config")] IDictionary Config { get; } Tensor Apply(RegularizerArgs args); - } + } + + public interface IRegularizerApi + { + IRegularizer GetRegularizerFromName(string name); + IRegularizer L1 { get; } + IRegularizer L2 { get; } + IRegularizer L1L2 { get; } + } + } diff --git a/src/TensorFlowNET.Core/Operations/Regularizers/L1.cs b/src/TensorFlowNET.Core/Operations/Regularizers/L1.cs index 8a5c6889..9e061945 100644 --- a/src/TensorFlowNET.Core/Operations/Regularizers/L1.cs +++ b/src/TensorFlowNET.Core/Operations/Regularizers/L1.cs @@ -9,7 +9,7 @@ namespace Tensorflow.Operations.Regularizers float _l1; private readonly Dictionary _config; - public string ClassName => "L2"; + public string ClassName => "L1"; public virtual IDictionary Config => _config; public L1(float l1 = 0.01f) diff --git a/src/TensorFlowNET.Keras/Regularizers.cs b/src/TensorFlowNET.Keras/Regularizers.cs index 9c6d07ca..73b72a05 100644 --- a/src/TensorFlowNET.Keras/Regularizers.cs +++ b/src/TensorFlowNET.Keras/Regularizers.cs @@ -1,17 +1,51 @@ -namespace Tensorflow.Keras +using Tensorflow.Operations.Regularizers; + +namespace Tensorflow.Keras { - public class Regularizers + public class Regularizers: IRegularizerApi { + private static Dictionary _nameActivationMap; + public IRegularizer l1(float l1 = 0.01f) - => new Tensorflow.Operations.Regularizers.L1(l1); + => new L1(l1); public IRegularizer l2(float l2 = 0.01f) - => new Tensorflow.Operations.Regularizers.L2(l2); + => new L2(l2); //From TF source //# The default value for l1 and l2 are different from the value in l1_l2 //# for backward compatibility reason. Eg, L1L2(l2=0.1) will only have l2 //# and no l1 penalty. public IRegularizer l1l2(float l1 = 0.00f, float l2 = 0.00f) - => new Tensorflow.Operations.Regularizers.L1L2(l1, l2); + => new L1L2(l1, l2); + + static Regularizers() + { + _nameActivationMap = new Dictionary(); + _nameActivationMap["L1"] = new L1(); + _nameActivationMap["L1"] = new L2(); + _nameActivationMap["L1"] = new L1L2(); + } + + public IRegularizer L1 => l1(); + + public IRegularizer L2 => l2(); + + public IRegularizer L1L2 => l1l2(); + + public IRegularizer GetRegularizerFromName(string name) + { + if (name == null) + { + throw new Exception($"Regularizer name cannot be null"); + } + if (!_nameActivationMap.TryGetValue(name, out var res)) + { + throw new Exception($"Regularizer {name} not found"); + } + else + { + return res; + } + } } } diff --git a/test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs b/test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs index 53a67cbf..c733537e 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs @@ -1,6 +1,7 @@ using Microsoft.VisualStudio.TestPlatform.Utilities; using Microsoft.VisualStudio.TestTools.UnitTesting; using Newtonsoft.Json.Linq; +using System.Collections.Generic; using System.Linq; using System.Xml.Linq; using Tensorflow.Keras.Engine; @@ -129,6 +130,53 @@ public class ModelLoadTest } + [TestMethod] + public void BiasRegularizerSaveAndLoad() + { + var savemodel = keras.Sequential(new List() + { + tf.keras.layers.InputLayer((227, 227, 3)), + tf.keras.layers.Conv2D(96, (11, 11), (4, 4), activation:"relu", padding:"valid"), + tf.keras.layers.BatchNormalization(), + tf.keras.layers.MaxPooling2D((3, 3), strides:(2, 2)), + + tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: keras.activations.Relu, bias_regularizer:keras.regularizers.L1L2), + tf.keras.layers.BatchNormalization(), + + tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: keras.activations.Relu, bias_regularizer:keras.regularizers.L2), + tf.keras.layers.BatchNormalization(), + + tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: keras.activations.Relu, bias_regularizer:keras.regularizers.L1), + tf.keras.layers.BatchNormalization(), + tf.keras.layers.MaxPooling2D((3, 3), (2, 2)), + + tf.keras.layers.Flatten(), + + tf.keras.layers.Dense(1000, activation: "linear"), + tf.keras.layers.Softmax(1) + }); + + savemodel.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" }); + + var num_epochs = 1; + var batch_size = 8; + + var trainDataset = new RandomDataSet(new Shape(227, 227, 3), 16); + + savemodel.fit(trainDataset.Data, trainDataset.Labels, batch_size, num_epochs); + + savemodel.save(@"./bias_regularizer_save_and_load", save_format: "tf"); + + var loadModel = tf.keras.models.load_model(@"./bias_regularizer_save_and_load"); + loadModel.summary(); + + loadModel.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" }); + + var fitDataset = new RandomDataSet(new Shape(227, 227, 3), 16); + + loadModel.fit(fitDataset.Data, fitDataset.Labels, batch_size, num_epochs); + } + [TestMethod] public void CreateConcatenateModelSaveAndLoad()