using System; using System.Collections.Generic; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; namespace Tensorflow.Keras.Layers { public class RNN : Layer { private RNNArgs args; public RNN(RNNArgs args) : base(PreConstruct(args)) { this.args = args; SupportsMasking = true; // The input shape is unknown yet, it could have nested tensor inputs, and // the input spec will be the list of specs for nested inputs, the structure // of the input_spec will be the same as the input. //self.input_spec = None //self.state_spec = None //self._states = None //self.constants_spec = None //self._num_constants = 0 //if stateful: // if ds_context.has_strategy(): // raise ValueError('RNNs with stateful=True not yet supported with ' // 'tf.distribute.Strategy.') } private static RNNArgs PreConstruct(RNNArgs args) { if (args.Kwargs == null) { args.Kwargs = new Dictionary(); } // If true, the output for masked timestep will be zeros, whereas in the // false case, output from previous timestep is returned for masked timestep. var zeroOutputForMask = (bool)args.Kwargs.Get("zero_output_for_mask", false); object input_shape; var propIS = args.Kwargs.Get("input_shape", null); var propID = args.Kwargs.Get("input_dim", null); var propIL = args.Kwargs.Get("input_length", null); if (propIS == null && (propID != null || propIL != null)) { input_shape = ( propIL ?? new NoneValue(), // maybe null is needed here propID ?? new NoneValue()); // and here args.Kwargs["input_shape"] = input_shape; } return args; } public RNN New(LayerRnnCell cell, bool return_sequences = false, bool return_state = false, bool go_backwards = false, bool stateful = false, bool unroll = false, bool time_major = false) => new RNN(new RNNArgs { Cell = cell, ReturnSequences = return_sequences, ReturnState = return_state, GoBackwards = go_backwards, Stateful = stateful, Unroll = unroll, TimeMajor = time_major }); public RNN New(IList cell, bool return_sequences = false, bool return_state = false, bool go_backwards = false, bool stateful = false, bool unroll = false, bool time_major = false) => new RNN(new RNNArgs { Cell = new StackedRNNCells(new StackedRNNCellsArgs { Cells = cell }), ReturnSequences = return_sequences, ReturnState = return_state, GoBackwards = go_backwards, Stateful = stateful, Unroll = unroll, TimeMajor = time_major }); protected Tensor get_initial_state(Tensor inputs) { return _generate_zero_filled_state_for_cell(null, null); } Tensor _generate_zero_filled_state_for_cell(LSTMCell cell, Tensor batch_size) { throw new NotImplementedException(""); } } }