From 6e4bad40b6b36278efdbcb4d585051b549d09e14 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 15 Aug 2020 06:11:01 -0500 Subject: [PATCH] Sequential #570 --- src/TensorFlowNET.Core/APIs/tf.init.cs | 3 +- .../Keras/Activations/Activations.Relu.cs | 4 +- .../Keras/Activations/Activations.Sigmoid.cs | 26 ++++++++ .../Keras/Activations/Activations.Tanh.cs | 26 ++++++++ .../Keras/ArgsDefinition/EmbeddingArgs.cs | 15 +++++ .../Keras/ArgsDefinition/LSTMArgs.cs | 26 ++++++++ .../Keras/ArgsDefinition/LSTMCellArgs.cs | 10 +++ .../Keras/ArgsDefinition/RNNArgs.cs | 10 +++ .../Keras/Engine/InputSpec.cs | 7 +- .../Keras/Engine/Layer.Layers.cs | 42 ++++++++++++ src/TensorFlowNET.Core/Keras/Engine/Layer.cs | 6 +- src/TensorFlowNET.Core/Keras/Engine/Node.cs | 4 +- .../Keras/Engine/Sequential.cs | 65 ++++++++++--------- src/TensorFlowNET.Core/Keras/KerasApi.cs | 35 +++++++--- .../Keras/Layers/Embedding.cs | 45 ++++++------- .../Keras/Layers/InputLayer.cs | 8 ++- src/TensorFlowNET.Core/Keras/Layers/LSTM.cs | 37 +++++++++++ .../Keras/Layers/LSTMCell.cs | 20 ++++++ src/TensorFlowNET.Core/Keras/Layers/RNN.cs | 27 ++++++++ .../Operations/Initializers/Orthogonal.cs | 14 ++++ .../Operations/Initializers/RandomUniform.cs | 4 +- .../Operations/gen_array_ops.cs | 2 +- .../Operations/gen_resource_variable_ops.cs | 2 +- .../Tensorflow.Binding.csproj | 2 +- src/TensorFlowNET.Core/Tensors/Tensor.cs | 8 +++ .../Keras/LayersTest.cs | 33 ++++++++-- 26 files changed, 400 insertions(+), 81 deletions(-) create mode 100644 src/TensorFlowNET.Core/Keras/Activations/Activations.Sigmoid.cs create mode 100644 src/TensorFlowNET.Core/Keras/Activations/Activations.Tanh.cs create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/EmbeddingArgs.cs create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/LSTMArgs.cs create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/LSTMCellArgs.cs create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/RNNArgs.cs create mode 100644 src/TensorFlowNET.Core/Keras/Layers/LSTM.cs create mode 100644 src/TensorFlowNET.Core/Keras/Layers/LSTMCell.cs create mode 100644 src/TensorFlowNET.Core/Keras/Layers/RNN.cs create mode 100644 src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.init.cs b/src/TensorFlowNET.Core/APIs/tf.init.cs index 35eeb908..c674d5c9 100644 --- a/src/TensorFlowNET.Core/APIs/tf.init.cs +++ b/src/TensorFlowNET.Core/APIs/tf.init.cs @@ -27,7 +27,8 @@ namespace Tensorflow public IInitializer zeros_initializer => new Zeros(); public IInitializer ones_initializer => new Ones(); public IInitializer glorot_uniform_initializer => new GlorotUniform(); - public IInitializer uniform_initializer => new RandomUniform(); + public IInitializer random_uniform_initializer => new RandomUniform(); + public IInitializer orthogonal_initializer => new Orthogonal(); public variable_scope variable_scope(string name, string default_name = null, diff --git a/src/TensorFlowNET.Core/Keras/Activations/Activations.Relu.cs b/src/TensorFlowNET.Core/Keras/Activations/Activations.Relu.cs index 3958f702..af6849e9 100644 --- a/src/TensorFlowNET.Core/Keras/Activations/Activations.Relu.cs +++ b/src/TensorFlowNET.Core/Keras/Activations/Activations.Relu.cs @@ -20,7 +20,9 @@ namespace Tensorflow.Keras return results[0]; } - throw new NotImplementedException(""); + var _op = tf.OpDefLib._apply_op_helper("Relu", name: name, args: new { features }); + + return _op.output; }; } } diff --git a/src/TensorFlowNET.Core/Keras/Activations/Activations.Sigmoid.cs b/src/TensorFlowNET.Core/Keras/Activations/Activations.Sigmoid.cs new file mode 100644 index 00000000..90bfed35 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Activations/Activations.Sigmoid.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Operations; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras +{ + public partial class Activations + { + public Activation Sigmoid = (features, name) => + { + if (tf.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "Sigmoid", name, + null, + features); + + return results[0]; + } + + throw new NotImplementedException(""); + }; + } +} diff --git a/src/TensorFlowNET.Core/Keras/Activations/Activations.Tanh.cs b/src/TensorFlowNET.Core/Keras/Activations/Activations.Tanh.cs new file mode 100644 index 00000000..a49bd2c1 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Activations/Activations.Tanh.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Operations; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras +{ + public partial class Activations + { + public Activation Tanh = (features, name) => + { + if (tf.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "Tanh", name, + null, + features); + + return results[0]; + } + + throw new NotImplementedException(""); + }; + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/EmbeddingArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/EmbeddingArgs.cs new file mode 100644 index 00000000..d40e7552 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/EmbeddingArgs.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class EmbeddingArgs : LayerArgs + { + public int InputDim { get; set; } + public int OutputDim { get; set; } + public bool MaskZero { get; set; } + public int InputLength { get; set; } = -1; + public IInitializer EmbeddingsInitializer { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/LSTMArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/LSTMArgs.cs new file mode 100644 index 00000000..c4951bf6 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/LSTMArgs.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class LSTMArgs : RNNArgs + { + public int Units { get; set; } + public Activation Activation { get; set; } + public Activation RecurrentActivation { get; set; } + public IInitializer KernelInitializer { get; set; } + public IInitializer RecurrentInitializer { get; set; } + public IInitializer BiasInitializer { get; set; } + public bool UnitForgetBias { get; set; } + public float Dropout { get; set; } + public float RecurrentDropout { get; set; } + public int Implementation { get; set; } + public bool ReturnSequences { get; set; } + public bool ReturnState { get; set; } + public bool GoBackwards { get; set; } + public bool Stateful { get; set; } + public bool TimeMajor { get; set; } + public bool Unroll { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/LSTMCellArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/LSTMCellArgs.cs new file mode 100644 index 00000000..aba0d268 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/LSTMCellArgs.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class LSTMCellArgs : LayerArgs + { + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/RNNArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/RNNArgs.cs new file mode 100644 index 00000000..56a9a0dc --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/RNNArgs.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class RNNArgs : LayerArgs + { + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs b/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs index acaf0b78..2041fe7d 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs @@ -26,17 +26,22 @@ namespace Tensorflow.Keras.Engine public int? ndim; public int? min_ndim; Dictionary axes; + TensorShape shape; public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid, int? ndim = null, int? min_ndim = null, - Dictionary axes = null) + Dictionary axes = null, + TensorShape shape = null) { this.ndim = ndim; if (axes == null) axes = new Dictionary(); this.axes = axes; this.min_ndim = min_ndim; + this.shape = shape; + if (ndim == null && shape != null) + this.ndim = shape.ndim; } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.Layers.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.Layers.cs index 14be45b0..1cddc769 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Layer.Layers.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.Layers.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Text; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Layers; +using Tensorflow.Operations.Activation; using static Tensorflow.Binding; namespace Tensorflow.Keras.Engine @@ -100,5 +101,46 @@ namespace Tensorflow.Keras.Engine _layers.Add(layer); return layer; } + + protected Layer LSTM(int units, + Activation activation = null, + Activation recurrent_activation = null, + bool use_bias = true, + IInitializer kernel_initializer = null, + IInitializer recurrent_initializer = null, + IInitializer bias_initializer = null, + bool unit_forget_bias = true, + float dropout = 0f, + float recurrent_dropout = 0f, + int implementation = 2, + bool return_sequences = false, + bool return_state = false, + bool go_backwards = false, + bool stateful = false, + bool time_major = false, + bool unroll = false) + { + var layer = new LSTM(new LSTMArgs + { + Units = units, + Activation = activation ?? tf.keras.activations.Tanh, + RecurrentActivation = recurrent_activation ?? tf.keras.activations.Sigmoid, + KernelInitializer = kernel_initializer ?? tf.glorot_uniform_initializer, + RecurrentInitializer = recurrent_initializer ?? tf.orthogonal_initializer, + BiasInitializer = bias_initializer ?? tf.zeros_initializer, + Dropout = dropout, + RecurrentDropout = recurrent_dropout, + Implementation = implementation, + ReturnSequences = return_sequences, + ReturnState = return_state, + GoBackwards = go_backwards, + Stateful = stateful, + TimeMajor = time_major, + Unroll = unroll + }); + + _layers.Add(layer); + return layer; + } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs index 8964bb0a..b3e54b5a 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs @@ -144,7 +144,9 @@ namespace Tensorflow.Keras.Engine } // using var graph = tf.keras.backend.get_graph().as_default(); - + if (!inputs.IsEagerTensor) + tf.Context.graph_mode(); + tf_with(ops.name_scope(nameScope), scope => { if (!built) @@ -157,6 +159,8 @@ namespace Tensorflow.Keras.Engine _set_mask_metadata(inputs, outputs, null); }); + tf.Context.eager_mode(); + return outputs; } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Node.cs b/src/TensorFlowNET.Core/Keras/Engine/Node.cs index ee734588..bb70d779 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Node.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Node.cs @@ -17,7 +17,7 @@ using System.Collections.Generic; using System.Linq; using Tensorflow.Keras.ArgsDefinition; -using Tensorflow.Keras.Layers; +using static Tensorflow.Binding; namespace Tensorflow.Keras.Engine { @@ -56,6 +56,8 @@ namespace Tensorflow.Keras.Engine } // Set metadata on outputs. + var node_index = layer.InboundNodes.Count - 1; + args.Outputs.KerasHistory.Add(layer); } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs b/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs index 49e605c1..09d10e70 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs @@ -14,17 +14,22 @@ limitations under the License. ******************************************************************************/ +using System.Collections.Generic; +using System.Linq; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Layers; using static Tensorflow.Binding; namespace Tensorflow.Keras.Engine { - public class Sequential : Model, ITensorFlowObject + /// + /// `Sequential` groups a linear stack of layers into a `tf.keras.Model`. + /// `Sequential` provides training and inference features on this model. + /// + public class Sequential { -#pragma warning disable CS0649 // Field 'Sequential._is_graph_network' is never assigned to, and will always have its default value false bool _is_graph_network; -#pragma warning restore CS0649 // Field 'Sequential._is_graph_network' is never assigned to, and will always have its default value false + Tensor inputs; Tensor outputs; bool computeOutputAndMaskJointly; @@ -32,26 +37,24 @@ namespace Tensorflow.Keras.Engine TensorShape inferredInputShape; bool hasExplicitInputShape; TF_DataType inputDType; - Layer[] layers; + List layers; + public TensorShape output_shape => outputs.TensorShape; + bool built = false; public Sequential(Layer[] layers = null, string name = null) - : base(new ModelArgs { Name = name}) { - this.layers = layers ?? new Layer[0]; - SupportsMasking = true; + this.layers = layers == null ? new List() : layers.ToList(); + // SupportsMasking = true; computeOutputAndMaskJointly = true; autoTrackSubLayers = false; hasExplicitInputShape = false; + _is_graph_network = false; } - public void __enter__() + public void add(Tensor tensor) { - - } - - public void add(Tensor layer) - { - + var layer = tensor.KerasHistory[0]; + add(layer); } /// @@ -62,9 +65,9 @@ namespace Tensorflow.Keras.Engine { built = false; var set_inputs = false; - if(layers.Length == 0) + if (layers.Count == 0) { - if(layer is InputLayer) + if (layer is InputLayer) { set_inputs = true; } @@ -93,31 +96,33 @@ namespace Tensorflow.Keras.Engine } } + else if (outputs != null) + { + outputs = layer.Apply(outputs); + } if (set_inputs || _is_graph_network) { - + _init_graph_network(inputs, outputs); } - } - - public void __exit__() - { - - } - - public void Dispose() - { + else + { + } } - public void __init__() + void _init_graph_network(Tensor inputs, Tensor outputs) { - + _is_graph_network = true; + this.inputs = inputs; + this.outputs = outputs; + built = true; + _map_graph_network(inputs, outputs); } - public void __del__() + void _map_graph_network(Tensor inputs, Tensor outputs) { - + layers.add(outputs.KerasHistory[0]); } } } diff --git a/src/TensorFlowNET.Core/Keras/KerasApi.cs b/src/TensorFlowNET.Core/Keras/KerasApi.cs index ef18d133..84bf61c2 100644 --- a/src/TensorFlowNET.Core/Keras/KerasApi.cs +++ b/src/TensorFlowNET.Core/Keras/KerasApi.cs @@ -63,18 +63,9 @@ namespace Tensorflow var layer = new InputLayer(args); - return layer.InboundNodes[0].Outputs[0]; + return layer.InboundNodes[0].Outputs; } - public static Embedding Embedding(int input_dim, - int output_dim, - IInitializer embeddings_initializer = null, - bool mask_zero = false) - => new Embedding(input_dim, - output_dim, - embeddings_initializer, - mask_zero); - public class LayersApi { public Layer Dense(int units, @@ -86,6 +77,30 @@ namespace Tensorflow Activation = activation ?? tf.keras.activations.Linear, InputShape = input_shape }); + + /// + /// Turns positive integers (indexes) into dense vectors of fixed size. + /// + /// + /// + /// + /// + /// + public Embedding Embedding(int input_dim, + int output_dim, + IInitializer embeddings_initializer = null, + bool mask_zero = false, + TensorShape input_shape = null, + int input_length = -1) + => new Embedding(new EmbeddingArgs + { + InputDim = input_dim, + OutputDim = output_dim, + MaskZero = mask_zero, + InputShape = input_shape ?? input_length, + InputLength = input_length, + EmbeddingsInitializer = embeddings_initializer + }); } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs index cc38f553..36c3ecaa 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs @@ -20,40 +20,37 @@ using static Tensorflow.Binding; namespace Tensorflow.Keras.Layers { + /// + /// Turns positive integers (indexes) into dense vectors of fixed size. + /// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding + /// public class Embedding : Layer { - private int input_dim; - private int output_dim; - private bool mask_zero; - public IVariableV1 embeddings; - public IInitializer embeddings_initializer; - int input_length; + EmbeddingArgs args; + int input_dim => args.InputDim; + int output_dim => args.OutputDim; + bool mask_zero => args.MaskZero; + IVariableV1 embeddings; + IInitializer embeddings_initializer; - public Embedding(int input_dim, int output_dim, - IInitializer embeddings_initializer = null, - bool mask_zero = false, - TF_DataType dtype = TF_DataType.TF_FLOAT, - int[] input_shape = null, - int input_length = -1) : - base(new LayerArgs - { - DType = dtype, - InputShape = input_shape ?? new[] { input_length } - }) + public Embedding(EmbeddingArgs args) + : base(args) { - this.input_dim = input_dim; - this.output_dim = output_dim; - this.embeddings_initializer = embeddings_initializer == null ? tf.uniform_initializer : embeddings_initializer; - this.mask_zero = mask_zero; + this.args = args; + if(args.InputShape == null) + args.InputShape = args.InputLength; + + embeddings_initializer = embeddings_initializer ?? tf.random_uniform_initializer; SupportsMasking = mask_zero; - this.input_length = input_length; } protected override void build(TensorShape input_shape) { - embeddings = add_weight(shape: new int[] { input_dim, output_dim }, + tf.Context.eager_mode(); + embeddings = add_weight(shape: (input_dim, output_dim), initializer: embeddings_initializer, name: "embeddings"); + tf.Context.graph_mode(); built = true; } @@ -63,7 +60,7 @@ namespace Tensorflow.Keras.Layers if (dtype != tf.int32 && dtype != tf.int64) inputs = math_ops.cast(inputs, tf.int32); - var outputs = embedding_ops.embedding_lookup(embeddings, inputs[0]); + var outputs = embedding_ops.embedding_lookup(embeddings, inputs); return outputs; } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs b/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs index 055ed373..74e14e94 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs @@ -53,6 +53,11 @@ namespace Tensorflow.Keras.Layers args.Name = prefix + '_' + tf.keras.backend.get_uid(prefix); } + if(args.DType == TF_DataType.DtInvalid) + { + args.DType = args.InputTensor == null ? tf.float32 : args.InputTensor.dtype; + } + if (args.InputTensor == null) { if(args.InputShape != null) @@ -72,7 +77,8 @@ namespace Tensorflow.Keras.Layers shape: BatchInputShape, dtype: DType, name: Name, - sparse: args.Sparse); + sparse: args.Sparse, + ragged: args.Ragged); tf.Context.eager_mode(); isPlaceholder = true; diff --git a/src/TensorFlowNET.Core/Keras/Layers/LSTM.cs b/src/TensorFlowNET.Core/Keras/Layers/LSTM.cs new file mode 100644 index 00000000..41ce3033 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Layers/LSTM.cs @@ -0,0 +1,37 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Long Short-Term Memory layer - Hochreiter 1997. + /// + /// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) + /// for details about the usage of RNN API. + /// + public class LSTM : RNN + { + LSTMArgs args; + InputSpec[] state_spec; + + int units => args.Units; + + public LSTM(LSTMArgs args) : + base(args) + { + this.args = args; + state_spec = new[] { units, units } + .Select(dim => new InputSpec(shape: (-1, dim))) + .ToArray(); + } + + protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) + { + return base.call(inputs, is_training, state); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Layers/LSTMCell.cs b/src/TensorFlowNET.Core/Keras/Layers/LSTMCell.cs new file mode 100644 index 00000000..09babb20 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Layers/LSTMCell.cs @@ -0,0 +1,20 @@ +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Layers +{ + public class LSTMCell : Layer + { + LSTMCellArgs args; + + public LSTMCell(LSTMCellArgs args) + : base(args) + { + this.args = args; + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Layers/RNN.cs b/src/TensorFlowNET.Core/Keras/Layers/RNN.cs new file mode 100644 index 00000000..0f00058f --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Layers/RNN.cs @@ -0,0 +1,27 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Layers +{ + public class RNN : Layer + { + public RNN(RNNArgs args) + : base(args) + { + + } + + protected Tensor get_initial_state(Tensor inputs) + { + return _generate_zero_filled_state_for_cell(null, null); + } + + Tensor _generate_zero_filled_state_for_cell(LSTMCell cell, Tensor batch_size) + { + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs b/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs new file mode 100644 index 00000000..27571671 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Operations.Initializers +{ + public class Orthogonal : IInitializer + { + public Tensor Apply(InitializerArgs args) + { + throw new NotImplementedException(); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs b/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs index c2e9889b..65c36611 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs @@ -29,7 +29,7 @@ namespace Tensorflow.Operations.Initializers #pragma warning restore CS0649 // Field 'RandomUniform.maxval' is never assigned to, and will always have its default value 0 private TF_DataType dtype; - public RandomUniform(TF_DataType dtype = TF_DataType.DtInvalid) + public RandomUniform(TF_DataType dtype = TF_DataType.TF_FLOAT) { this.dtype = dtype; } @@ -37,7 +37,7 @@ namespace Tensorflow.Operations.Initializers public Tensor Apply(InitializerArgs args) { if (args.DType == TF_DataType.DtInvalid) - args.DType = this.dtype; + args.DType = dtype; return random_ops.random_uniform(args.Shape, minval: minval, diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 95d4d078..85fcdae8 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -193,7 +193,7 @@ namespace Tensorflow /// public static Tensor identity(Tensor input, string name = null) { - if (tf.Context.executing_eagerly()) + if (tf.executing_eagerly()) { var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, "Identity", name, diff --git a/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs index 87947285..8d6a5cfa 100644 --- a/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs @@ -140,7 +140,7 @@ namespace Tensorflow /// public static Tensor read_variable_op(Tensor resource, TF_DataType dtype, string name = null) { - if (tf.Context.executing_eagerly()) + if (tf.executing_eagerly()) { var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, "ReadVariableOp", name, diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj index b620cf09..f101418f 100644 --- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj +++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj @@ -5,7 +5,7 @@ TensorFlow.NET Tensorflow 2.2.0 - 0.20.0-preview3 + 0.20.0-preview4 8.0 Haiping Chen, Meinrad Recheis, Eli Belash SciSharp STACK diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 4d586160..21d5ba00 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -25,6 +25,7 @@ using System.Text; using static Tensorflow.Binding; using Tensorflow.Eager; using Tensorflow.Framework; +using Tensorflow.Keras.Engine; namespace Tensorflow { @@ -97,6 +98,8 @@ namespace Tensorflow /// public SafeTensorHandleHandle EagerTensorHandle { get; set; } + public bool IsEagerTensor => this is EagerTensor; + /// /// Returns the shape of a tensor. /// @@ -138,6 +141,11 @@ namespace Tensorflow public TensorShape TensorShape => rank < 0 ? new TensorShape() : tensor_util.to_shape(shape); + /// + /// Keras History: (Layer, (node_index, tensor_index)) + /// + public List KerasHistory = new List(); + /// /// Updates the shape of this tensor. /// diff --git a/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs b/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs index 88d5de78..cffa4cb6 100644 --- a/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs +++ b/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs @@ -13,7 +13,7 @@ namespace TensorFlowNET.UnitTest.Keras /// /// https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/layers /// - [TestClass, Ignore] + [TestClass] public class LayersTest : GraphModeTestBase { [TestMethod] @@ -23,11 +23,15 @@ namespace TensorFlowNET.UnitTest.Keras model.add(tf.keras.Input(shape: 16)); } - [TestMethod] + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding + /// + [TestMethod, Ignore] public void Embedding() { var model = new Sequential(); - model.add(new Embedding(1000, 64, input_length: 10)); + var layer = tf.keras.layers.Embedding(1000, 64, input_length: 10); + model.add(layer); // the model will take as input an integer matrix of size (batch, // input_length). // the largest integer (i.e. word index) in the input should be no larger @@ -35,15 +39,32 @@ namespace TensorFlowNET.UnitTest.Keras // now model.output_shape == (None, 10, 64), where None is the batch // dimension. var input_array = np.random.randint(1000, size: (32, 10)); - model.compile("rmsprop", "mse"); + // model.compile("rmsprop", "mse"); + // output_array = model.predict(input_array) } + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense + /// [TestMethod] public void Dense() { + // Create a `Sequential` model and add a Dense layer as the first layer. var model = tf.keras.Sequential(); - var dense_layer = tf.keras.layers.Dense(5, input_shape: 3); - model.add(dense_layer); + model.add(tf.keras.Input(shape: 16)); + model.add(tf.keras.layers.Dense(32, activation: tf.keras.activations.Relu)); + // Now the model will take as input arrays of shape (None, 16) + // and output arrays of shape (None, 32). + // Note that after the first layer, you don't need to specify + // the size of the input anymore: + model.add(tf.keras.layers.Dense(32)); + Assert.AreEqual((-1, 32), model.output_shape); + } + + [TestMethod] + public void SimpleRNN() + { + } } }