<<<<<<< HEAD using OneOf; using System; using System.Collections.Generic; using System.Reflection; using Tensorflow.Keras.ArgsDefinition; ======= using System; using System.Collections; using System.Collections.Generic; using System.Reflection; using static Tensorflow.Keras.ArgsDefinition.Rnn.RNNArgs; >>>>>>> master using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; using Tensorflow.Util; <<<<<<< HEAD using Tensorflow.Common.Extensions; using System.Linq.Expressions; using Tensorflow.Keras.Utils; using Tensorflow.Common.Types; ======= using OneOf; using OneOf.Types; using Tensorflow.Common.Extensions; >>>>>>> master // from tensorflow.python.distribute import distribution_strategy_context as ds_context; namespace Tensorflow.Keras.Layers.Rnn { /// /// Base class for recurrent layers. /// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) /// for details about the usage of RNN API. /// public class RNN : RnnBase { <<<<<<< HEAD private RNNArgs _args; private object _input_spec = null; // or NoneValue?? private object _state_spec = null; private Tensors _states = null; private object _constants_spec = null; private int _num_constants; protected IVariableV1 _kernel; protected IVariableV1 _bias; protected IRnnCell _cell; ======= private RNNArgs args; private object input_spec = null; // or NoneValue?? private object state_spec = null; private Tensors _states = null; private object constants_spec = null; private int _num_constants = 0; protected IVariableV1 kernel; protected IVariableV1 bias; protected ILayer cell; >>>>>>> master public RNN(RNNArgs args) : base(PreConstruct(args)) { _args = args; SupportsMasking = true; // if is StackedRnncell <<<<<<< HEAD if (args.Cells != null) { _cell = new StackedRNNCells(new StackedRNNCellsArgs { Cells = args.Cells ======= if (args.Cell.IsT0) { cell = new StackedRNNCells(new StackedRNNCellsArgs { Cells = args.Cell.AsT0, >>>>>>> master }); } else { <<<<<<< HEAD _cell = args.Cell; } ======= cell = args.Cell.AsT1; } Type type = cell.GetType(); MethodInfo callMethodInfo = type.GetMethod("Call"); if (callMethodInfo == null) { throw new ValueError(@"Argument `cell` or `cells`should have a `call` method. "); } PropertyInfo state_size_info = type.GetProperty("state_size"); if (state_size_info == null) { throw new ValueError(@"The RNN cell should have a `state_size` attribute"); } // get input_shape this.args = PreConstruct(args); // The input shape is unknown yet, it could have nested tensor inputs, and // the input spec will be the list of specs for nested inputs, the structure // of the input_spec will be the same as the input. >>>>>>> master // get input_shape _args = PreConstruct(args); _num_constants = 0; } // States is a tuple consist of cell states_size, like (cell1.state_size, cell2.state_size,...) // state_size can be a single integer, can also be a list/tuple of integers, can also be TensorShape or a list/tuple of TensorShape public Tensors States { get { if (_states == null) { // CHECK(Rinne): check if this is correct. var nested = _cell.StateSize.MapStructure(x => null); _states = nested.AsNest().ToTensors(); } return _states; } set { _states = value; } } private OneOf> compute_output_shape(Shape input_shape) { var batch = input_shape[0]; var time_step = input_shape[1]; if (_args.TimeMajor) { (batch, time_step) = (time_step, batch); } // state_size is a array of ints or a positive integer var state_size = _cell.StateSize.ToSingleShape(); // TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor Func _get_output_shape; _get_output_shape = (flat_output_size) => { var output_dim = flat_output_size.as_int_list(); Shape output_shape; if (_args.ReturnSequences) { if (_args.TimeMajor) { output_shape = new Shape(new int[] { (int)time_step, (int)batch }.concat(output_dim)); } else { output_shape = new Shape(new int[] { (int)batch, (int)time_step }.concat(output_dim)); } } else { output_shape = new Shape(new int[] { (int)batch }.concat(output_dim)); } return output_shape; }; Type type = _cell.GetType(); PropertyInfo output_size_info = type.GetProperty("output_size"); Shape output_shape; if (output_size_info != null) { output_shape = nest.map_structure(_get_output_shape, _cell.OutputSize.ToSingleShape()); // TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型 output_shape = (output_shape.Length == 1 ? (int)output_shape[0] : output_shape); } else { output_shape = _get_output_shape(state_size); } if (_args.ReturnState) { Func _get_state_shape; _get_state_shape = (flat_state) => { var state_shape = new int[] { (int)batch }.concat(flat_state.as_int_list()); return new Shape(state_shape); }; var state_shape = _get_state_shape(state_size); return new List { output_shape, state_shape }; } else { return output_shape; } } private Tensors compute_mask(Tensors inputs, Tensors mask) { // Time step masks must be the same for each input. // This is because the mask for an RNN is of size [batch, time_steps, 1], // and specifies which time steps should be skipped, and a time step // must be skipped for all inputs. mask = nest.flatten(mask)[0]; var output_mask = _args.ReturnSequences ? mask : null; if (_args.ReturnState) { var state_mask = new List(); for (int i = 0; i < len(States); i++) { state_mask.Add(null); } return new List { output_mask }.concat(state_mask); } else { return output_mask; } } // States is a tuple consist of cell states_size, like (cell1.state_size, cell2.state_size,...) // state_size can be a single integer, can also be a list/tuple of integers, can also be TensorShape or a list/tuple of TensorShape public Tensors States { get { if (_states == null) { var state = nest.map_structure(x => null, cell.state_size); return nest.is_nested(state) ? state : new Tensors { state }; } return _states; } set { _states = value; } } private OneOf> compute_output_shape(Shape input_shape) { var batch = input_shape[0]; var time_step = input_shape[1]; if (args.TimeMajor) { (batch, time_step) = (time_step, batch); } // state_size is a array of ints or a positive integer var state_size = cell.state_size; // TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor Func _get_output_shape; _get_output_shape = (flat_output_size) => { var output_dim = flat_output_size.as_int_list(); Shape output_shape; if (args.ReturnSequences) { if (args.TimeMajor) { output_shape = new Shape(new int[] { (int)time_step, (int)batch }.concat(output_dim)); } else { output_shape = new Shape(new int[] { (int)batch, (int)time_step }.concat(output_dim)); } } else { output_shape = new Shape(new int[] { (int)batch }.concat(output_dim)); } return output_shape; }; Type type = cell.GetType(); PropertyInfo output_size_info = type.GetProperty("output_size"); Shape output_shape; if (output_size_info != null) { output_shape = nest.map_structure(_get_output_shape, cell.output_size); // TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型 output_shape = (output_shape.Length == 1 ? (int)output_shape[0] : output_shape); } else { output_shape = _get_output_shape(state_size[0]); } if (args.ReturnState) { Func _get_state_shape; _get_state_shape = (flat_state) => { var state_shape = new int[] { (int)batch }.concat(flat_state.as_int_list()); return new Shape(state_shape); }; var state_shape = _get_state_shape(new Shape(state_size.ToArray())); return new List { output_shape, state_shape }; } else { return output_shape; } } private Tensors compute_mask(Tensors inputs, Tensors mask) { // Time step masks must be the same for each input. // This is because the mask for an RNN is of size [batch, time_steps, 1], // and specifies which time steps should be skipped, and a time step // must be skipped for all inputs. mask = nest.flatten(mask)[0]; var output_mask = args.ReturnSequences ? mask : null; if (args.ReturnState) { var state_mask = new List(); for (int i = 0; i < len(States); i++) { state_mask.Add(null); } return new List { output_mask }.concat(state_mask); } else { return output_mask; } } public override void build(KerasShapesWrapper input_shape) { object get_input_spec(Shape shape) <<<<<<< HEAD ======= { var input_spec_shape = shape.as_int_list(); var (batch_index, time_step_index) = args.TimeMajor ? (1, 0) : (0, 1); if (!args.Stateful) { input_spec_shape[batch_index] = -1; } input_spec_shape[time_step_index] = -1; return new InputSpec(shape: input_spec_shape); } Shape get_step_input_shape(Shape shape) { // return shape[1:] if self.time_major else (shape[0],) + shape[2:] if (args.TimeMajor) { return shape.as_int_list().ToList().GetRange(1, shape.Length - 1).ToArray(); } else { return new int[] { shape.as_int_list()[0] }.concat(shape.as_int_list().ToList().GetRange(2, shape.Length - 2).ToArray()); } } object get_state_spec(Shape shape) { var state_spec_shape = shape.as_int_list(); // append bacth dim state_spec_shape = new int[] { -1 }.concat(state_spec_shape); return new InputSpec(shape: state_spec_shape); } // Check whether the input shape contains any nested shapes. It could be // (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from // numpy inputs. if (!cell.Built) >>>>>>> master { var input_spec_shape = shape.as_int_list(); var (batch_index, time_step_index) = _args.TimeMajor ? (1, 0) : (0, 1); if (!_args.Stateful) { input_spec_shape[batch_index] = -1; } input_spec_shape[time_step_index] = -1; return new InputSpec(shape: input_spec_shape); } Shape get_step_input_shape(Shape shape) { // return shape[1:] if self.time_major else (shape[0],) + shape[2:] if (_args.TimeMajor) { return shape.as_int_list().ToList().GetRange(1, shape.Length - 1).ToArray(); } else { return new int[] { shape.as_int_list()[0] }.concat(shape.as_int_list().ToList().GetRange(2, shape.Length - 2).ToArray()); } } object get_state_spec(Shape shape) { var state_spec_shape = shape.as_int_list(); // append bacth dim state_spec_shape = new int[] { -1 }.concat(state_spec_shape); return new InputSpec(shape: state_spec_shape); } // Check whether the input shape contains any nested shapes. It could be // (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from // numpy inputs. if (!_cell.Built) { _cell.build(input_shape); } } <<<<<<< HEAD /// /// /// /// /// Binary tensor of shape [batch_size, timesteps] indicating whether a given timestep should be masked /// /// List of initial state tensors to be passed to the first call of the cell /// List of constant tensors to be passed to the cell at each timestep /// /// /// protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bool? training = null, IOptionalArgs? optional_args = null) { RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs; if(optional_args is not null && rnn_optional_args is null) { throw new ArgumentException("The optional args shhould be of type `RnnOptionalArgs`"); } Tensors? constants = rnn_optional_args?.Constants; Tensors? mask = rnn_optional_args?.Mask; //var (inputs_padded, row_length) = BackendImpl.convert_inputs_if_ragged(inputs); // 暂时先不接受ragged tensor int row_length = 0; // TODO(Rinne): support this param. ======= // inputs: Tensors // mask: Binary tensor of shape [batch_size, timesteps] indicating whether a given timestep should be masked // training: bool // initial_state: List of initial state tensors to be passed to the first call of the cell // constants: List of constant tensors to be passed to the cell at each timestep protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { //var (inputs_padded, row_length) = BackendImpl.convert_inputs_if_ragged(inputs); // 暂时先不接受ragged tensor int? row_length = null; >>>>>>> master bool is_ragged_input = false; _validate_args_if_ragged(is_ragged_input, mask); (inputs, initial_state, constants) = _process_inputs(inputs, initial_state, constants); <<<<<<< HEAD _maybe_reset_cell_dropout_mask(_cell); if (_cell is StackedRNNCells) { var stack_cell = _cell as StackedRNNCells; foreach (IRnnCell cell in stack_cell.Cells) ======= _maybe_reset_cell_dropout_mask(cell); if (cell is StackedRNNCells) { var stack_cell = cell as StackedRNNCells; foreach (var cell in stack_cell.Cells) >>>>>>> master { _maybe_reset_cell_dropout_mask(cell); } } if (mask != null) { // Time step masks must be the same for each input. <<<<<<< HEAD mask = mask.Flatten().First(); } Shape input_shape; if (!inputs.IsNested()) ======= mask = nest.flatten(mask)[0]; } Shape input_shape; if (nest.is_nested(inputs)) >>>>>>> master { // In the case of nested input, use the first element for shape check // input_shape = nest.flatten(inputs)[0].shape; // TODO(Wanglongzhi2001) <<<<<<< HEAD input_shape = inputs.Flatten().First().shape; ======= input_shape = nest.flatten(inputs)[0].shape; >>>>>>> master } else { input_shape = inputs.shape; } <<<<<<< HEAD var timesteps = _args.TimeMajor ? input_shape[0] : input_shape[1]; if (_args.Unroll && timesteps == null) ======= var timesteps = args.TimeMajor ? input_shape[0] : input_shape[1]; if (args.Unroll && timesteps != null) >>>>>>> master { throw new ValueError( "Cannot unroll a RNN if the " + "time dimension is undefined. \n" + "- If using a Sequential model, " + "specify the time dimension by passing " + "an `input_shape` or `batch_input_shape` " + "argument to your first layer. If your " + "first layer is an Embedding, you can " + "also use the `input_length` argument.\n" + "- If using the functional API, specify " + "the time dimension by passing a `shape` " + "or `batch_shape` argument to your Input layer." ); } // cell_call_fn = (self.cell.__call__ if callable(self.cell) else self.cell.call) <<<<<<< HEAD Func step; bool is_tf_rnn_cell = _cell.IsTFRnnCell; if (constants is not null) { if (!_cell.SupportOptionalArgs) { throw new ValueError( $"RNN cell {_cell} does not support constants." + ======= var cell_call_fn = cell.Call; Func step; if (constants != null) { ParameterInfo[] parameters = cell_call_fn.GetMethodInfo().GetParameters(); bool hasParam = parameters.Any(p => p.Name == "constants"); if (!hasParam) { throw new ValueError( $"RNN cell {cell} does not support constants." + >>>>>>> master $"Received: constants={constants}"); } step = (inputs, states) => { <<<<<<< HEAD constants = new Tensors(states.TakeLast(_num_constants)); states = new Tensors(states.SkipLast(_num_constants)); states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states; var (output, new_states) = _cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants }); return (output, new_states.Single); ======= // constants = states[-self._num_constants :] constants = states.numpy()[new Slice(states.Length - _num_constants, states.Length)]; // states = states[: -self._num_constants] states = states.numpy()[new Slice(0, states.Length - _num_constants)]; // states = (states[0] if len(states) == 1 and is_tf_rnn_cell else states) states = states.Length == 1 ? states[0] : states; var (output, new_states) = cell_call_fn(inputs, null, null, states, constants); // TODO(Wanglongzhi2001),should cell_call_fn's return value be Tensors, Tensors? if (!nest.is_nested(new_states)) { return (output, new Tensors { new_states }); } return (output, new_states); >>>>>>> master }; } else { step = (inputs, states) => { <<<<<<< HEAD states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states.First()) : states; var (output, new_states) = _cell.Apply(inputs, states); return (output, new_states); }; } var (last_output, outputs, states) = keras.backend.rnn( step, inputs, initial_state, constants: constants, go_backwards: _args.GoBackwards, mask: mask, unroll: _args.Unroll, input_length: row_length != null ? new Tensor(row_length) : new Tensor(timesteps), time_major: _args.TimeMajor, zero_output_for_mask: _args.ZeroOutputForMask, return_all_outputs: _args.ReturnSequences); if (_args.Stateful) { throw new NotImplementedException("this argument havn't been developed."); } Tensors output = new Tensors(); if (_args.ReturnSequences) { // TODO(Rinne): add go_backwards parameter and revise the `row_length` param output = keras.backend.maybe_convert_to_ragged(is_ragged_input, outputs, row_length, false); ======= // states = (states[0] if len(states) == 1 and is_tf_rnn_cell else states) states = states.Length == 1 ? states[0] : states; var (output, new_states) = cell_call_fn(inputs, null, null, states, constants); if (!nest.is_nested(new_states)) { return (output, new Tensors { new_states }); } return (output, new_states); }; } var (last_output, outputs, states) = BackendImpl.rnn(step, inputs, initial_state, constants: constants, go_backwards: args.GoBackwards, mask: mask, unroll: args.Unroll, input_length: row_length != null ? new Tensor(row_length) : new Tensor(timesteps), time_major: args.TimeMajor, zero_output_for_mask: args.ZeroOutputForMask, return_all_outputs: args.ReturnSequences); if (args.Stateful) { throw new NotImplementedException("this argument havn't been developed!"); } Tensors output = new Tensors(); if (args.ReturnSequences) { throw new NotImplementedException("this argument havn't been developed!"); >>>>>>> master } else { output = last_output; } <<<<<<< HEAD if (_args.ReturnState) { ======= if (args.ReturnState) { >>>>>>> master foreach (var state in states) { output.Add(state); } return output; } else { return output; } } <<<<<<< HEAD public override Tensors Apply(Tensors inputs, Tensors initial_states = null, bool training = false, IOptionalArgs? optional_args = null) { RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs; if (optional_args is not null && rnn_optional_args is null) { throw new ArgumentException("The type of optional args should be `RnnOptionalArgs`."); } Tensors? constants = rnn_optional_args?.Constants; (inputs, initial_states, constants) = RnnUtils.standardize_args(inputs, initial_states, constants, _num_constants); if(initial_states is null && constants is null) { return base.Apply(inputs); } // TODO(Rinne): implement it. throw new NotImplementedException(); } private (Tensors inputs, Tensors initial_state, Tensors constants) _process_inputs(Tensors inputs, Tensors initial_state, Tensors constants) { if (inputs.Length > 1) { if (_num_constants != 0) { initial_state = new Tensors(inputs.Skip(1)); } else { initial_state = new Tensors(inputs.Skip(1).SkipLast(_num_constants)); constants = new Tensors(inputs.TakeLast(_num_constants)); ======= private (Tensors inputs, Tensors initial_state, Tensors constants) _process_inputs(Tensor inputs, Tensors initial_state, Tensors constants) { if (nest.is_sequence(input)) { if (_num_constants != 0) { initial_state = inputs[new Slice(1, len(inputs))]; } else { initial_state = inputs[new Slice(1, len(inputs) - _num_constants)]; constants = inputs[new Slice(len(inputs) - _num_constants, len(inputs))]; >>>>>>> master } if (len(initial_state) == 0) initial_state = null; inputs = inputs[0]; } <<<<<<< HEAD if (_args.Stateful) ======= if (args.Stateful) >>>>>>> master { if (initial_state != null) { var tmp = new Tensor[] { }; foreach (var s in nest.flatten(States)) { <<<<<<< HEAD tmp.add(tf.math.count_nonzero(s.Single())); } var non_zero_count = tf.add_n(tmp); //initial_state = tf.cond(non_zero_count > 0, () => States, () => initial_state); if ((int)non_zero_count.numpy() > 0) ======= tmp.add(tf.math.count_nonzero((Tensor)s)); } var non_zero_count = tf.add_n(tmp); //initial_state = tf.cond(non_zero_count > 0, () => States, () => initial_state); if((int)non_zero_count.numpy() > 0) >>>>>>> master { initial_state = States; } } else { initial_state = States; } <<<<<<< HEAD // TODO(Wanglongzhi2001), // initial_state = tf.nest.map_structure( //# When the layer has a inferred dtype, use the dtype from the //# cell. // lambda v: tf.cast( // v, self.compute_dtype or self.cell.compute_dtype // ), // initial_state, // ) } else if (initial_state is null) ======= } else if(initial_state != null) >>>>>>> master { initial_state = get_initial_state(inputs); } if (initial_state.Length != States.Length) { <<<<<<< HEAD throw new ValueError($"Layer {this} expects {States.Length} state(s), " + $"but it received {initial_state.Length} " + $"initial state(s). Input received: {inputs}"); ======= throw new ValueError( $"Layer {this} expects {States.Length} state(s), " + $"but it received {initial_state.Length} " + $"initial state(s). Input received: {inputs}"); >>>>>>> master } return (inputs, initial_state, constants); } private void _validate_args_if_ragged(bool is_ragged_input, Tensors mask) { <<<<<<< HEAD if (!is_ragged_input) ======= if (!is_ragged_input) >>>>>>> master { return; } <<<<<<< HEAD if (_args.Unroll) ======= if (args.Unroll) >>>>>>> master { throw new ValueError("The input received contains RaggedTensors and does " + "not support unrolling. Disable unrolling by passing " + "`unroll=False` in the RNN Layer constructor."); } if (mask != null) { throw new ValueError($"The mask that was passed in was {mask}, which " + "cannot be applied to RaggedTensor inputs. Please " + "make sure that there is no mask injected by upstream " + "layers."); } } void _maybe_reset_cell_dropout_mask(ILayer cell) { <<<<<<< HEAD if (cell is DropoutRNNCellMixin CellDRCMixin) { CellDRCMixin.reset_dropout_mask(); CellDRCMixin.reset_recurrent_dropout_mask(); } ======= //if (cell is DropoutRNNCellMixin) //{ // cell.reset_dropout_mask(); // cell.reset_recurrent_dropout_mask(); //} >>>>>>> master } private static RNNArgs PreConstruct(RNNArgs args) { if (args.Kwargs == null) { args.Kwargs = new Dictionary(); } // If true, the output for masked timestep will be zeros, whereas in the // false case, output from previous timestep is returned for masked timestep. var zeroOutputForMask = (bool)args.Kwargs.Get("zero_output_for_mask", false); Shape input_shape; var propIS = (Shape)args.Kwargs.Get("input_shape", null); var propID = (int?)args.Kwargs.Get("input_dim", null); var propIL = (int?)args.Kwargs.Get("input_length", null); if (propIS == null && (propID != null || propIL != null)) { input_shape = new Shape( propIL ?? -1, propID ?? -1); args.Kwargs["input_shape"] = input_shape; } return args; } public Tensors __call__(Tensors inputs, Tensor state = null, Tensor training = null) { throw new NotImplementedException(); <<<<<<< HEAD ======= } // 好像不能cell不能传接口类型 //public RNN New(IRnnArgCell cell, // bool return_sequences = false, // bool return_state = false, // bool go_backwards = false, // bool stateful = false, // bool unroll = false, // bool time_major = false) // => new RNN(new RNNArgs // { // Cell = cell, // ReturnSequences = return_sequences, // ReturnState = return_state, // GoBackwards = go_backwards, // Stateful = stateful, // Unroll = unroll, // TimeMajor = time_major // }); //public RNN New(List cell, // bool return_sequences = false, // bool return_state = false, // bool go_backwards = false, // bool stateful = false, // bool unroll = false, // bool time_major = false) // => new RNN(new RNNArgs // { // Cell = cell, // ReturnSequences = return_sequences, // ReturnState = return_state, // GoBackwards = go_backwards, // Stateful = stateful, // Unroll = unroll, // TimeMajor = time_major // }); protected Tensors get_initial_state(Tensor inputs) { Type type = cell.GetType(); MethodInfo MethodInfo = type.GetMethod("get_initial_state"); if (nest.is_nested(inputs)) { // The input are nested sequences. Use the first element in the seq // to get batch size and dtype. inputs = nest.flatten(inputs)[0]; } var input_shape = tf.shape(inputs); var batch_size = args.TimeMajor ? input_shape[1] : input_shape[0]; var dtype = inputs.dtype; Tensor init_state; if (MethodInfo != null) { init_state = (Tensor)MethodInfo.Invoke(cell, new object[] { null, batch_size, dtype }); } else { init_state = RNNUtils.generate_zero_filled_state(batch_size, cell.state_size, dtype); } //if (!nest.is_nested(init_state)) //{ // init_state = new List { init_state}; //} return new List { init_state }; //return _generate_zero_filled_state_for_cell(null, null); >>>>>>> master } // 好像不能cell不能传接口类型 //public RNN New(IRnnArgCell cell, // bool return_sequences = false, // bool return_state = false, // bool go_backwards = false, // bool stateful = false, // bool unroll = false, // bool time_major = false) // => new RNN(new RNNArgs // { // Cell = cell, // ReturnSequences = return_sequences, // ReturnState = return_state, // GoBackwards = go_backwards, // Stateful = stateful, // Unroll = unroll, // TimeMajor = time_major // }); //public RNN New(List cell, // bool return_sequences = false, // bool return_state = false, // bool go_backwards = false, // bool stateful = false, // bool unroll = false, // bool time_major = false) // => new RNN(new RNNArgs // { // Cell = cell, // ReturnSequences = return_sequences, // ReturnState = return_state, // GoBackwards = go_backwards, // Stateful = stateful, // Unroll = unroll, // TimeMajor = time_major // }); protected Tensors get_initial_state(Tensors inputs) { var get_initial_state_fn = _cell.GetType().GetMethod("get_initial_state"); var input = inputs[0]; var input_shape = inputs.shape; var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0]; var dtype = input.dtype; Tensors init_state = new Tensors(); if(get_initial_state_fn != null) { init_state = (Tensors)get_initial_state_fn.Invoke(_cell, new object[] { inputs, batch_size, dtype }); } //if (_cell is RnnCellBase rnn_base_cell) //{ // init_state = rnn_base_cell.GetInitialState(null, batch_size, dtype); //} else { init_state = RnnUtils.generate_zero_filled_state(tf.convert_to_tensor(batch_size), _cell.StateSize, dtype); } return init_state; } // Check whether the state_size contains multiple states. <<<<<<< HEAD public static bool is_multiple_state(GeneralizedTensorShape state_size) ======= public static bool is_multiple_state(object state_size) >>>>>>> master { return state_size.Shapes.Length > 1; } } }