Browse Source

feat: add Bidirectional layer

tags/v0.110.4-Transformer-Model
“Wanglongzhi2001” 2 years ago
parent
commit
0c9437afcb
11 changed files with 428 additions and 18 deletions
  1. +20
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/BidirectionalArgs.cs
  2. +5
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs
  3. +5
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs
  4. +24
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/WrapperArgs.cs
  5. +13
    -1
      src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
  6. +14
    -0
      src/TensorFlowNET.Keras/Layers/LayersApi.cs
  7. +33
    -0
      src/TensorFlowNET.Keras/Layers/Rnn/BaseWrapper.cs
  8. +276
    -0
      src/TensorFlowNET.Keras/Layers/Rnn/Bidirectional.cs
  9. +19
    -12
      src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs
  10. +8
    -3
      src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
  11. +11
    -2
      test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs

+ 20
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/BidirectionalArgs.cs View File

@@ -0,0 +1,20 @@
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.NumPy;

namespace Tensorflow.Keras.ArgsDefinition
{
public class BidirectionalArgs : AutoSerializeLayerArgs
{
[JsonProperty("layer")]
public ILayer Layer { get; set; }
[JsonProperty("merge_mode")]
public string? MergeMode { get; set; }
[JsonProperty("backward_layer")]
public ILayer BackwardLayer { get; set; }
public NDArray Weights { get; set; }
}

}

+ 5
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs View File

@@ -5,5 +5,10 @@
// TODO: maybe change the `RNNArgs` and implement this class.
public bool UnitForgetBias { get; set; }
public int Implementation { get; set; }

public LSTMArgs Clone()
{
return (LSTMArgs)MemberwiseClone();
}
}
}

+ 5
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs View File

@@ -40,5 +40,10 @@ namespace Tensorflow.Keras.ArgsDefinition
public bool ZeroOutputForMask { get; set; } = false;
[JsonProperty("recurrent_dropout")]
public float RecurrentDropout { get; set; } = .0f;

public RNNArgs Clone()
{
return (RNNArgs)MemberwiseClone();
}
}
}

+ 24
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/WrapperArgs.cs View File

@@ -0,0 +1,24 @@
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;


namespace Tensorflow.Keras.ArgsDefinition
{
public class WrapperArgs : AutoSerializeLayerArgs
{
[JsonProperty("layer")]
public ILayer Layer { get; set; }

public WrapperArgs(ILayer layer)
{
Layer = layer;
}

public static implicit operator WrapperArgs(BidirectionalArgs args)
=> new WrapperArgs(args.Layer);
}

}

+ 13
- 1
src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs View File

@@ -258,7 +258,19 @@ namespace Tensorflow.Keras.Layers
float dropout = 0f,
float recurrent_dropout = 0f,
bool reset_after = true);

/// <summary>
/// Bidirectional wrapper for RNNs.
/// </summary>
/// <param name="layer">`keras.layers.RNN` instance, such as `keras.layers.LSTM` or `keras.layers.GRU`</param>
/// automatically.</param>
/// <returns></returns>
public ILayer Bidirectional(
ILayer layer,
string merge_mode = "concat",
NDArray weights = null,
ILayer backward_layer = null);

public ILayer Subtract();
}
}

+ 14
- 0
src/TensorFlowNET.Keras/Layers/LayersApi.cs View File

@@ -908,6 +908,20 @@ namespace Tensorflow.Keras.Layers
ResetAfter = reset_after
});

public ILayer Bidirectional(
ILayer layer,
string merge_mode = "concat",
NDArray weights = null,
ILayer backward_layer = null)
=> new Bidirectional(new BidirectionalArgs
{
Layer = layer,
MergeMode = merge_mode,
Weights = weights,
BackwardLayer = backward_layer
});


/// <summary>
///
/// </summary>


+ 33
- 0
src/TensorFlowNET.Keras/Layers/Rnn/BaseWrapper.cs View File

@@ -0,0 +1,33 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.Layers
{
/// <summary>
/// Abstract wrapper base class. Wrappers take another layer and augment it in various ways.
/// Do not use this class as a layer, it is only an abstract base class.
/// Two usable wrappers are the `TimeDistributed` and `Bidirectional` wrappers.
/// </summary>
public abstract class Wrapper: Layer
{
public ILayer _layer;
public Wrapper(WrapperArgs args):base(args)
{
_layer = args.Layer;
}

public virtual void Build(KerasShapesWrapper input_shape)
{
if (!_layer.Built)
{
_layer.build(input_shape);
}
built = true;
}

}
}

