HEAD
using OneOf;
using System;
using System.Collections.Generic;
using System.Reflection;
using Tensorflow.Keras.ArgsDefinition;
=======
using System;
using System.Collections;
using System.Collections.Generic;
using System.Reflection;
using static Tensorflow.Keras.ArgsDefinition.Rnn.RNNArgs;
>>>>>>> master
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Util;
<<<<<<< HEAD
using Tensorflow.Common.Extensions;
using System.Linq.Expressions;
using Tensorflow.Keras.Utils;
using Tensorflow.Common.Types;
=======
using OneOf;
using OneOf.Types;
using Tensorflow.Common.Extensions;
>>>>>>> master
// from tensorflow.python.distribute import distribution_strategy_context as ds_context;
namespace Tensorflow.Keras.Layers.Rnn
{
///
/// Base class for recurrent layers.
/// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
/// for details about the usage of RNN API.
///
public class RNN : RnnBase
{
<<<<<<< HEAD
private RNNArgs _args;
private object _input_spec = null; // or NoneValue??
private object _state_spec = null;
private Tensors _states = null;
private object _constants_spec = null;
private int _num_constants;
protected IVariableV1 _kernel;
protected IVariableV1 _bias;
protected IRnnCell _cell;
=======
private RNNArgs args;
private object input_spec = null; // or NoneValue??
private object state_spec = null;
private Tensors _states = null;
private object constants_spec = null;
private int _num_constants = 0;
protected IVariableV1 kernel;
protected IVariableV1 bias;
protected ILayer cell;
>>>>>>> master
public RNN(RNNArgs args) : base(PreConstruct(args))
{
_args = args;
SupportsMasking = true;
// if is StackedRnncell
<<<<<<< HEAD
if (args.Cells != null)
{
_cell = new StackedRNNCells(new StackedRNNCellsArgs
{
Cells = args.Cells
=======
if (args.Cell.IsT0)
{
cell = new StackedRNNCells(new StackedRNNCellsArgs
{
Cells = args.Cell.AsT0,
>>>>>>> master
});
}
else
{
<<<<<<< HEAD
_cell = args.Cell;
}
=======
cell = args.Cell.AsT1;
}
Type type = cell.GetType();
MethodInfo callMethodInfo = type.GetMethod("Call");
if (callMethodInfo == null)
{
throw new ValueError(@"Argument `cell` or `cells`should have a `call` method. ");
}
PropertyInfo state_size_info = type.GetProperty("state_size");
if (state_size_info == null)
{
throw new ValueError(@"The RNN cell should have a `state_size` attribute");
}
// get input_shape
this.args = PreConstruct(args);
// The input shape is unknown yet, it could have nested tensor inputs, and
// the input spec will be the list of specs for nested inputs, the structure
// of the input_spec will be the same as the input.
>>>>>>> master
// get input_shape
_args = PreConstruct(args);
_num_constants = 0;
}
// States is a tuple consist of cell states_size, like (cell1.state_size, cell2.state_size,...)
// state_size can be a single integer, can also be a list/tuple of integers, can also be TensorShape or a list/tuple of TensorShape
public Tensors States
{
get
{
if (_states == null)
{
// CHECK(Rinne): check if this is correct.
var nested = _cell.StateSize.MapStructure(x => null);
_states = nested.AsNest().ToTensors();
}
return _states;
}
set { _states = value; }
}
private OneOf> compute_output_shape(Shape input_shape)
{
var batch = input_shape[0];
var time_step = input_shape[1];
if (_args.TimeMajor)
{
(batch, time_step) = (time_step, batch);
}
// state_size is a array of ints or a positive integer
var state_size = _cell.StateSize.ToSingleShape();
// TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor
Func _get_output_shape;
_get_output_shape = (flat_output_size) =>
{
var output_dim = flat_output_size.as_int_list();
Shape output_shape;
if (_args.ReturnSequences)
{
if (_args.TimeMajor)
{
output_shape = new Shape(new int[] { (int)time_step, (int)batch }.concat(output_dim));
}
else
{
output_shape = new Shape(new int[] { (int)batch, (int)time_step }.concat(output_dim));
}
}
else
{
output_shape = new Shape(new int[] { (int)batch }.concat(output_dim));
}
return output_shape;
};
Type type = _cell.GetType();
PropertyInfo output_size_info = type.GetProperty("output_size");
Shape output_shape;
if (output_size_info != null)
{
output_shape = nest.map_structure(_get_output_shape, _cell.OutputSize.ToSingleShape());
// TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型
output_shape = (output_shape.Length == 1 ? (int)output_shape[0] : output_shape);
}
else
{
output_shape = _get_output_shape(state_size);
}
if (_args.ReturnState)
{
Func _get_state_shape;
_get_state_shape = (flat_state) =>
{
var state_shape = new int[] { (int)batch }.concat(flat_state.as_int_list());
return new Shape(state_shape);
};
var state_shape = _get_state_shape(state_size);
return new List { output_shape, state_shape };
}
else
{
return output_shape;
}
}
private Tensors compute_mask(Tensors inputs, Tensors mask)
{
// Time step masks must be the same for each input.
// This is because the mask for an RNN is of size [batch, time_steps, 1],
// and specifies which time steps should be skipped, and a time step
// must be skipped for all inputs.
mask = nest.flatten(mask)[0];
var output_mask = _args.ReturnSequences ? mask : null;
if (_args.ReturnState)
{
var state_mask = new List();
for (int i = 0; i < len(States); i++)
{
state_mask.Add(null);
}
return new List { output_mask }.concat(state_mask);
}
else
{
return output_mask;
}
}
// States is a tuple consist of cell states_size, like (cell1.state_size, cell2.state_size,...)
// state_size can be a single integer, can also be a list/tuple of integers, can also be TensorShape or a list/tuple of TensorShape
public Tensors States
{
get
{
if (_states == null)
{
var state = nest.map_structure(x => null, cell.state_size);
return nest.is_nested(state) ? state : new Tensors { state };
}
return _states;
}
set { _states = value; }
}
private OneOf> compute_output_shape(Shape input_shape)
{
var batch = input_shape[0];
var time_step = input_shape[1];
if (args.TimeMajor)
{
(batch, time_step) = (time_step, batch);
}
// state_size is a array of ints or a positive integer
var state_size = cell.state_size;
// TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor
Func _get_output_shape;
_get_output_shape = (flat_output_size) =>
{
var output_dim = flat_output_size.as_int_list();
Shape output_shape;
if (args.ReturnSequences)
{
if (args.TimeMajor)
{
output_shape = new Shape(new int[] { (int)time_step, (int)batch }.concat(output_dim));
}
else
{
output_shape = new Shape(new int[] { (int)batch, (int)time_step }.concat(output_dim));
}
}
else
{
output_shape = new Shape(new int[] { (int)batch }.concat(output_dim));
}
return output_shape;
};
Type type = cell.GetType();
PropertyInfo output_size_info = type.GetProperty("output_size");
Shape output_shape;
if (output_size_info != null)
{
output_shape = nest.map_structure(_get_output_shape, cell.output_size);
// TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型
output_shape = (output_shape.Length == 1 ? (int)output_shape[0] : output_shape);
}
else
{
output_shape = _get_output_shape(state_size[0]);
}
if (args.ReturnState)
{
Func _get_state_shape;
_get_state_shape = (flat_state) =>
{
var state_shape = new int[] { (int)batch }.concat(flat_state.as_int_list());
return new Shape(state_shape);
};
var state_shape = _get_state_shape(new Shape(state_size.ToArray()));
return new List { output_shape, state_shape };
}
else
{
return output_shape;
}
}
private Tensors compute_mask(Tensors inputs, Tensors mask)
{
// Time step masks must be the same for each input.
// This is because the mask for an RNN is of size [batch, time_steps, 1],
// and specifies which time steps should be skipped, and a time step
// must be skipped for all inputs.
mask = nest.flatten(mask)[0];
var output_mask = args.ReturnSequences ? mask : null;
if (args.ReturnState)
{
var state_mask = new List();
for (int i = 0; i < len(States); i++)
{
state_mask.Add(null);
}
return new List { output_mask }.concat(state_mask);
}
else
{
return output_mask;
}
}
public override void build(KerasShapesWrapper input_shape)
{
object get_input_spec(Shape shape)
<<<<<<< HEAD
=======
{
var input_spec_shape = shape.as_int_list();
var (batch_index, time_step_index) = args.TimeMajor ? (1, 0) : (0, 1);
if (!args.Stateful)
{
input_spec_shape[batch_index] = -1;
}
input_spec_shape[time_step_index] = -1;
return new InputSpec(shape: input_spec_shape);
}
Shape get_step_input_shape(Shape shape)
{
// return shape[1:] if self.time_major else (shape[0],) + shape[2:]
if (args.TimeMajor)
{
return shape.as_int_list().ToList().GetRange(1, shape.Length - 1).ToArray();
}
else
{
return new int[] { shape.as_int_list()[0] }.concat(shape.as_int_list().ToList().GetRange(2, shape.Length - 2).ToArray());
}
}
object get_state_spec(Shape shape)
{
var state_spec_shape = shape.as_int_list();
// append bacth dim
state_spec_shape = new int[] { -1 }.concat(state_spec_shape);
return new InputSpec(shape: state_spec_shape);
}
// Check whether the input shape contains any nested shapes. It could be
// (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from
// numpy inputs.
if (!cell.Built)
>>>>>>> master
{
var input_spec_shape = shape.as_int_list();
var (batch_index, time_step_index) = _args.TimeMajor ? (1, 0) : (0, 1);
if (!_args.Stateful)
{
input_spec_shape[batch_index] = -1;
}
input_spec_shape[time_step_index] = -1;
return new InputSpec(shape: input_spec_shape);
}
Shape get_step_input_shape(Shape shape)
{
// return shape[1:] if self.time_major else (shape[0],) + shape[2:]
if (_args.TimeMajor)
{
return shape.as_int_list().ToList().GetRange(1, shape.Length - 1).ToArray();
}
else
{
return new int[] { shape.as_int_list()[0] }.concat(shape.as_int_list().ToList().GetRange(2, shape.Length - 2).ToArray());
}
}
object get_state_spec(Shape shape)
{
var state_spec_shape = shape.as_int_list();
// append bacth dim
state_spec_shape = new int[] { -1 }.concat(state_spec_shape);
return new InputSpec(shape: state_spec_shape);
}
// Check whether the input shape contains any nested shapes. It could be
// (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from
// numpy inputs.
if (!_cell.Built)
{
_cell.build(input_shape);
}
}
<<<<<<< HEAD
///
///
///
///
/// Binary tensor of shape [batch_size, timesteps] indicating whether a given timestep should be masked
///
/// List of initial state tensors to be passed to the first call of the cell
/// List of constant tensors to be passed to the cell at each timestep
///
///
///
protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs;
if(optional_args is not null && rnn_optional_args is null)
{
throw new ArgumentException("The optional args shhould be of type `RnnOptionalArgs`");
}
Tensors? constants = rnn_optional_args?.Constants;
Tensors? mask = rnn_optional_args?.Mask;
//var (inputs_padded, row_length) = BackendImpl.convert_inputs_if_ragged(inputs);
// 暂时先不接受ragged tensor
int row_length = 0; // TODO(Rinne): support this param.
=======
// inputs: Tensors
// mask: Binary tensor of shape [batch_size, timesteps] indicating whether a given timestep should be masked
// training: bool
// initial_state: List of initial state tensors to be passed to the first call of the cell
// constants: List of constant tensors to be passed to the cell at each timestep
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
//var (inputs_padded, row_length) = BackendImpl.convert_inputs_if_ragged(inputs);
// 暂时先不接受ragged tensor
int? row_length = null;
>>>>>>> master
bool is_ragged_input = false;
_validate_args_if_ragged(is_ragged_input, mask);
(inputs, initial_state, constants) = _process_inputs(inputs, initial_state, constants);
<<<<<<< HEAD
_maybe_reset_cell_dropout_mask(_cell);
if (_cell is StackedRNNCells)
{
var stack_cell = _cell as StackedRNNCells;
foreach (IRnnCell cell in stack_cell.Cells)
=======
_maybe_reset_cell_dropout_mask(cell);
if (cell is StackedRNNCells)
{
var stack_cell = cell as StackedRNNCells;
foreach (var cell in stack_cell.Cells)
>>>>>>> master
{
_maybe_reset_cell_dropout_mask(cell);
}
}
if (mask != null)
{
// Time step masks must be the same for each input.
<<<<<<< HEAD
mask = mask.Flatten().First();
}
Shape input_shape;
if (!inputs.IsNested())
=======
mask = nest.flatten(mask)[0];
}
Shape input_shape;
if (nest.is_nested(inputs))
>>>>>>> master
{
// In the case of nested input, use the first element for shape check
// input_shape = nest.flatten(inputs)[0].shape;
// TODO(Wanglongzhi2001)
<<<<<<< HEAD
input_shape = inputs.Flatten().First().shape;
=======
input_shape = nest.flatten(inputs)[0].shape;
>>>>>>> master
}
else
{
input_shape = inputs.shape;
}
<<<<<<< HEAD
var timesteps = _args.TimeMajor ? input_shape[0] : input_shape[1];
if (_args.Unroll && timesteps == null)
=======
var timesteps = args.TimeMajor ? input_shape[0] : input_shape[1];
if (args.Unroll && timesteps != null)
>>>>>>> master
{
throw new ValueError(
"Cannot unroll a RNN if the " +
"time dimension is undefined. \n" +
"- If using a Sequential model, " +
"specify the time dimension by passing " +
"an `input_shape` or `batch_input_shape` " +
"argument to your first layer. If your " +
"first layer is an Embedding, you can " +
"also use the `input_length` argument.\n" +
"- If using the functional API, specify " +
"the time dimension by passing a `shape` " +
"or `batch_shape` argument to your Input layer."
);
}
// cell_call_fn = (self.cell.__call__ if callable(self.cell) else self.cell.call)
<<<<<<< HEAD
Func step;
bool is_tf_rnn_cell = _cell.IsTFRnnCell;
if (constants is not null)
{
if (!_cell.SupportOptionalArgs)
{
throw new ValueError(
$"RNN cell {_cell} does not support constants." +
=======
var cell_call_fn = cell.Call;
Func step;
if (constants != null)
{
ParameterInfo[] parameters = cell_call_fn.GetMethodInfo().GetParameters();
bool hasParam = parameters.Any(p => p.Name == "constants");
if (!hasParam)
{
throw new ValueError(
$"RNN cell {cell} does not support constants." +
>>>>>>> master
$"Received: constants={constants}");
}
step = (inputs, states) =>
{
<<<<<<< HEAD
constants = new Tensors(states.TakeLast(_num_constants));
states = new Tensors(states.SkipLast(_num_constants));
states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states;
var (output, new_states) = _cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants });
return (output, new_states.Single);
=======
// constants = states[-self._num_constants :]
constants = states.numpy()[new Slice(states.Length - _num_constants, states.Length)];
// states = states[: -self._num_constants]
states = states.numpy()[new Slice(0, states.Length - _num_constants)];
// states = (states[0] if len(states) == 1 and is_tf_rnn_cell else states)
states = states.Length == 1 ? states[0] : states;
var (output, new_states) = cell_call_fn(inputs, null, null, states, constants);
// TODO(Wanglongzhi2001),should cell_call_fn's return value be Tensors, Tensors?
if (!nest.is_nested(new_states))
{
return (output, new Tensors { new_states });
}
return (output, new_states);
>>>>>>> master
};
}
else
{
step = (inputs, states) =>
{
<<<<<<< HEAD
states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states.First()) : states;
var (output, new_states) = _cell.Apply(inputs, states);
return (output, new_states);
};
}
var (last_output, outputs, states) = keras.backend.rnn(
step,
inputs,
initial_state,
constants: constants,
go_backwards: _args.GoBackwards,
mask: mask,
unroll: _args.Unroll,
input_length: row_length != null ? new Tensor(row_length) : new Tensor(timesteps),
time_major: _args.TimeMajor,
zero_output_for_mask: _args.ZeroOutputForMask,
return_all_outputs: _args.ReturnSequences);
if (_args.Stateful)
{
throw new NotImplementedException("this argument havn't been developed.");
}
Tensors output = new Tensors();
if (_args.ReturnSequences)
{
// TODO(Rinne): add go_backwards parameter and revise the `row_length` param
output = keras.backend.maybe_convert_to_ragged(is_ragged_input, outputs, row_length, false);
=======
// states = (states[0] if len(states) == 1 and is_tf_rnn_cell else states)
states = states.Length == 1 ? states[0] : states;
var (output, new_states) = cell_call_fn(inputs, null, null, states, constants);
if (!nest.is_nested(new_states))
{
return (output, new Tensors { new_states });
}
return (output, new_states);
};
}
var (last_output, outputs, states) = BackendImpl.rnn(step,
inputs,
initial_state,
constants: constants,
go_backwards: args.GoBackwards,
mask: mask,
unroll: args.Unroll,
input_length: row_length != null ? new Tensor(row_length) : new Tensor(timesteps),
time_major: args.TimeMajor,
zero_output_for_mask: args.ZeroOutputForMask,
return_all_outputs: args.ReturnSequences);
if (args.Stateful)
{
throw new NotImplementedException("this argument havn't been developed!");
}
Tensors output = new Tensors();
if (args.ReturnSequences)
{
throw new NotImplementedException("this argument havn't been developed!");
>>>>>>> master
}
else
{
output = last_output;
}
<<<<<<< HEAD
if (_args.ReturnState)
{
=======
if (args.ReturnState)
{
>>>>>>> master
foreach (var state in states)
{
output.Add(state);
}
return output;
}
else
{
return output;
}
}
<<<<<<< HEAD
public override Tensors Apply(Tensors inputs, Tensors initial_states = null, bool training = false, IOptionalArgs? optional_args = null)
{
RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs;
if (optional_args is not null && rnn_optional_args is null)
{
throw new ArgumentException("The type of optional args should be `RnnOptionalArgs`.");
}
Tensors? constants = rnn_optional_args?.Constants;
(inputs, initial_states, constants) = RnnUtils.standardize_args(inputs, initial_states, constants, _num_constants);
if(initial_states is null && constants is null)
{
return base.Apply(inputs);
}
// TODO(Rinne): implement it.
throw new NotImplementedException();
}
private (Tensors inputs, Tensors initial_state, Tensors constants) _process_inputs(Tensors inputs, Tensors initial_state, Tensors constants)
{
if (inputs.Length > 1)
{
if (_num_constants != 0)
{
initial_state = new Tensors(inputs.Skip(1));
}
else
{
initial_state = new Tensors(inputs.Skip(1).SkipLast(_num_constants));
constants = new Tensors(inputs.TakeLast(_num_constants));
=======
private (Tensors inputs, Tensors initial_state, Tensors constants) _process_inputs(Tensor inputs, Tensors initial_state, Tensors constants)
{
if (nest.is_sequence(input))
{
if (_num_constants != 0)
{
initial_state = inputs[new Slice(1, len(inputs))];
}
else
{
initial_state = inputs[new Slice(1, len(inputs) - _num_constants)];
constants = inputs[new Slice(len(inputs) - _num_constants, len(inputs))];
>>>>>>> master
}
if (len(initial_state) == 0)
initial_state = null;
inputs = inputs[0];
}
<<<<<<< HEAD
if (_args.Stateful)
=======
if (args.Stateful)
>>>>>>> master
{
if (initial_state != null)
{
var tmp = new Tensor[] { };
foreach (var s in nest.flatten(States))
{
<<<<<<< HEAD
tmp.add(tf.math.count_nonzero(s.Single()));
}
var non_zero_count = tf.add_n(tmp);
//initial_state = tf.cond(non_zero_count > 0, () => States, () => initial_state);
if ((int)non_zero_count.numpy() > 0)
=======
tmp.add(tf.math.count_nonzero((Tensor)s));
}
var non_zero_count = tf.add_n(tmp);
//initial_state = tf.cond(non_zero_count > 0, () => States, () => initial_state);
if((int)non_zero_count.numpy() > 0)
>>>>>>> master
{
initial_state = States;
}
}
else
{
initial_state = States;
}
<<<<<<< HEAD
// TODO(Wanglongzhi2001),
// initial_state = tf.nest.map_structure(
//# When the layer has a inferred dtype, use the dtype from the
//# cell.
// lambda v: tf.cast(
// v, self.compute_dtype or self.cell.compute_dtype
// ),
// initial_state,
// )
}
else if (initial_state is null)
=======
}
else if(initial_state != null)
>>>>>>> master
{
initial_state = get_initial_state(inputs);
}
if (initial_state.Length != States.Length)
{
<<<<<<< HEAD
throw new ValueError($"Layer {this} expects {States.Length} state(s), " +
$"but it received {initial_state.Length} " +
$"initial state(s). Input received: {inputs}");
=======
throw new ValueError(
$"Layer {this} expects {States.Length} state(s), " +
$"but it received {initial_state.Length} " +
$"initial state(s). Input received: {inputs}");
>>>>>>> master
}
return (inputs, initial_state, constants);
}
private void _validate_args_if_ragged(bool is_ragged_input, Tensors mask)
{
<<<<<<< HEAD
if (!is_ragged_input)
=======
if (!is_ragged_input)
>>>>>>> master
{
return;
}
<<<<<<< HEAD
if (_args.Unroll)
=======
if (args.Unroll)
>>>>>>> master
{
throw new ValueError("The input received contains RaggedTensors and does " +
"not support unrolling. Disable unrolling by passing " +
"`unroll=False` in the RNN Layer constructor.");
}
if (mask != null)
{
throw new ValueError($"The mask that was passed in was {mask}, which " +
"cannot be applied to RaggedTensor inputs. Please " +
"make sure that there is no mask injected by upstream " +
"layers.");
}
}
void _maybe_reset_cell_dropout_mask(ILayer cell)
{
<<<<<<< HEAD
if (cell is DropoutRNNCellMixin CellDRCMixin)
{
CellDRCMixin.reset_dropout_mask();
CellDRCMixin.reset_recurrent_dropout_mask();
}
=======
//if (cell is DropoutRNNCellMixin)
//{
// cell.reset_dropout_mask();
// cell.reset_recurrent_dropout_mask();
//}
>>>>>>> master
}
private static RNNArgs PreConstruct(RNNArgs args)
{
if (args.Kwargs == null)
{
args.Kwargs = new Dictionary();
}
// If true, the output for masked timestep will be zeros, whereas in the
// false case, output from previous timestep is returned for masked timestep.
var zeroOutputForMask = (bool)args.Kwargs.Get("zero_output_for_mask", false);
Shape input_shape;
var propIS = (Shape)args.Kwargs.Get("input_shape", null);
var propID = (int?)args.Kwargs.Get("input_dim", null);
var propIL = (int?)args.Kwargs.Get("input_length", null);
if (propIS == null && (propID != null || propIL != null))
{
input_shape = new Shape(
propIL ?? -1,
propID ?? -1);
args.Kwargs["input_shape"] = input_shape;
}
return args;
}
public Tensors __call__(Tensors inputs, Tensor state = null, Tensor training = null)
{
throw new NotImplementedException();
<<<<<<< HEAD
=======
}
// 好像不能cell不能传接口类型
//public RNN New(IRnnArgCell cell,
// bool return_sequences = false,
// bool return_state = false,
// bool go_backwards = false,
// bool stateful = false,
// bool unroll = false,
// bool time_major = false)
// => new RNN(new RNNArgs
// {
// Cell = cell,
// ReturnSequences = return_sequences,
// ReturnState = return_state,
// GoBackwards = go_backwards,
// Stateful = stateful,
// Unroll = unroll,
// TimeMajor = time_major
// });
//public RNN New(List cell,
// bool return_sequences = false,
// bool return_state = false,
// bool go_backwards = false,
// bool stateful = false,
// bool unroll = false,
// bool time_major = false)
// => new RNN(new RNNArgs
// {
// Cell = cell,
// ReturnSequences = return_sequences,
// ReturnState = return_state,
// GoBackwards = go_backwards,
// Stateful = stateful,
// Unroll = unroll,
// TimeMajor = time_major
// });
protected Tensors get_initial_state(Tensor inputs)
{
Type type = cell.GetType();
MethodInfo MethodInfo = type.GetMethod("get_initial_state");
if (nest.is_nested(inputs))
{
// The input are nested sequences. Use the first element in the seq
// to get batch size and dtype.
inputs = nest.flatten(inputs)[0];
}
var input_shape = tf.shape(inputs);
var batch_size = args.TimeMajor ? input_shape[1] : input_shape[0];
var dtype = inputs.dtype;
Tensor init_state;
if (MethodInfo != null)
{
init_state = (Tensor)MethodInfo.Invoke(cell, new object[] { null, batch_size, dtype });
}
else
{
init_state = RNNUtils.generate_zero_filled_state(batch_size, cell.state_size, dtype);
}
//if (!nest.is_nested(init_state))
//{
// init_state = new List { init_state};
//}
return new List { init_state };
//return _generate_zero_filled_state_for_cell(null, null);
>>>>>>> master
}
// 好像不能cell不能传接口类型
//public RNN New(IRnnArgCell cell,
// bool return_sequences = false,
// bool return_state = false,
// bool go_backwards = false,
// bool stateful = false,
// bool unroll = false,
// bool time_major = false)
// => new RNN(new RNNArgs
// {
// Cell = cell,
// ReturnSequences = return_sequences,
// ReturnState = return_state,
// GoBackwards = go_backwards,
// Stateful = stateful,
// Unroll = unroll,
// TimeMajor = time_major
// });
//public RNN New(List cell,
// bool return_sequences = false,
// bool return_state = false,
// bool go_backwards = false,
// bool stateful = false,
// bool unroll = false,
// bool time_major = false)
// => new RNN(new RNNArgs
// {
// Cell = cell,
// ReturnSequences = return_sequences,
// ReturnState = return_state,
// GoBackwards = go_backwards,
// Stateful = stateful,
// Unroll = unroll,
// TimeMajor = time_major
// });
protected Tensors get_initial_state(Tensors inputs)
{
var get_initial_state_fn = _cell.GetType().GetMethod("get_initial_state");
var input = inputs[0];
var input_shape = inputs.shape;
var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0];
var dtype = input.dtype;
Tensors init_state = new Tensors();
if(get_initial_state_fn != null)
{
init_state = (Tensors)get_initial_state_fn.Invoke(_cell, new object[] { inputs, batch_size, dtype });
}
//if (_cell is RnnCellBase rnn_base_cell)
//{
// init_state = rnn_base_cell.GetInitialState(null, batch_size, dtype);
//}
else
{
init_state = RnnUtils.generate_zero_filled_state(tf.convert_to_tensor(batch_size), _cell.StateSize, dtype);
}
return init_state;
}
// Check whether the state_size contains multiple states.
<<<<<<< HEAD
public static bool is_multiple_state(GeneralizedTensorShape state_size)
=======
public static bool is_multiple_state(object state_size)
>>>>>>> master
{
return state_size.Shapes.Length > 1;
}
}
}