using System; using System.Collections.Generic; using System.Diagnostics; using System.Text; using Tensorflow.Common.Types; using Tensorflow.Keras.Layers.Rnn; using Tensorflow.Common.Extensions; namespace Tensorflow.Keras.Utils { internal static class RnnUtils { internal static Tensors generate_zero_filled_state(Tensor batch_size_tensor, GeneralizedTensorShape state_size, TF_DataType dtype) { Func create_zeros; create_zeros = (GeneralizedTensorShape unnested_state_size) => { var flat_dims = unnested_state_size.ToSingleShape().dims; var init_state_size = new Tensor[] { batch_size_tensor }. Concat(flat_dims.Select(x => tf.constant(x, dtypes.int32))).ToArray(); return array_ops.zeros(init_state_size, dtype: dtype); }; // TODO(Rinne): map structure with nested tensors. if(state_size.TotalNestedCount > 1) { return new Tensors(state_size.Flatten().Select(s => create_zeros(new GeneralizedTensorShape(s))).ToArray()); } else { return create_zeros(state_size); } } internal static Tensors generate_zero_filled_state_for_cell(IRnnCell cell, Tensors inputs, Tensor batch_size, TF_DataType dtype) { if (inputs is not null) { batch_size = array_ops.shape(inputs)[0]; dtype = inputs.dtype; } return generate_zero_filled_state(batch_size, cell.StateSize, dtype); } /// /// Standardizes `__call__` to a single list of tensor inputs. /// /// When running a model loaded from a file, the input tensors /// `initial_state` and `constants` can be passed to `RNN.__call__()` as part /// of `inputs` instead of by the dedicated keyword arguments.This method /// makes sure the arguments are separated and that `initial_state` and /// `constants` are lists of tensors(or None). /// /// Tensor or list/tuple of tensors. which may include constants /// and initial states.In that case `num_constant` must be specified. /// Tensor or list of tensors or None, initial states. /// Tensor or list of tensors or None, constant tensors. /// Expected number of constants (if constants are passed as /// part of the `inputs` list. /// internal static (Tensors, Tensors, Tensors) standardize_args(Tensors inputs, Tensors initial_state, Tensors constants, int num_constants) { if(inputs.Length > 1) { // There are several situations here: // In the graph mode, __call__ will be only called once. The initial_state // and constants could be in inputs (from file loading). // In the eager mode, __call__ will be called twice, once during // rnn_layer(inputs=input_t, constants=c_t, ...), and second time will be // model.fit/train_on_batch/predict with real np data. In the second case, // the inputs will contain initial_state and constants as eager tensor. // // For either case, the real input is the first item in the list, which // could be a nested structure itself. Then followed by initial_states, which // could be a list of items, or list of list if the initial_state is complex // structure, and finally followed by constants which is a flat list. Debug.Assert(initial_state is null && constants is null); if(num_constants > 0) { constants = inputs.TakeLast(num_constants).ToArray().ToTensors(); inputs = inputs.SkipLast(num_constants).ToArray().ToTensors(); } if(inputs.Length > 1) { initial_state = inputs.Skip(1).ToArray().ToTensors(); inputs = inputs.Take(1).ToArray().ToTensors(); } } return (inputs, initial_state, constants); } /// /// Check whether the state_size contains multiple states. /// /// /// public static bool is_multiple_state(GeneralizedTensorShape state_size) { return state_size.TotalNestedCount > 1; } } }