+ 276
- 0
src/TensorFlowNET.Keras/Layers/Rnn/Bidirectional.cs View File

@@ -0,0 +1,276 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.Layers
{
/// <summary>
/// Bidirectional wrapper for RNNs.
/// </summary>
public class Bidirectional: Wrapper
{
BidirectionalArgs _args;
RNN _forward_layer;
RNN _backward_layer;
RNN _layer;
bool _support_masking = true;
int _num_constants = 0;
bool _return_state;
bool _stateful;
bool _return_sequences;
InputSpec _input_spec;
RNNArgs _layer_args_copy;
public Bidirectional(BidirectionalArgs args):base(args)
{
_args = args;
if (_args.Layer is not ILayer)
throw new ValueError(
"Please initialize `Bidirectional` layer with a " +
$"`tf.keras.layers.Layer` instance. Received: {_args.Layer}");

if (_args.BackwardLayer is not null && _args.BackwardLayer is not ILayer)
throw new ValueError(
"`backward_layer` need to be a `tf.keras.layers.Layer` " +
$"instance. Received: {_args.BackwardLayer}");
if (!new List<string> { "sum", "mul", "ave", "concat", null }.Contains(_args.MergeMode))
{
throw new ValueError(
$"Invalid merge mode. Received: {_args.MergeMode}. " +
"Merge mode should be one of " +
"{\"sum\", \"mul\", \"ave\", \"concat\", null}"
);
}
if (_args.Layer is RNN)
{
_layer = _args.Layer as RNN;
}
else
{
throw new ValueError(
"Bidirectional only support RNN instance such as LSTM or GRU");
}
_return_state = _layer.Args.ReturnState;
_return_sequences = _layer.Args.ReturnSequences;
_stateful = _layer.Args.Stateful;
_layer_args_copy = _layer.Args.Clone();
// We don't want to track `layer` since we're already tracking the two
// copies of it we actually run.
// TODO(Wanglongzhi2001), since the feature of setattr_tracking has not been implemented.
// _setattr_tracking = false;
// super().__init__(layer, **kwargs)
// _setattr_tracking = true;

// Recreate the forward layer from the original layer config, so that it
// will not carry over any state from the layer.
var actualType = _layer.GetType();
if (actualType == typeof(LSTM))
{
var arg = _layer_args_copy as LSTMArgs;
_forward_layer = new LSTM(arg);
}
// TODO(Wanglongzhi2001), add GRU if case.
else
{
_forward_layer = new RNN(_layer.Cell, _layer_args_copy);
}
//_forward_layer = _recreate_layer_from_config(_layer);
if (_args.BackwardLayer is null)
{
_backward_layer = _recreate_layer_from_config(_layer, go_backwards:true);
}
else
{
_backward_layer = _args.BackwardLayer as RNN;
}
_forward_layer.Name = "forward_" + _forward_layer.Name;
_backward_layer.Name = "backward_" + _backward_layer.Name;
_verify_layer_config();

void force_zero_output_for_mask(RNN layer)
{
layer.Args.ZeroOutputForMask = layer.Args.ReturnSequences;
}

force_zero_output_for_mask(_forward_layer);
force_zero_output_for_mask(_backward_layer);

if (_args.Weights is not null)
{
var nw = len(_args.Weights);
_forward_layer.set_weights(_args.Weights[$":,{nw / 2}"]);
_backward_layer.set_weights(_args.Weights[$"{nw / 2},:"]);
}

_input_spec = _layer.InputSpec;
}

private void _verify_layer_config()
{
if (_forward_layer.Args.GoBackwards == _backward_layer.Args.GoBackwards)
{
throw new ValueError(
"Forward layer and backward layer should have different " +
"`go_backwards` value." +
"forward_layer.go_backwards = " +
$"{_forward_layer.Args.GoBackwards}," +
"backward_layer.go_backwards = " +
$"{_backward_layer.Args.GoBackwards}");
}
if (_forward_layer.Args.Stateful != _backward_layer.Args.Stateful)
{
throw new ValueError(
"Forward layer and backward layer are expected to have "+
$"the same value for attribute stateful, got "+
$"{_forward_layer.Args.Stateful} for forward layer and "+
$"{_backward_layer.Args.Stateful} for backward layer");
}
if (_forward_layer.Args.ReturnState != _backward_layer.Args.ReturnState)
{
throw new ValueError(
"Forward layer and backward layer are expected to have " +
$"the same value for attribute return_state, got " +
$"{_forward_layer.Args.ReturnState} for forward layer and " +
$"{_backward_layer.Args.ReturnState} for backward layer");
}
if (_forward_layer.Args.ReturnSequences != _backward_layer.Args.ReturnSequences)
{
throw new ValueError(
"Forward layer and backward layer are expected to have " +
$"the same value for attribute return_sequences, got " +
$"{_forward_layer.Args.ReturnSequences} for forward layer and " +
$"{_backward_layer.Args.ReturnSequences} for backward layer");
}
}

private RNN _recreate_layer_from_config(RNN layer, bool go_backwards = false)
{
var config = layer.get_config() as RNNArgs;
var cell = layer.Cell;
if (go_backwards)
{
config.GoBackwards = !config.GoBackwards;
}
var actualType = layer.GetType();
if (actualType == typeof(LSTM))
{
var arg = config as LSTMArgs;
return new LSTM(arg);
}
else
{
return new RNN(cell, config);
}
}

public override void build(KerasShapesWrapper input_shape)
{
_buildInputShape = input_shape;
tf_with(ops.name_scope(_forward_layer.Name), scope=>
{
_forward_layer.build(input_shape);
});
tf_with(ops.name_scope(_backward_layer.Name), scope =>
{
_backward_layer.build(input_shape);
});
built = true;
}

protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
// `Bidirectional.call` implements the same API as the wrapped `RNN`.

Tensors forward_inputs;
Tensors backward_inputs;
Tensors forward_state;
Tensors backward_state;
// if isinstance(inputs, list) and len(inputs) > 1:
if (inputs.Length > 1)
{
// initial_states are keras tensors, which means they are passed
// in together with inputs as list. The initial_states need to be
// split into forward and backward section, and be feed to layers
// accordingly.
forward_inputs = new Tensors { inputs[0] };
backward_inputs = new Tensors { inputs[0] };
var pivot = (len(inputs) - _num_constants) / 2 + 1;
// add forward initial state
forward_inputs.Concat(new Tensors { inputs[$"1:{pivot}"] });
if (_num_constants != 0)
// add backward initial state
backward_inputs.Concat(new Tensors { inputs[$"{pivot}:"] });
else
{
// add backward initial state
backward_inputs.Concat(new Tensors { inputs[$"{pivot}:{-_num_constants}"] });
// add constants for forward and backward layers
forward_inputs.Concat(new Tensors { inputs[$"{-_num_constants}:"] });
backward_inputs.Concat(new Tensors { inputs[$"{-_num_constants}:"] });
}
forward_state = null;
backward_state = null;
}
else if (state is not null)
{
// initial_states are not keras tensors, eg eager tensor from np
// array. They are only passed in from kwarg initial_state, and
// should be passed to forward/backward layer via kwarg
// initial_state as well.
forward_inputs = inputs;
backward_inputs = inputs;
var half = len(state) / 2;
forward_state = state[$":{half}"];
backward_state = state[$"{half}:"];
}
else
{
forward_inputs = inputs;
backward_inputs = inputs;
forward_state = null;
backward_state = null;
}
var y = _forward_layer.Apply(forward_inputs, forward_state);
var y_rev = _backward_layer.Apply(backward_inputs, backward_state);

Tensors states = new();
if (_return_state)
{
states = y["1:"] + y_rev["1:"];
y = y[0];
y_rev = y_rev[0];
}

if (_return_sequences)
{
int time_dim = _forward_layer.Args.TimeMajor ? 0 : 1;
y_rev = keras.backend.reverse(y_rev, time_dim);
}
Tensors output;
if (_args.MergeMode == "concat")
output = keras.backend.concatenate(new Tensors { y.Single(), y_rev.Single() });
else if (_args.MergeMode == "sum")
output = y.Single() + y_rev.Single();
else if (_args.MergeMode == "ave")
output = (y.Single() + y_rev.Single()) / 2;
else if (_args.MergeMode == "mul")
output = y.Single() * y_rev.Single();
else if (_args.MergeMode is null)
output = new Tensors { y.Single(), y_rev.Single() };
else
throw new ValueError(
"Unrecognized value for `merge_mode`. " +
$"Received: {_args.MergeMode}" +
"Expected values are [\"concat\", \"sum\", \"ave\", \"mul\"]");
if (_return_state)
{
if (_args.MergeMode is not null)
return new Tensors { output.Single(), states.Single()};
}
return output;
}
}
}

