From 5f1f59897d34a664f81d632efe00f9a8014f480d Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 26 Jul 2020 22:44:07 -0500 Subject: [PATCH] Partial implementation of tf.keras. #355 --- TensorFlow.NET.sln | 54 +------------- src/TensorFlowNET.Core/APIs/tf.layers.cs | 23 +++--- src/TensorFlowNET.Core/APIs/tf.math.cs | 3 + src/TensorFlowNET.Core/Data/DatasetManager.cs | 2 +- .../Data/TensorSliceDataset.cs | 4 +- .../Exceptions/InvalidArgumentError.cs | 17 +++++ .../Gradients/GradientTape.cs | 2 +- .../Keras/ArgsDefinition/DenseArgs.cs | 56 +++++++++++++++ .../Keras/ArgsDefinition/LayerArgs.cs | 51 ++++++++++++++ .../Keras/ArgsDefinition/ModelArgs.cs | 10 +++ .../Keras/Engine/CallContext.cs | 14 ++++ .../Keras/Engine/CallContextManager.cs | 14 ++++ src/TensorFlowNET.Core/Keras/Engine/ILayer.cs | 15 ++++ .../Keras/{Layers => Engine}/Layer.cs | 70 ++++++++++++++----- src/TensorFlowNET.Core/Keras/Engine/Model.cs | 12 ++-- .../Keras/Engine/Network.cs | 55 --------------- .../Keras/Engine/Sequential.cs | 7 +- src/TensorFlowNET.Core/Keras/KerasApi.cs | 19 ++++- .../Keras/Layers/BatchNormalization.cs | 3 +- src/TensorFlowNET.Core/Keras/Layers/Conv.cs | 2 +- src/TensorFlowNET.Core/Keras/Layers/Dense.cs | 32 ++++----- .../Keras/Layers/Embedding.cs | 11 ++- .../Keras/Layers/InputLayer.cs | 8 ++- src/TensorFlowNET.Core/Keras/Layers/Node.cs | 1 + .../Keras/Layers/Pooling2D.cs | 2 +- src/TensorFlowNET.Core/Layers/Dense.cs | 20 ------ src/TensorFlowNET.Core/Layers/Layer.cs | 11 ++- .../Operations/NnOps/BasicLSTMCell.cs | 2 +- .../Operations/NnOps/BasicRNNCell.cs | 2 +- src/TensorFlowNET.Core/Status/Status.cs | 3 +- 30 files changed, 329 insertions(+), 196 deletions(-) create mode 100644 src/TensorFlowNET.Core/Exceptions/InvalidArgumentError.cs create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/DenseArgs.cs create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/ModelArgs.cs create mode 100644 src/TensorFlowNET.Core/Keras/Engine/CallContext.cs create mode 100644 src/TensorFlowNET.Core/Keras/Engine/CallContextManager.cs create mode 100644 src/TensorFlowNET.Core/Keras/Engine/ILayer.cs rename src/TensorFlowNET.Core/Keras/{Layers => Engine}/Layer.cs (85%) delete mode 100644 src/TensorFlowNET.Core/Keras/Engine/Network.cs delete mode 100644 src/TensorFlowNET.Core/Layers/Dense.cs diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln index 7bdd47e8..a7d934d9 100644 --- a/TensorFlow.NET.sln +++ b/TensorFlow.NET.sln @@ -9,11 +9,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Benchmark", "src EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.UnitTest", "test\TensorFlowNET.UnitTest\Tensorflow.UnitTest.csproj", "{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras", "src\TensorFlowNET.Keras\Tensorflow.Keras.csproj", "{6268B461-486A-460B-9B3C-86493CBBAAF7}" -EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras.UnitTest", "test\Tensorflow.Keras.UnitTest\Tensorflow.Keras.UnitTest.csproj", "{EB92DD90-6346-41FB-B967-2B33A860AD98}" -EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TensorFlowNET.Console", "src\TensorFlowNET.Console\TensorFlowNET.Console.csproj", "{03F06299-3F4B-4449-A709-3A647657BC0C}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Console", "src\TensorFlowNET.Console\TensorFlowNET.Console.csproj", "{03F06299-3F4B-4449-A709-3A647657BC0C}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -103,54 +99,6 @@ Global {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x64.Build.0 = Release|x64 {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x86.ActiveCfg = Release|Any CPU {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x86.Build.0 = Release|Any CPU - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|Any CPU.Build.0 = Debug|Any CPU - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|x64.ActiveCfg = Debug|x64 - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|x64.Build.0 = Debug|x64 - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|x86.ActiveCfg = Debug|Any CPU - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|x86.Build.0 = Debug|Any CPU - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|x64.ActiveCfg = Debug|x64 - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|x64.Build.0 = Debug|x64 - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|x86.ActiveCfg = Debug|Any CPU - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|x86.Build.0 = Debug|Any CPU - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|Any CPU.ActiveCfg = Release|Any CPU - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|Any CPU.Build.0 = Release|Any CPU - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|x64.ActiveCfg = Release|x64 - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|x64.Build.0 = Release|x64 - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|x86.ActiveCfg = Release|Any CPU - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|x86.Build.0 = Release|Any CPU - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|Any CPU.ActiveCfg = Release|Any CPU - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|Any CPU.Build.0 = Release|Any CPU - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|x64.ActiveCfg = Release|x64 - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|x64.Build.0 = Release|x64 - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|x86.ActiveCfg = Release|Any CPU - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|x86.Build.0 = Release|Any CPU - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|Any CPU.Build.0 = Debug|Any CPU - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|x64.ActiveCfg = Debug|x64 - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|x64.Build.0 = Debug|x64 - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|x86.ActiveCfg = Debug|Any CPU - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|x86.Build.0 = Debug|Any CPU - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|x64.ActiveCfg = Debug|x64 - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|x64.Build.0 = Debug|x64 - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|x86.ActiveCfg = Debug|Any CPU - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|x86.Build.0 = Debug|Any CPU - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|Any CPU.ActiveCfg = Release|Any CPU - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|Any CPU.Build.0 = Release|Any CPU - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|x64.ActiveCfg = Release|x64 - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|x64.Build.0 = Release|x64 - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|x86.ActiveCfg = Release|Any CPU - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|x86.Build.0 = Release|Any CPU - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|Any CPU.ActiveCfg = Release|Any CPU - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|Any CPU.Build.0 = Release|Any CPU - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x64.ActiveCfg = Release|x64 - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x64.Build.0 = Release|x64 - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x86.ActiveCfg = Release|Any CPU - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x86.Build.0 = Release|Any CPU {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|Any CPU.Build.0 = Debug|Any CPU {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x64.ActiveCfg = Debug|Any CPU diff --git a/src/TensorFlowNET.Core/APIs/tf.layers.cs b/src/TensorFlowNET.Core/APIs/tf.layers.cs index e62d5fa2..3ebddbcf 100644 --- a/src/TensorFlowNET.Core/APIs/tf.layers.cs +++ b/src/TensorFlowNET.Core/APIs/tf.layers.cs @@ -14,9 +14,11 @@ limitations under the License. ******************************************************************************/ +using System; using System.Collections.Generic; using System.Linq; using NumSharp; +using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Layers; using Tensorflow.Operations.Activation; using static Tensorflow.Binding; @@ -173,14 +175,19 @@ namespace Tensorflow if (bias_initializer == null) bias_initializer = tf.zeros_initializer; - var layer = new Dense(units, activation, - use_bias: use_bias, - bias_initializer: bias_initializer, - kernel_initializer: kernel_initializer, - trainable: trainable, - name: name); - - return layer.apply(inputs).Item1; + var layer = new Dense(new DenseArgs + { + Units = units, + Activation = activation, + UseBias = use_bias, + BiasInitializer = bias_initializer, + KernelInitializer = kernel_initializer, + Trainable = trainable, + Name = name + }); + + throw new NotImplementedException(""); + //return layer.apply(inputs).Item1; } /// diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index 23f7753b..49bd7ca8 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -515,6 +515,9 @@ namespace Tensorflow public Tensor sum(Tensor input, int axis, bool keep_dims = false, string name = null) => gen_math_ops._sum(input, axis, keep_dims: keep_dims, name: name); + public Tensor reduce_mean(Tensor input_tensors, int axis, bool keepdims = false, string name = null) + => math_ops.reduce_mean(input_tensors, axis: new[] { axis }, keepdims: keepdims, name: name); + public Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null) => math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name, reduction_indices: reduction_indices); diff --git a/src/TensorFlowNET.Core/Data/DatasetManager.cs b/src/TensorFlowNET.Core/Data/DatasetManager.cs index b46fb791..ddfa1835 100644 --- a/src/TensorFlowNET.Core/Data/DatasetManager.cs +++ b/src/TensorFlowNET.Core/Data/DatasetManager.cs @@ -7,7 +7,7 @@ namespace Tensorflow { public class DatasetManager { - public IDatasetV2 from_tensor_slices(NDArray features, NDArray labels) + public IDatasetV2 from_tensor_slices(Tensor features, Tensor labels) => new TensorSliceDataset(features, labels); } } diff --git a/src/TensorFlowNET.Core/Data/TensorSliceDataset.cs b/src/TensorFlowNET.Core/Data/TensorSliceDataset.cs index e35cd8c5..0fc88f31 100644 --- a/src/TensorFlowNET.Core/Data/TensorSliceDataset.cs +++ b/src/TensorFlowNET.Core/Data/TensorSliceDataset.cs @@ -11,9 +11,9 @@ namespace Tensorflow { public class TensorSliceDataset : DatasetSource { - public TensorSliceDataset(NDArray features, NDArray labels) + public TensorSliceDataset(Tensor features, Tensor labels) { - _tensors = new[] { tf.convert_to_tensor(features), tf.convert_to_tensor(labels) }; + _tensors = new[] { features, labels }; var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray(); structure = batched_spec.Select(x => x._unbatch()).ToArray(); diff --git a/src/TensorFlowNET.Core/Exceptions/InvalidArgumentError.cs b/src/TensorFlowNET.Core/Exceptions/InvalidArgumentError.cs new file mode 100644 index 00000000..b16e4fd6 --- /dev/null +++ b/src/TensorFlowNET.Core/Exceptions/InvalidArgumentError.cs @@ -0,0 +1,17 @@ +using System; + +namespace Tensorflow +{ + public class InvalidArgumentError : TensorflowException + { + public InvalidArgumentError() : base() + { + + } + + public InvalidArgumentError(string message) : base(message) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/GradientTape.cs b/src/TensorFlowNET.Core/Gradients/GradientTape.cs index d33d38a2..c47e8af8 100644 --- a/src/TensorFlowNET.Core/Gradients/GradientTape.cs +++ b/src/TensorFlowNET.Core/Gradients/GradientTape.cs @@ -119,7 +119,7 @@ namespace Tensorflow.Gradients return (results[0], results[1]); } - public Tensor[] gradient(Tensor target, ResourceVariable[] sources) + public Tensor[] gradient(Tensor target, IEnumerable sources) { if (_recording) { diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DenseArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DenseArgs.cs new file mode 100644 index 00000000..ef05f929 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DenseArgs.cs @@ -0,0 +1,56 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Operations.Activation; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class DenseArgs : LayerArgs + { + /// + /// Positive integer, dimensionality of the output space. + /// + public int Units { get; set; } + + /// + /// Activation function to use. + /// + public IActivation Activation { get; set; } + + /// + /// Whether the layer uses a bias vector. + /// + public bool UseBias { get; set; } = true; + + /// + /// Initializer for the `kernel` weights matrix. + /// + public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; + + /// + /// Initializer for the bias vector. + /// + public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; + + /// + /// Regularizer function applied to the `kernel` weights matrix. + /// + public IInitializer KernelRegularizer { get; set; } + + /// + /// Regularizer function applied to the bias vector. + /// + public IInitializer BiasRegularizer { get; set; } + + /// + /// Constraint function applied to the `kernel` weights matrix. + /// + public Action KernelConstraint { get; set; } + + /// + /// Constraint function applied to the bias vector. + /// + public Action BiasConstraint { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs new file mode 100644 index 00000000..8c7f6597 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs @@ -0,0 +1,51 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class LayerArgs + { + /// + /// Indicates whether the layer's weights are updated during training + /// and whether the layer's updates are run during training. + /// + public bool Trainable { get; set; } = true; + + public string Name { get; set; } + + /// + /// Only applicable to input layers. + /// + public TF_DataType DType { get; set; } + + /// + /// Whether the `call` method can be used to build a TF graph without issues. + /// This attribute has no effect if the model is created using the Functional + /// API. Instead, `model.dynamic` is determined based on the internal layers. + /// + public bool Dynamic { get; set; } = false; + + /// + /// Only applicable to input layers. + /// + public TensorShape InputShape { get; set; } + + /// + /// Only applicable to input layers. + /// + public TensorShape BatchInputShape { get; set; } + + /// + /// Initial weight values. + /// + public float[] Weights { get; set; } + + /// + /// Regularizer function applied to the output of the layer(its "activation"). + /// + public IInitializer ActivityRegularizer { get; set; } + + public bool Autocast { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/ModelArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/ModelArgs.cs new file mode 100644 index 00000000..f8e13bbe --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/ModelArgs.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class ModelArgs : LayerArgs + { + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/CallContext.cs b/src/TensorFlowNET.Core/Keras/Engine/CallContext.cs new file mode 100644 index 00000000..7f63f088 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Engine/CallContext.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Engine +{ + public class CallContext + { + public CallContextManager enter() + { + return new CallContextManager(); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/CallContextManager.cs b/src/TensorFlowNET.Core/Keras/Engine/CallContextManager.cs new file mode 100644 index 00000000..fca7404f --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Engine/CallContextManager.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Engine +{ + public class CallContextManager : IDisposable + { + public void Dispose() + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/ILayer.cs b/src/TensorFlowNET.Core/Keras/Engine/ILayer.cs new file mode 100644 index 00000000..0b1f422d --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Engine/ILayer.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Engine +{ + /// + /// A layer is a callable object that takes as input one or more tensors and + /// that outputs one or more tensors. + /// + public interface ILayer + { + Tensor Apply(Tensor inputs, bool is_training = false); + } +} diff --git a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs similarity index 85% rename from src/TensorFlowNET.Core/Keras/Layers/Layer.cs rename to src/TensorFlowNET.Core/Keras/Engine/Layer.cs index 5118cb71..60379a0c 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs @@ -17,12 +17,14 @@ using System; using System.Collections.Generic; using System.Linq; -using Tensorflow.Keras.Engine; +using System.Threading; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Layers; using Tensorflow.Keras.Utils; using Tensorflow.Train; using static Tensorflow.Binding; -namespace Tensorflow.Keras.Layers +namespace Tensorflow.Keras.Engine { /// /// Base layer class. @@ -32,8 +34,10 @@ namespace Tensorflow.Keras.Layers /// /// tensorflow\python\keras\engine\base_layer.py /// - public class Layer : AutoTrackable + public class Layer : AutoTrackable, ILayer { + protected LayerArgs _args; + /// /// Indicates whether `build` needs to be called upon layer call, to create /// the layer's weights. @@ -52,6 +56,7 @@ namespace Tensorflow.Keras.Layers protected InputSpec input_spec; protected bool supports_masking; protected List _trainable_weights; + public List trainable_variables => _trainable_weights; protected List _non_trainable_weights; private string _name; public string name => _name; @@ -72,13 +77,12 @@ namespace Tensorflow.Keras.Layers float _initial_weights; #pragma warning restore CS0169 // The field 'Layer._initial_weights' is never used - public Layer(bool trainable = true, - string name = null, - TF_DataType dtype = TF_DataType.DtInvalid, - int[] input_shape = null) + ThreadLocal _call_context; + public CallContext CallContext => _call_context.Value; + + public Layer(LayerArgs args) { - this.trainable = trainable; - this._dtype = dtype; + _args = args; // A stateful layer is a layer whose updates are run during inference too, // for instance stateful RNNs. stateful = false; @@ -94,17 +98,47 @@ namespace Tensorflow.Keras.Layers _updates = new List(); // Manage input shape information if passed. - if(input_shape != null) + + _inbound_nodes = new List(); + } + + /// + /// Wraps `call`, applying pre- and post-processing steps. + /// + /// + /// + /// + public Tensor Apply(Tensor input, bool is_training = false) + { + var input_list = new Tensor[] { input }; + + if (_call_context == null) + _call_context = new ThreadLocal() + { + Value = new CallContext() + }; + + using var ctxManager = CallContext.enter(); + + string name_scope = ""; + if (tf.context.executing_eagerly()) { - var shapes = new List { -1 }; - shapes.AddRange(input_shape); - _batch_input_shape = shapes.ToArray(); + name_scope = _name; + } + else + { + throw new NotImplementedException(""); } - - _dtype = dtype; + tf_with(ops.name_scope(name_scope), scope => + { + if (!built) + _maybe_build(input); - _inbound_nodes = new List(); + call(input, is_training: is_training); + }); + + throw new NotImplementedException(""); } public Tensor[] __call__(Tensor[] inputs, @@ -147,7 +181,7 @@ namespace Tensorflow.Keras.Layers _maybe_build(inputs[0]); outputs = call(inputs[0], - training: training, + // training: training, state: state); (input, outputs) = _set_connectivity_metadata_(input, outputs); @@ -183,7 +217,7 @@ namespace Tensorflow.Keras.Layers return null; } - protected virtual Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) + protected virtual Tensor[] call(Tensor inputs, bool is_training = false, Tensor state = null) { throw new NotImplementedException(""); } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Model.cs b/src/TensorFlowNET.Core/Keras/Engine/Model.cs index b52c5335..c98bd62d 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Model.cs @@ -1,8 +1,12 @@ -using Tensorflow.Keras.Optimizers; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Optimizers; namespace Tensorflow.Keras.Engine { - public class Model : Network + /// + /// `Model` groups layers into an object with training and inference features. + /// + public class Model : Layer { #pragma warning disable CS0169 // The field 'Model._cloning' is never used bool _cloning; @@ -15,8 +19,8 @@ namespace Tensorflow.Keras.Engine string loss; IOptimizer optimizer; - public Model(string name = null) - : base(name: name) + public Model(ModelArgs args) + : base(args) { } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Network.cs b/src/TensorFlowNET.Core/Keras/Engine/Network.cs deleted file mode 100644 index 86d03231..00000000 --- a/src/TensorFlowNET.Core/Keras/Engine/Network.cs +++ /dev/null @@ -1,55 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using System.Collections.Generic; -using Tensorflow.Keras.Layers; - -namespace Tensorflow.Keras.Engine -{ - public class Network : Layer - { - protected bool _is_compiled; - protected bool _expects_training_arg; - protected bool _compute_output_and_mask_jointly; - /// - /// All layers in order of horizontal graph traversal. - /// Entries are unique. Includes input and output layers. - /// - protected List _layers; - - public Network(string name = null) - : base(name: name) - { - _init_subclassed_network(name); - } - - protected virtual void _init_subclassed_network(string name = null) - { - _base_init(name: name); - } - - protected virtual void _base_init(string name = null) - { - _init_set_name(name); - trainable = true; - _is_compiled = false; - _expects_training_arg = false; - _compute_output_and_mask_jointly = false; - supports_masking = false; - _layers = new List(); - } - } -} diff --git a/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs b/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs index c6fcd68b..ff9392c8 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Layers; namespace Tensorflow.Keras.Engine @@ -28,10 +29,10 @@ namespace Tensorflow.Keras.Engine #pragma warning restore CS0169 // The field 'Sequential.outputs' is never used public Sequential(string name = null) - : base(name: name) + : base(new ModelArgs { Name = name}) { supports_masking = true; - _compute_output_and_mask_jointly = true; + // _compute_output_and_mask_jointly = true; } public void __enter__() @@ -47,7 +48,7 @@ namespace Tensorflow.Keras.Engine { built = false; var set_inputs = false; - if(_layers.Count == 0) + //if(_layers.Count == 0) { if(layer is InputLayer) { diff --git a/src/TensorFlowNET.Core/Keras/KerasApi.cs b/src/TensorFlowNET.Core/Keras/KerasApi.cs index f9fd94d6..f8dcadff 100644 --- a/src/TensorFlowNET.Core/Keras/KerasApi.cs +++ b/src/TensorFlowNET.Core/Keras/KerasApi.cs @@ -1,6 +1,11 @@ -using System.Data; +using System; +using System.Data; using Tensorflow.Keras; +using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Datasets; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; +using Tensorflow.Operations.Activation; namespace Tensorflow { @@ -8,5 +13,17 @@ namespace Tensorflow { public KerasDataset datasets { get; } = new KerasDataset(); public Initializers initializers { get; } = new Initializers(); + public Layers layers { get; } = new Layers(); + + public class Layers + { + public ILayer Dense(int units, + IActivation activation = null) + => new Dense(new DenseArgs + { + Units = units, + Activation = activation + }); + } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs index 2d132694..ef71cd37 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs @@ -143,12 +143,13 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) + protected override Tensor[] call(Tensor inputs, bool is_training = false, Tensor state = null) { Tensor outputs = null; if (fused) { + Tensor training = tf.convert_to_tensor(is_training); outputs = _fused_batch_norm(inputs, training: training); return new[] { outputs, outputs }; } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs index 7f763fb8..cc04bc0f 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs @@ -108,7 +108,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) + protected override Tensor[] call(Tensor inputs, bool training = false, Tensor state = null) { var outputs = _convolution_op.__call__(inputs, kernel); if (use_bias) diff --git a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs index 1b481ada..906747b8 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs @@ -17,35 +17,29 @@ using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Operations.Activation; using static Tensorflow.Binding; namespace Tensorflow.Keras.Layers { - public class Dense : Tensorflow.Layers.Layer + /// + /// Just your regular densely-connected NN layer. + /// + public class Dense : Layer { protected int units; protected IActivation activation; protected bool use_bias; protected IInitializer kernel_initializer; protected IInitializer bias_initializer; - protected RefVariable kernel; - protected RefVariable bias; + protected IVariableV1 kernel; + protected IVariableV1 bias; - public Dense(int units, - IActivation activation, - string name = null, - bool use_bias = true, - bool trainable = false, - IInitializer kernel_initializer = null, - IInitializer bias_initializer = null) : base(trainable: trainable, name: name) + public Dense(DenseArgs args) : + base(args) { - this.units = units; - this.activation = activation; - this.use_bias = use_bias; - this.kernel_initializer = kernel_initializer; - this.bias_initializer = bias_initializer; this.supports_masking = true; this.input_spec = new InputSpec(min_ndim: 2); } @@ -56,14 +50,14 @@ namespace Tensorflow.Keras.Layers var axes = new Dictionary(); axes[-1] = last_dim; input_spec = new InputSpec(min_ndim: 2, axes: axes); - kernel = (RefVariable)add_weight( + kernel = add_weight( "kernel", shape: new int[] { last_dim, units }, initializer: kernel_initializer, dtype: _dtype, trainable: true); if (use_bias) - bias = (RefVariable)add_weight( + bias = add_weight( "bias", shape: new int[] { units }, initializer: bias_initializer, @@ -73,7 +67,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) + protected override Tensor[] call(Tensor inputs, bool training = false, Tensor state = null) { Tensor outputs = null; var rank = inputs.rank; @@ -83,7 +77,7 @@ namespace Tensorflow.Keras.Layers } else { - outputs = gen_math_ops.mat_mul(inputs, kernel); + outputs = gen_math_ops.mat_mul(inputs, kernel.Handle); } if (use_bias) diff --git a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs index eb526874..628b5ef4 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs @@ -14,6 +14,8 @@ limitations under the License. ******************************************************************************/ +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; using static Tensorflow.Binding; namespace Tensorflow.Keras.Layers @@ -32,7 +34,12 @@ namespace Tensorflow.Keras.Layers bool mask_zero = false, TF_DataType dtype = TF_DataType.TF_FLOAT, int[] input_shape = null, - int input_length = -1) : base(dtype: dtype, input_shape: input_shape ?? new[] { input_length }) + int input_length = -1) : + base(new LayerArgs + { + DType = dtype, + InputShape = input_shape ?? new[] { input_length } + }) { this.input_dim = input_dim; this.output_dim = output_dim; @@ -50,7 +57,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) + protected override Tensor[] call(Tensor inputs, bool is_training = false, Tensor state = null) { var dtype = inputs.dtype; if (dtype != tf.int32 && dtype != tf.int64) diff --git a/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs b/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs index be5515ec..66ba6625 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs @@ -17,6 +17,8 @@ using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; namespace Tensorflow.Keras.Layers { @@ -35,7 +37,11 @@ namespace Tensorflow.Keras.Layers TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool sparse = false, - Tensor input_tensor = null) : base(dtype: dtype, name: name) + Tensor input_tensor = null) : + base(new LayerArgs + { + DType = dtype, Name = name + }) { built = true; this.sparse = sparse; diff --git a/src/TensorFlowNET.Core/Keras/Layers/Node.cs b/src/TensorFlowNET.Core/Keras/Layers/Node.cs index a8785c1d..11862f06 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Node.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Node.cs @@ -15,6 +15,7 @@ ******************************************************************************/ using System.Linq; +using Tensorflow.Keras.Engine; namespace Tensorflow.Keras.Layers { diff --git a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs index 646751ae..3c9d0a38 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs @@ -45,7 +45,7 @@ namespace Tensorflow.Keras.Layers this.input_spec = new InputSpec(ndim: 4); } - protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) + protected override Tensor[] call(Tensor inputs, bool is_training = false, Tensor state = null) { int[] pool_shape; if (data_format == "channels_last") diff --git a/src/TensorFlowNET.Core/Layers/Dense.cs b/src/TensorFlowNET.Core/Layers/Dense.cs deleted file mode 100644 index 3edd1e71..00000000 --- a/src/TensorFlowNET.Core/Layers/Dense.cs +++ /dev/null @@ -1,20 +0,0 @@ -using Tensorflow.Operations.Activation; - -namespace Tensorflow.Layers -{ - public class Dense : Keras.Layers.Dense - { - public Dense(int units, - IActivation activation, - bool use_bias = true, - bool trainable = false, - IInitializer kernel_initializer = null) : base(units, - activation, - use_bias: use_bias, - trainable: trainable, - kernel_initializer: kernel_initializer) - { - - } - } -} diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index 0d6d2f69..5c0ad97e 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -16,11 +16,12 @@ using System; using System.Collections.Generic; +using Tensorflow.Keras.ArgsDefinition; using static Tensorflow.Binding; namespace Tensorflow.Layers { - public class Layer : Keras.Layers.Layer + public class Layer : Keras.Engine.Layer { protected Graph _graph; @@ -34,7 +35,13 @@ namespace Tensorflow.Layers public Layer(bool trainable = true, string name = null, TF_DataType dtype = TF_DataType.DtInvalid, - bool? _reuse = null) : base(trainable: trainable, name: name, dtype: dtype) + bool? _reuse = null) : + base(new LayerArgs + { + Trainable = trainable, + Name = name, + DType = dtype + }) { // For backwards compatibility, legacy layers do not use `ResourceVariable` // by default. diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs index d8d1cb3d..b08fc78d 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs @@ -74,7 +74,7 @@ namespace Tensorflow /// /// /// - protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) + protected override Tensor[] call(Tensor inputs, bool is_training = false, Tensor state = null) { var one = constant_op.constant(1, dtype: dtypes.int32); // Parameters of gates are concatenated into one multiply for efficiency. diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs index dfc1256f..8ddd4599 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs @@ -67,7 +67,7 @@ namespace Tensorflow built = true; } - protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) + protected override Tensor[] call(Tensor inputs, bool is_training = false, Tensor state = null) { // Most basic RNN: output = new_state = act(W * input + U * state + B). var concat = array_ops.concat(new[] { inputs, state }, 1); diff --git a/src/TensorFlowNET.Core/Status/Status.cs b/src/TensorFlowNET.Core/Status/Status.cs index 30051fb7..288297fb 100644 --- a/src/TensorFlowNET.Core/Status/Status.cs +++ b/src/TensorFlowNET.Core/Status/Status.cs @@ -85,8 +85,9 @@ namespace Tensorflow { case TF_Code.TF_OUT_OF_RANGE: throw new OutOfRangeError(message); + case TF_Code.TF_INVALID_ARGUMENT: + throw new InvalidArgumentError(message); default: - Console.WriteLine(message); throw new TensorflowException(message); } }