|
@@ -13,17 +13,17 @@ namespace Tensorflow.Keras.Layers |
|
|
/// </summary> |
|
|
/// </summary> |
|
|
public class Bidirectional: Wrapper |
|
|
public class Bidirectional: Wrapper |
|
|
{ |
|
|
{ |
|
|
BidirectionalArgs _args; |
|
|
|
|
|
RNN _forward_layer; |
|
|
|
|
|
RNN _backward_layer; |
|
|
|
|
|
RNN _layer; |
|
|
|
|
|
bool _support_masking = true; |
|
|
|
|
|
int _num_constants = 0; |
|
|
int _num_constants = 0; |
|
|
|
|
|
bool _support_masking = true; |
|
|
bool _return_state; |
|
|
bool _return_state; |
|
|
bool _stateful; |
|
|
bool _stateful; |
|
|
bool _return_sequences; |
|
|
bool _return_sequences; |
|
|
InputSpec _input_spec; |
|
|
|
|
|
|
|
|
BidirectionalArgs _args; |
|
|
RNNArgs _layer_args_copy; |
|
|
RNNArgs _layer_args_copy; |
|
|
|
|
|
RNN _forward_layer; |
|
|
|
|
|
RNN _backward_layer; |
|
|
|
|
|
RNN _layer; |
|
|
|
|
|
InputSpec _input_spec; |
|
|
public Bidirectional(BidirectionalArgs args):base(args) |
|
|
public Bidirectional(BidirectionalArgs args):base(args) |
|
|
{ |
|
|
{ |
|
|
_args = args; |
|
|
_args = args; |
|
@@ -66,12 +66,16 @@ namespace Tensorflow.Keras.Layers |
|
|
|
|
|
|
|
|
// Recreate the forward layer from the original layer config, so that it |
|
|
// Recreate the forward layer from the original layer config, so that it |
|
|
// will not carry over any state from the layer. |
|
|
// will not carry over any state from the layer. |
|
|
var actualType = _layer.GetType(); |
|
|
|
|
|
if (actualType == typeof(LSTM)) |
|
|
|
|
|
|
|
|
if (_layer is LSTM) |
|
|
{ |
|
|
{ |
|
|
var arg = _layer_args_copy as LSTMArgs; |
|
|
var arg = _layer_args_copy as LSTMArgs; |
|
|
_forward_layer = new LSTM(arg); |
|
|
_forward_layer = new LSTM(arg); |
|
|
} |
|
|
} |
|
|
|
|
|
else if(_layer is SimpleRNN) |
|
|
|
|
|
{ |
|
|
|
|
|
var arg = _layer_args_copy as SimpleRNNArgs; |
|
|
|
|
|
_forward_layer = new SimpleRNN(arg); |
|
|
|
|
|
} |
|
|
// TODO(Wanglongzhi2001), add GRU if case. |
|
|
// TODO(Wanglongzhi2001), add GRU if case. |
|
|
else |
|
|
else |
|
|
{ |
|
|
{ |
|
@@ -154,12 +158,18 @@ namespace Tensorflow.Keras.Layers |
|
|
{ |
|
|
{ |
|
|
config.GoBackwards = !config.GoBackwards; |
|
|
config.GoBackwards = !config.GoBackwards; |
|
|
} |
|
|
} |
|
|
var actualType = layer.GetType(); |
|
|
|
|
|
if (actualType == typeof(LSTM)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (layer is LSTM) |
|
|
{ |
|
|
{ |
|
|
var arg = config as LSTMArgs; |
|
|
var arg = config as LSTMArgs; |
|
|
return new LSTM(arg); |
|
|
return new LSTM(arg); |
|
|
} |
|
|
} |
|
|
|
|
|
else if(layer is SimpleRNN) |
|
|
|
|
|
{ |
|
|
|
|
|
var arg = config as SimpleRNNArgs; |
|
|
|
|
|
return new SimpleRNN(arg); |
|
|
|
|
|
} |
|
|
|
|
|
// TODO(Wanglongzhi2001), add GRU if case. |
|
|
else |
|
|
else |
|
|
{ |
|
|
{ |
|
|
return new RNN(cell, config); |
|
|
return new RNN(cell, config); |
|
@@ -183,7 +193,6 @@ namespace Tensorflow.Keras.Layers |
|
|
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) |
|
|
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) |
|
|
{ |
|
|
{ |
|
|
// `Bidirectional.call` implements the same API as the wrapped `RNN`. |
|
|
// `Bidirectional.call` implements the same API as the wrapped `RNN`. |
|
|
|
|
|
|
|
|
Tensors forward_inputs; |
|
|
Tensors forward_inputs; |
|
|
Tensors backward_inputs; |
|
|
Tensors backward_inputs; |
|
|
Tensors forward_state; |
|
|
Tensors forward_state; |
|
|