+ 19
- 12
src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs View File

@@ -3,6 +3,7 @@ using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Common.Types;
using Tensorflow.Common.Extensions;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.Layers
{
@@ -14,15 +15,15 @@ namespace Tensorflow.Keras.Layers
/// </summary>
public class LSTM : RNN
{
LSTMArgs args;
LSTMArgs _args;
InputSpec[] _state_spec;
InputSpec _input_spec;
bool _could_use_gpu_kernel;
public LSTMArgs Args { get => _args; }
public LSTM(LSTMArgs args) :
base(CreateCell(args), args)
{
this.args = args;
_args = args;
_input_spec = new InputSpec(ndim: 3);
_state_spec = new[] { args.Units, args.Units }.Select(dim => new InputSpec(shape: (-1, dim))).ToArray();
_could_use_gpu_kernel = args.Activation == keras.activations.Tanh
@@ -71,7 +72,7 @@ namespace Tensorflow.Keras.Layers

var single_input = inputs.Single;
var input_shape = single_input.shape;
var timesteps = args.TimeMajor ? input_shape[0] : input_shape[1];
var timesteps = _args.TimeMajor ? input_shape[0] : input_shape[1];

_maybe_reset_cell_dropout_mask(Cell);

@@ -87,26 +88,26 @@ namespace Tensorflow.Keras.Layers
inputs,
initial_state,
constants: null,
go_backwards: args.GoBackwards,
go_backwards: _args.GoBackwards,
mask: mask,
unroll: args.Unroll,
unroll: _args.Unroll,
input_length: ops.convert_to_tensor(timesteps),
time_major: args.TimeMajor,
zero_output_for_mask: args.ZeroOutputForMask,
return_all_outputs: args.ReturnSequences
time_major: _args.TimeMajor,
zero_output_for_mask: _args.ZeroOutputForMask,
return_all_outputs: _args.ReturnSequences
);

Tensor output;
if (args.ReturnSequences)
if (_args.ReturnSequences)
{
output = keras.backend.maybe_convert_to_ragged(false, outputs, (int)timesteps, args.GoBackwards);
output = keras.backend.maybe_convert_to_ragged(false, outputs, (int)timesteps, _args.GoBackwards);
}
else
{
output = last_output;
}

if (args.ReturnState)
if (_args.ReturnState)
{
return new Tensor[] { output }.Concat(states).ToArray().ToTensors();
}
@@ -115,5 +116,11 @@ namespace Tensorflow.Keras.Layers
return output;
}
}

