using OneOf;
using System;
using System.Collections.Generic;
using System.Reflection;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Util;
using Tensorflow.Common.Extensions;
using System.Linq.Expressions;
using Tensorflow.Keras.Utils;
using Tensorflow.Common.Types;
using System.Runtime.CompilerServices;
// from tensorflow.python.distribute import distribution_strategy_context as ds_context;
namespace Tensorflow.Keras.Layers
{
///
/// Base class for recurrent layers.
/// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
/// for details about the usage of RNN API.
///
public class RNN : RnnBase
{
private RNNArgs _args;
private object _input_spec = null; // or NoneValue??
private object _state_spec = null;
private Tensors _states = null;
private object _constants_spec = null;
private int _num_constants;
protected IVariableV1 _kernel;
protected IVariableV1 _bias;
private IRnnCell _cell;
public RNNArgs Args { get => _args; }
public IRnnCell Cell
{
get
{
return _cell;
}
init
{
_cell = value;
_self_tracked_trackables.Add(_cell);
}
}
public RNN(IRnnCell cell, RNNArgs args) : base(PreConstruct(args))
{
_args = args;
SupportsMasking = true;
Cell = cell;
// get input_shape
_args = PreConstruct(args);
_num_constants = 0;
}
public RNN(IEnumerable cells, RNNArgs args) : base(PreConstruct(args))
{
_args = args;
SupportsMasking = true;
Cell = new StackedRNNCells(cells, new StackedRNNCellsArgs());
// get input_shape
_args = PreConstruct(args);
_num_constants = 0;
}
// States is a tuple consist of cell states_size, like (cell1.state_size, cell2.state_size,...)
// state_size can be a single integer, can also be a list/tuple of integers, can also be TensorShape or a list/tuple of TensorShape
public Tensors States
{
get
{
if (_states == null)
{
// CHECK(Rinne): check if this is correct.
var nested = Cell.StateSize.MapStructure(x => null);
_states = nested.AsNest().ToTensors();
}
return _states;
}
set { _states = value; }
}
private INestStructure compute_output_shape(Shape input_shape)
{
var batch = input_shape[0];
var time_step = input_shape[1];
if (_args.TimeMajor)
{
(batch, time_step) = (time_step, batch);
}
// state_size is a array of ints or a positive integer
var state_size = Cell.StateSize;
if(state_size?.TotalNestedCount == 1)
{
state_size = new NestList(state_size.Flatten().First());
}
Func _get_output_shape = (flat_output_size) =>
{
var output_dim = new Shape(flat_output_size).as_int_list();
Shape output_shape;
if (_args.ReturnSequences)
{
if (_args.TimeMajor)
{
output_shape = new Shape(new int[] { (int)time_step, (int)batch }.concat(output_dim));
}
else
{
output_shape = new Shape(new int[] { (int)batch, (int)time_step }.concat(output_dim));
}
}
else
{
output_shape = new Shape(new int[] { (int)batch }.concat(output_dim));
}
return output_shape;
};
Type type = Cell.GetType();
PropertyInfo output_size_info = type.GetProperty("output_size");
INestStructure output_shape;
if (output_size_info != null)
{
output_shape = Nest.MapStructure(_get_output_shape, Cell.OutputSize);
}
else
{
output_shape = new NestNode(_get_output_shape(state_size.Flatten().First()));
}
if (_args.ReturnState)
{
Func _get_state_shape = (flat_state) =>
{
var state_shape = new int[] { (int)batch }.concat(new Shape(flat_state).as_int_list());
return new Shape(state_shape);
};
var state_shape = Nest.MapStructure(_get_state_shape, state_size);
return new Nest(new[] { output_shape, state_shape } );
}
else
{
return output_shape;
}
}
private Tensors compute_mask(Tensors inputs, Tensors mask)
{
// Time step masks must be the same for each input.
// This is because the mask for an RNN is of size [batch, time_steps, 1],
// and specifies which time steps should be skipped, and a time step
// must be skipped for all inputs.
mask = nest.flatten(mask)[0];
var output_mask = _args.ReturnSequences ? mask : null;
if (_args.ReturnState)
{
var state_mask = new List();
for (int i = 0; i < len(States); i++)
{
state_mask.Add(null);
}
return new List { output_mask }.concat(state_mask);
}
else
{
return output_mask;
}
}
public override void build(KerasShapesWrapper input_shape)
{
_buildInputShape = input_shape;
input_shape = new KerasShapesWrapper(input_shape.Shapes[0]);
InputSpec get_input_spec(Shape shape)
{
var input_spec_shape = shape.as_int_list();
var (batch_index, time_step_index) = _args.TimeMajor ? (1, 0) : (0, 1);
if (!_args.Stateful)
{
input_spec_shape[batch_index] = -1;
}
input_spec_shape[time_step_index] = -1;
return new InputSpec(shape: input_spec_shape);
}
Shape get_step_input_shape(Shape shape)
{
// return shape[1:] if self.time_major else (shape[0],) + shape[2:]
if (_args.TimeMajor)
{
return shape.as_int_list().ToList().GetRange(1, shape.Length - 1).ToArray();
}
else
{
return new int[] { shape.as_int_list()[0] }.concat(shape.as_int_list().ToList().GetRange(2, shape.Length - 2).ToArray());
}
}
object get_state_spec(Shape shape)
{
var state_spec_shape = shape.as_int_list();
// append bacth dim
state_spec_shape = new int[] { -1 }.concat(state_spec_shape);
return new InputSpec(shape: state_spec_shape);
}
// Check whether the input shape contains any nested shapes. It could be
// (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from
// numpy inputs.
if (Cell is Layer layer && !layer.Built)
{
layer.build(input_shape);
layer.Built = true;
}
this.built = true;
}
///
///
///
///
/// List of initial state tensors to be passed to the first call of the cell
///
///
///
///
///
protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs;
if(optional_args is not null && rnn_optional_args is null)
{
throw new ArgumentException("The optional args shhould be of type `RnnOptionalArgs`");
}
Tensors? constants = rnn_optional_args?.Constants;
Tensors? mask = rnn_optional_args?.Mask;
//var (inputs_padded, row_length) = BackendImpl.convert_inputs_if_ragged(inputs);
// 暂时先不接受ragged tensor
int row_length = 0; // TODO(Rinne): support this param.
bool is_ragged_input = false;
_validate_args_if_ragged(is_ragged_input, mask);
(inputs, initial_state, constants) = _process_inputs(inputs, initial_state, constants);
_maybe_reset_cell_dropout_mask(Cell);
if (Cell is StackedRNNCells)
{
var stack_cell = Cell as StackedRNNCells;
foreach (IRnnCell cell in stack_cell.Cells)
{
_maybe_reset_cell_dropout_mask(cell);
}
}
if (mask != null)
{
// Time step masks must be the same for each input.
mask = mask.Flatten().First();
}
Shape input_shape;
if (!inputs.IsNested())
{
// In the case of nested input, use the first element for shape check
// input_shape = nest.flatten(inputs)[0].shape;
// TODO(Wanglongzhi2001)
input_shape = inputs.Flatten().First().shape;
}
else
{
input_shape = inputs.shape;
}
var timesteps = _args.TimeMajor ? input_shape[0] : input_shape[1];
if (_args.Unroll && timesteps == null)
{
throw new ValueError(
"Cannot unroll a RNN if the " +
"time dimension is undefined. \n" +
"- If using a Sequential model, " +
"specify the time dimension by passing " +
"an `input_shape` or `batch_input_shape` " +
"argument to your first layer. If your " +
"first layer is an Embedding, you can " +
"also use the `input_length` argument.\n" +
"- If using the functional API, specify " +
"the time dimension by passing a `shape` " +
"or `batch_shape` argument to your Input layer."
);
}
// cell_call_fn = (self.cell.__call__ if callable(self.cell) else self.cell.call)
Func step;
bool is_tf_rnn_cell = false;
if (constants is not null)
{
if (!Cell.SupportOptionalArgs)
{
throw new ValueError(
$"RNN cell {Cell} does not support constants." +
$"Received: constants={constants}");
}
step = (inputs, states) =>
{
constants = new Tensors(states.TakeLast(_num_constants).ToArray());
states = new Tensors(states.SkipLast(_num_constants).ToArray());
states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states;
var (output, new_states) = Cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants });
return (output, new_states);
};
}
else
{
step = (inputs, states) =>
{
states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states.First()) : states;
var (output, new_states) = Cell.Apply(inputs, states);
return (output, new_states);
};
}
var (last_output, outputs, states) = keras.backend.rnn(
step,
inputs,
initial_state,
constants: constants,
go_backwards: _args.GoBackwards,
mask: mask,
unroll: _args.Unroll,
input_length: row_length != null ? new Tensor(row_length) : new Tensor(timesteps),
time_major: _args.TimeMajor,
zero_output_for_mask: _args.ZeroOutputForMask,
return_all_outputs: _args.ReturnSequences);
if (_args.Stateful)
{
throw new NotImplementedException("this argument havn't been developed.");
}
Tensors output = new Tensors();
if (_args.ReturnSequences)
{
// TODO(Rinne): add go_backwards parameter and revise the `row_length` param
output = keras.backend.maybe_convert_to_ragged(is_ragged_input, outputs, row_length, false);
}
else
{
output = last_output;
}
if (_args.ReturnState)
{
foreach (var state in states)
{
output.Add(state);
}
return output;
}
else
{
//var tapeSet = tf.GetTapeSet();
//foreach(var tape in tapeSet)
//{
// tape.Watch(output);
//}
return output;
}
}
public override Tensors Apply(Tensors inputs, Tensors initial_states = null, bool? training = false, IOptionalArgs? optional_args = null)
{
RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs;
if (optional_args is not null && rnn_optional_args is null)
{
throw new ArgumentException("The type of optional args should be `RnnOptionalArgs`.");
}
Tensors? constants = rnn_optional_args?.Constants;
(inputs, initial_states, constants) = RnnUtils.standardize_args(inputs, initial_states, constants, _num_constants);
if(initial_states is null && constants is null)
{
return base.Apply(inputs);
}
// TODO(Rinne): implement it.
throw new NotImplementedException();
}
protected (Tensors inputs, Tensors initial_state, Tensors constants) _process_inputs(Tensors inputs, Tensors initial_state, Tensors constants)
{
if (inputs.Length > 1)
{
if (_num_constants != 0)
{
initial_state = new Tensors(inputs.Skip(1).ToArray());
}
else
{
initial_state = new Tensors(inputs.Skip(1).SkipLast(_num_constants).ToArray());
constants = new Tensors(inputs.TakeLast(_num_constants).ToArray());
}
if (len(initial_state) == 0)
initial_state = null;
inputs = inputs[0];
}
if (_args.Stateful)
{
if (initial_state != null)
{
var tmp = new Tensor[] { };
foreach (var s in nest.flatten(States))
{
tmp.add(tf.math.count_nonzero(s.Single()));
}
var non_zero_count = tf.add_n(tmp);
initial_state = tf.cond(non_zero_count > 0, States, initial_state);
if ((int)non_zero_count.numpy() > 0)
{
initial_state = States;
}
}
else
{
initial_state = States;
}
//initial_state = Nest.MapStructure(v => tf.cast(v, this.), initial_state);
}
else if (initial_state is null)
{
initial_state = get_initial_state(inputs);
}
if (initial_state.Length != States.Length)
{
throw new ValueError($"Layer {this} expects {States.Length} state(s), " +
$"but it received {initial_state.Length} " +
$"initial state(s). Input received: {inputs}");
}
return (inputs, initial_state, constants);
}
private void _validate_args_if_ragged(bool is_ragged_input, Tensors mask)
{
if (!is_ragged_input)
{
return;
}
if (_args.Unroll)
{
throw new ValueError("The input received contains RaggedTensors and does " +
"not support unrolling. Disable unrolling by passing " +
"`unroll=False` in the RNN Layer constructor.");
}
if (mask != null)
{
throw new ValueError($"The mask that was passed in was {mask}, which " +
"cannot be applied to RaggedTensor inputs. Please " +
"make sure that there is no mask injected by upstream " +
"layers.");
}
}
protected void _maybe_reset_cell_dropout_mask(ILayer cell)
{
if (cell is DropoutRNNCellMixin CellDRCMixin)
{
CellDRCMixin.reset_dropout_mask();
CellDRCMixin.reset_recurrent_dropout_mask();
}
}
private static RNNArgs PreConstruct(RNNArgs args)
{
// If true, the output for masked timestep will be zeros, whereas in the
// false case, output from previous timestep is returned for masked timestep.
var zeroOutputForMask = args.ZeroOutputForMask;
Shape input_shape;
var propIS = args.InputShape;
var propID = args.InputDim;
var propIL = args.InputLength;
if (propIS == null && (propID != null || propIL != null))
{
input_shape = new Shape(
propIL ?? -1,
propID ?? -1);
args.InputShape = input_shape;
}
return args;
}
public Tensors __call__(Tensors inputs, Tensor state = null, Tensor training = null)
{
throw new NotImplementedException();
}
// 好像不能cell不能传接口类型
//public RNN New(IRnnArgCell cell,
// bool return_sequences = false,
// bool return_state = false,
// bool go_backwards = false,
// bool stateful = false,
// bool unroll = false,
// bool time_major = false)
// => new RNN(new RNNArgs
// {
// Cell = cell,
// ReturnSequences = return_sequences,
// ReturnState = return_state,
// GoBackwards = go_backwards,
// Stateful = stateful,
// Unroll = unroll,
// TimeMajor = time_major
// });
//public RNN New(List cell,
// bool return_sequences = false,
// bool return_state = false,
// bool go_backwards = false,
// bool stateful = false,
// bool unroll = false,
// bool time_major = false)
// => new RNN(new RNNArgs
// {
// Cell = cell,
// ReturnSequences = return_sequences,
// ReturnState = return_state,
// GoBackwards = go_backwards,
// Stateful = stateful,
// Unroll = unroll,
// TimeMajor = time_major
// });
protected Tensors get_initial_state(Tensors inputs)
{
var input = inputs[0];
var input_shape = array_ops.shape(inputs);
var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0];
var dtype = input.dtype;
Tensors init_state = Cell.GetInitialState(null, batch_size, dtype);
return init_state;
}
public override IKerasConfig get_config()
{
return _args;
}
}
}