Browse Source

Merge pull request #1162 from Wanglongzhi2001/master

fix: remove the reflection in the implemention of Bidirectional
tags/v0.110.4-Transformer-Model
Haiping GitHub 2 years ago
parent
commit
adeed05f64
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 20 additions and 11 deletions
  1. +20
    -11
      src/TensorFlowNET.Keras/Layers/Rnn/Bidirectional.cs

+ 20
- 11
src/TensorFlowNET.Keras/Layers/Rnn/Bidirectional.cs View File

@@ -13,17 +13,17 @@ namespace Tensorflow.Keras.Layers
/// </summary>
public class Bidirectional: Wrapper
{
BidirectionalArgs _args;
RNN _forward_layer;
RNN _backward_layer;
RNN _layer;
bool _support_masking = true;
int _num_constants = 0;
bool _support_masking = true;
bool _return_state;
bool _stateful;
bool _return_sequences;
InputSpec _input_spec;
BidirectionalArgs _args;
RNNArgs _layer_args_copy;
RNN _forward_layer;
RNN _backward_layer;
RNN _layer;
InputSpec _input_spec;
public Bidirectional(BidirectionalArgs args):base(args)
{
_args = args;
@@ -66,12 +66,16 @@ namespace Tensorflow.Keras.Layers

// Recreate the forward layer from the original layer config, so that it
// will not carry over any state from the layer.
var actualType = _layer.GetType();
if (actualType == typeof(LSTM))
if (_layer is LSTM)
{
var arg = _layer_args_copy as LSTMArgs;
_forward_layer = new LSTM(arg);
}
else if(_layer is SimpleRNN)
{
var arg = _layer_args_copy as SimpleRNNArgs;
_forward_layer = new SimpleRNN(arg);
}
// TODO(Wanglongzhi2001), add GRU if case.
else
{
@@ -154,12 +158,18 @@ namespace Tensorflow.Keras.Layers
{
config.GoBackwards = !config.GoBackwards;
}
var actualType = layer.GetType();
if (actualType == typeof(LSTM))
if (layer is LSTM)
{
var arg = config as LSTMArgs;
return new LSTM(arg);
}
else if(layer is SimpleRNN)
{
var arg = config as SimpleRNNArgs;
return new SimpleRNN(arg);
}
// TODO(Wanglongzhi2001), add GRU if case.
else
{
return new RNN(cell, config);
@@ -183,7 +193,6 @@ namespace Tensorflow.Keras.Layers
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
// `Bidirectional.call` implements the same API as the wrapped `RNN`.

Tensors forward_inputs;
Tensors backward_inputs;
Tensors forward_state;


Loading…
Cancel
Save