public override IKerasConfig get_config()
{
return _args;
}

}
}

+ 8
- 3
src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs View File

@@ -31,7 +31,9 @@ namespace Tensorflow.Keras.Layers
protected IVariableV1 _kernel;
protected IVariableV1 _bias;
private IRnnCell _cell;
protected IRnnCell Cell

public RNNArgs Args { get => _args; }
public IRnnCell Cell
{
get
{
@@ -570,10 +572,13 @@ namespace Tensorflow.Keras.Layers
var input_shape = array_ops.shape(inputs);
var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0];
var dtype = input.dtype;

Tensors init_state = Cell.GetInitialState(null, batch_size, dtype);

return init_state;
}

public override IKerasConfig get_config()
{
return _args;
}
}
}

+ 11
- 2
test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs View File

@@ -5,6 +5,7 @@ using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Saving;
@@ -38,8 +39,6 @@ namespace Tensorflow.Keras.UnitTest.Layers
var cells = new IRnnCell[] { tf.keras.layers.SimpleRNNCell(4), tf.keras.layers.SimpleRNNCell(5) };
var stackedRNNCell = tf.keras.layers.StackedRNNCells(cells);
var (output, state) = stackedRNNCell.Apply(inputs, states);
Console.WriteLine(output);
Console.WriteLine(state.shape);
Assert.AreEqual((32, 5), output.shape);
Assert.AreEqual((32, 4), state[0].shape);
}
@@ -108,6 +107,7 @@ namespace Tensorflow.Keras.UnitTest.Layers
var inputs = tf.random.normal((32, 10, 8));
var cell = tf.keras.layers.SimpleRNNCell(10, dropout: 0.5f, recurrent_dropout: 0.5f);
var rnn = tf.keras.layers.RNN(cell: cell);
var cgf = rnn.get_config();
var output = rnn.Apply(inputs);
Assert.AreEqual((32, 10), output.shape);

@@ -145,5 +145,14 @@ namespace Tensorflow.Keras.UnitTest.Layers
Assert.AreEqual((32, 4), output.shape);

}

[TestMethod]
public void Bidirectional()
{
var bi = tf.keras.layers.Bidirectional(keras.layers.LSTM(10, return_sequences:true));
var inputs = tf.random.normal((32, 10, 8));
var outputs = bi.Apply(inputs);
Assert.AreEqual((32, 10, 20), outputs.shape);
}
}
}

Loading…
Cancel
Save