|
|
@@ -11,6 +11,7 @@ using Tensorflow.Common.Extensions; |
|
|
|
using System.Linq.Expressions; |
|
|
|
using Tensorflow.Keras.Utils; |
|
|
|
using Tensorflow.Common.Types; |
|
|
|
using System.Runtime.CompilerServices; |
|
|
|
// from tensorflow.python.distribute import distribution_strategy_context as ds_context; |
|
|
|
|
|
|
|
namespace Tensorflow.Keras.Layers.Rnn |
|
|
@@ -30,7 +31,19 @@ namespace Tensorflow.Keras.Layers.Rnn |
|
|
|
private int _num_constants; |
|
|
|
protected IVariableV1 _kernel; |
|
|
|
protected IVariableV1 _bias; |
|
|
|
protected IRnnCell _cell; |
|
|
|
private IRnnCell _cell; |
|
|
|
protected IRnnCell Cell |
|
|
|
{ |
|
|
|
get |
|
|
|
{ |
|
|
|
return _cell; |
|
|
|
} |
|
|
|
init |
|
|
|
{ |
|
|
|
_cell = value; |
|
|
|
_self_tracked_trackables.Add(_cell); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
public RNN(RNNArgs args) : base(PreConstruct(args)) |
|
|
|
{ |
|
|
@@ -40,14 +53,14 @@ namespace Tensorflow.Keras.Layers.Rnn |
|
|
|
// if is StackedRnncell |
|
|
|
if (args.Cells != null) |
|
|
|
{ |
|
|
|
_cell = new StackedRNNCells(new StackedRNNCellsArgs |
|
|
|
Cell = new StackedRNNCells(new StackedRNNCellsArgs |
|
|
|
{ |
|
|
|
Cells = args.Cells |
|
|
|
}); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
_cell = args.Cell; |
|
|
|
Cell = args.Cell; |
|
|
|
} |
|
|
|
|
|
|
|
// get input_shape |
|
|
@@ -65,7 +78,7 @@ namespace Tensorflow.Keras.Layers.Rnn |
|
|
|
if (_states == null) |
|
|
|
{ |
|
|
|
// CHECK(Rinne): check if this is correct. |
|
|
|
var nested = _cell.StateSize.MapStructure<Tensor?>(x => null); |
|
|
|
var nested = Cell.StateSize.MapStructure<Tensor?>(x => null); |
|
|
|
_states = nested.AsNest().ToTensors(); |
|
|
|
} |
|
|
|
return _states; |
|
|
@@ -83,7 +96,7 @@ namespace Tensorflow.Keras.Layers.Rnn |
|
|
|
} |
|
|
|
|
|
|
|
// state_size is a array of ints or a positive integer |
|
|
|
var state_size = _cell.StateSize.ToSingleShape(); |
|
|
|
var state_size = Cell.StateSize.ToSingleShape(); |
|
|
|
|
|
|
|
// TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor |
|
|
|
Func<Shape, Shape> _get_output_shape; |
|
|
@@ -110,12 +123,12 @@ namespace Tensorflow.Keras.Layers.Rnn |
|
|
|
return output_shape; |
|
|
|
}; |
|
|
|
|
|
|
|
Type type = _cell.GetType(); |
|
|
|
Type type = Cell.GetType(); |
|
|
|
PropertyInfo output_size_info = type.GetProperty("output_size"); |
|
|
|
Shape output_shape; |
|
|
|
if (output_size_info != null) |
|
|
|
{ |
|
|
|
output_shape = nest.map_structure(_get_output_shape, _cell.OutputSize.ToSingleShape()); |
|
|
|
output_shape = nest.map_structure(_get_output_shape, Cell.OutputSize.ToSingleShape()); |
|
|
|
// TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型 |
|
|
|
output_shape = (output_shape.Length == 1 ? (int)output_shape[0] : output_shape); |
|
|
|
} |
|
|
@@ -171,7 +184,9 @@ namespace Tensorflow.Keras.Layers.Rnn |
|
|
|
|
|
|
|
public override void build(KerasShapesWrapper input_shape) |
|
|
|
{ |
|
|
|
object get_input_spec(Shape shape) |
|
|
|
input_shape = new KerasShapesWrapper(input_shape.Shapes[0]); |
|
|
|
|
|
|
|
InputSpec get_input_spec(Shape shape) |
|
|
|
{ |
|
|
|
var input_spec_shape = shape.as_int_list(); |
|
|
|
|
|
|
@@ -213,10 +228,13 @@ namespace Tensorflow.Keras.Layers.Rnn |
|
|
|
// numpy inputs. |
|
|
|
|
|
|
|
|
|
|
|
if (!_cell.Built) |
|
|
|
if (Cell is Layer layer && !layer.Built) |
|
|
|
{ |
|
|
|
_cell.build(input_shape); |
|
|
|
layer.build(input_shape); |
|
|
|
layer.Built = true; |
|
|
|
} |
|
|
|
|
|
|
|
this.built = true; |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
@@ -247,10 +265,10 @@ namespace Tensorflow.Keras.Layers.Rnn |
|
|
|
|
|
|
|
(inputs, initial_state, constants) = _process_inputs(inputs, initial_state, constants); |
|
|
|
|
|
|
|
_maybe_reset_cell_dropout_mask(_cell); |
|
|
|
if (_cell is StackedRNNCells) |
|
|
|
_maybe_reset_cell_dropout_mask(Cell); |
|
|
|
if (Cell is StackedRNNCells) |
|
|
|
{ |
|
|
|
var stack_cell = _cell as StackedRNNCells; |
|
|
|
var stack_cell = Cell as StackedRNNCells; |
|
|
|
foreach (IRnnCell cell in stack_cell.Cells) |
|
|
|
{ |
|
|
|
_maybe_reset_cell_dropout_mask(cell); |
|
|
@@ -300,10 +318,10 @@ namespace Tensorflow.Keras.Layers.Rnn |
|
|
|
bool is_tf_rnn_cell = false; |
|
|
|
if (constants is not null) |
|
|
|
{ |
|
|
|
if (!_cell.SupportOptionalArgs) |
|
|
|
if (!Cell.SupportOptionalArgs) |
|
|
|
{ |
|
|
|
throw new ValueError( |
|
|
|
$"RNN cell {_cell} does not support constants." + |
|
|
|
$"RNN cell {Cell} does not support constants." + |
|
|
|
$"Received: constants={constants}"); |
|
|
|
} |
|
|
|
|
|
|
@@ -312,7 +330,7 @@ namespace Tensorflow.Keras.Layers.Rnn |
|
|
|
constants = new Tensors(states.TakeLast(_num_constants).ToArray()); |
|
|
|
states = new Tensors(states.SkipLast(_num_constants).ToArray()); |
|
|
|
states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states; |
|
|
|
var (output, new_states) = _cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants }); |
|
|
|
var (output, new_states) = Cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants }); |
|
|
|
return (output, new_states.Single); |
|
|
|
}; |
|
|
|
} |
|
|
@@ -321,7 +339,7 @@ namespace Tensorflow.Keras.Layers.Rnn |
|
|
|
step = (inputs, states) => |
|
|
|
{ |
|
|
|
states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states.First()) : states; |
|
|
|
var (output, new_states) = _cell.Apply(inputs, states); |
|
|
|
var (output, new_states) = Cell.Apply(inputs, states); |
|
|
|
return (output, new_states); |
|
|
|
}; |
|
|
|
} |
|
|
@@ -562,7 +580,7 @@ namespace Tensorflow.Keras.Layers.Rnn |
|
|
|
var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0]; |
|
|
|
var dtype = input.dtype; |
|
|
|
|
|
|
|
Tensors init_state = _cell.GetInitialState(null, batch_size, dtype); |
|
|
|
Tensors init_state = Cell.GetInitialState(null, batch_size, dtype); |
|
|
|
|
|
|
|
return init_state; |
|
|
|
} |
|
|
|