Browse Source

Merge pull request #1168 from Wanglongzhi2001/master

feat: implement GRU layer
tags/v0.110.4-Transformer-Model
Haiping GitHub 2 years ago
parent
commit
70d681c020
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 300 additions and 41 deletions
  1. +29
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUArgs.cs
  2. +13
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUOptionalArgs.cs
  3. +19
    -0
      src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
  4. +60
    -1
      src/TensorFlowNET.Keras/Layers/LayersApi.cs
  5. +168
    -0
      src/TensorFlowNET.Keras/Layers/Rnn/GRU.cs
  6. +2
    -40
      src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
  7. +9
    -0
      test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs

+ 29
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUArgs.cs View File

@@ -0,0 +1,29 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class GRUArgs : AutoSerializeLayerArgs
{
public int Units { get; set; }
public Activation Activation { get; set; }
public Activation RecurrentActivation { get; set; }
public bool UseBias { get; set; } = true;
public float Dropout { get; set; } = .0f;
public float RecurrentDropout { get; set; } = .0f;
public IInitializer KernelInitializer { get; set; }
public IInitializer RecurrentInitializer { get; set; }
public IInitializer BiasInitializer { get; set; }
public bool ReturnSequences { get;set; }
public bool ReturnState { get;set; }
public bool GoBackwards { get;set; }
public bool Stateful { get;set; }
public bool Unroll { get;set; }
public bool TimeMajor { get;set; }
public bool ResetAfter { get;set; }
public int Implementation { get; set; } = 2;

}

}

+ 13
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUOptionalArgs.cs View File

@@ -0,0 +1,13 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class GRUOptionalArgs
{
public string Identifier => "GRU";

public Tensor Mask { get; set; } = null;
}
}

+ 19
- 0
src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs View File

@@ -259,6 +259,25 @@ namespace Tensorflow.Keras.Layers
float recurrent_dropout = 0f,
bool reset_after = true);

public ILayer GRU(
int units,
string activation = "tanh",
string recurrent_activation = "sigmoid",
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros",
float dropout = 0f,
float recurrent_dropout = 0f,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool unroll = false,
bool time_major = false,
bool reset_after = true
);

/// <summary>
/// Bidirectional wrapper for RNNs.
/// </summary>


+ 60
- 1
src/TensorFlowNET.Keras/Layers/LayersApi.cs View File

@@ -784,7 +784,7 @@ namespace Tensorflow.Keras.Layers
string recurrent_activation = "sigmoid",
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal", // TODO(Wanglongzhi2001),glorot_uniform has not been developed.
string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros",
bool unit_forget_bias = true,
float dropout = 0f,
@@ -908,6 +908,65 @@ namespace Tensorflow.Keras.Layers
ResetAfter = reset_after
});

/// <summary>
/// Gated Recurrent Unit - Cho et al. 2014.
/// </summary>
/// <param name="units">Positive integer, dimensionality of the output space.</param>
/// <param name="activation">Activation function to use. If you pass `None`, no activation is applied.(ie. "linear" activation: `a(x) = x`).</param>
/// <param name="recurrent_activation">Activation function to use for the recurrent step. If you pass `None`, no activation is applied. (ie. "linear" activation: `a(x) = x`).</param>
/// <param name="use_bias">Boolean, (default `True`), whether the layer uses a bias vector.</param>
/// <param name="kernel_initializer">Initializer for the `kernel` weights matrix, used for the linear transformation of the inputs. Default: `glorot_uniform`.</param>
/// <param name="recurrent_initializer">Initializer for the `recurrent_kernel` weights matrix, used for the linear transformation of the recurrent state. Default: `orthogonal`.</param>
/// <param name="bias_initializer">Initializer for the bias vector. Default: `zeros`.</param>
/// <param name="dropout">Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs. Default: 0.</param>
/// <param name="recurrent_dropout">Float between 0 and 1. Fraction of the units to drop for the linear transformation of the recurrent state. Default: 0.</param>
/// <param name="implementation"></param>
/// <param name="return_sequences">Boolean. Whether to return the last output in the output sequence, or the full sequence. Default: `False`.</param>
/// <param name="return_state">Boolean. Whether to return the last state in addition to the output. Default: `False`.</param>
/// <param name="go_backwards">Boolean (default `False`). If True, process the input sequence backwards and return the reversed sequence.</param>
/// <param name="stateful">Boolean (default False). If True, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch.</param>
/// <param name="unroll">Boolean (default False). If True, the network will be unrolled, else a symbolic loop will be used. Unrolling can speed-up a RNN,</param>
/// <param name="time_major">The shape format of the `inputs` and `outputs` tensors.</param>
/// <param name="reset_after">GRU convention (whether to apply reset gate after or before matrix multiplication). False = "before", True = "after" (default and cuDNN compatible).</param>
/// <returns></returns>
public ILayer GRU(
int units,
string activation = "tanh",
string recurrent_activation = "sigmoid",
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros",
float dropout = 0f,
float recurrent_dropout = 0f,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool unroll = false,
bool time_major = false,
bool reset_after = true
)
=> new GRU(new GRUArgs
{
Units = units,
Activation = keras.activations.GetActivationFromName(activation),
RecurrentActivation = keras.activations.GetActivationFromName(recurrent_activation),
KernelInitializer = GetInitializerByName(kernel_initializer),
RecurrentInitializer = GetInitializerByName(recurrent_initializer),
BiasInitializer = GetInitializerByName(bias_initializer),
UseBias = use_bias,
Dropout = dropout,
RecurrentDropout = recurrent_dropout,
ReturnSequences = return_sequences,
ReturnState = return_state,
GoBackwards = go_backwards,
Stateful = stateful,
TimeMajor = time_major,
Unroll = unroll,
ResetAfter = reset_after
});

public ILayer Bidirectional(
ILayer layer,
string merge_mode = "concat",


+ 168
- 0
src/TensorFlowNET.Keras/Layers/Rnn/GRU.cs View File

@@ -0,0 +1,168 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Common.Extensions;
using Tensorflow.Common.Types;
using Tensorflow.Keras.Saving;


namespace Tensorflow.Keras.Layers
{
public class GRU : RNN
{
GRUArgs _args;
private static GRUCell _cell;

bool _return_runtime;
public GRUCell Cell { get => _cell; }
public int units { get => _args.Units; }
public Activation activation { get => _args.Activation; }
public Activation recurrent_activation { get => _args.RecurrentActivation; }
public bool use_bias { get => _args.UseBias; }
public float dropout { get => _args.Dropout; }
public float recurrent_dropout { get => _args.RecurrentDropout; }
public IInitializer kernel_initializer { get => _args.KernelInitializer; }
public IInitializer recurrent_initializer { get => _args.RecurrentInitializer; }
public IInitializer bias_initializer { get => _args.BiasInitializer; }
public int implementation { get => _args.Implementation; }
public bool reset_after { get => _args.ResetAfter; }

public GRU(GRUArgs args) : base(CreateCell(args), PreConstruct(args))
{
_args = args;

if (_args.Implementation == 0)
{
// Use the red output to act as a warning message that can also be used under the release version
Console.ForegroundColor = ConsoleColor.Red;
Console.WriteLine("Warning: `implementation=0` has been deprecated, "+
"and now defaults to `implementation=2`."+
"Please update your layer call.");
Console.ResetColor();
}

GRUCell cell = new GRUCell(new GRUCellArgs
{
Units = _args.Units,
Activation = _args.Activation,
RecurrentActivation = _args.RecurrentActivation,
UseBias = _args.UseBias,
Dropout = _args.Dropout,
RecurrentDropout = _args.RecurrentDropout,
KernelInitializer = _args.KernelInitializer,
RecurrentInitializer = _args.RecurrentInitializer,
BiasInitializer = _args.BiasInitializer,
ResetAfter = _args.ResetAfter,
Implementation = _args.Implementation
});
_cell = cell;
}

protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
GRUOptionalArgs? gru_optional_args = optional_args as GRUOptionalArgs;
if (optional_args is not null && gru_optional_args is null)
{
throw new ArgumentException("The type of optional args should be `GRUOptionalArgs`.");
}
Tensors? mask = gru_optional_args?.Mask;

// Not support ragger input temporarily;
int row_length = 0;
bool is_ragged_input = false;

_validate_args_if_ragged(is_ragged_input, mask);

// GRU does not support constants.Ignore it during process.
(inputs, initial_state, _) = this._process_inputs(inputs, initial_state, null);

if (mask.Length > 1)
{
mask = mask[0];
}

var input_shape = inputs.shape;
var timesteps = _args.TimeMajor ? input_shape[0] : input_shape[1];


// TODO(Wanglongzhi2001), finish _could_use_gpu_kernel part
Func<Tensors, Tensors, (Tensors, Tensors)> step = (cell_inputs, cell_states) =>
{
var res = Cell.Apply(cell_inputs, cell_states, training is null ? true : training.Value);
var (output, state) = res;
return (output, state);
};

var (last_output, outputs, states) = keras.backend.rnn(
step,
inputs,
initial_state,
constants: null,
go_backwards: _args.GoBackwards,
mask: mask,
unroll: _args.Unroll,
input_length: ops.convert_to_tensor(timesteps),
time_major: _args.TimeMajor,
zero_output_for_mask: base.Args.ZeroOutputForMask,
return_all_outputs: _args.ReturnSequences
);

Tensors output;
if (_args.ReturnSequences)
{
output = outputs;
}
else
{
output = last_output;
}

if (_args.ReturnState)
{
output = new Tensors { output, states };
}
return output;
}

private static IRnnCell CreateCell(GRUArgs gruArgs)
{
return new GRUCell(new GRUCellArgs
{
Units = gruArgs.Units,
Activation = gruArgs.Activation,
RecurrentActivation = gruArgs.RecurrentActivation,
UseBias = gruArgs.UseBias,
Dropout = gruArgs.Dropout,
RecurrentDropout = gruArgs.RecurrentDropout,
KernelInitializer = gruArgs.KernelInitializer,
RecurrentInitializer = gruArgs.RecurrentInitializer,
BiasInitializer = gruArgs.BiasInitializer,
ResetAfter = gruArgs.ResetAfter,
Implementation = gruArgs.Implementation
});
}

private static RNNArgs PreConstruct(GRUArgs args)
{
return new RNNArgs
{
ReturnSequences = args.ReturnSequences,
ReturnState = args.ReturnState,
GoBackwards = args.GoBackwards,
Stateful = args.Stateful,
Unroll = args.Unroll,
TimeMajor = args.TimeMajor,
Units = args.Units,
Activation = args.Activation,
RecurrentActivation = args.RecurrentActivation,
UseBias = args.UseBias,
Dropout = args.Dropout,
RecurrentDropout = args.RecurrentDropout,
KernelInitializer = args.KernelInitializer,
RecurrentInitializer = args.RecurrentInitializer,
BiasInitializer = args.BiasInitializer
};
}
}
}

+ 2
- 40
src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs View File

@@ -25,8 +25,8 @@ namespace Tensorflow.Keras.Layers
private RNNArgs _args;
private object _input_spec = null; // or NoneValue??
private object _state_spec = null;
private Tensors _states = null;
private object _constants_spec = null;
private Tensors _states = null;
private int _num_constants;
protected IVariableV1 _kernel;
protected IVariableV1 _bias;
@@ -469,7 +469,7 @@ namespace Tensorflow.Keras.Layers
return (inputs, initial_state, constants);
}

private void _validate_args_if_ragged(bool is_ragged_input, Tensors mask)
protected void _validate_args_if_ragged(bool is_ragged_input, Tensors mask)
{
if (!is_ragged_input)
{
@@ -528,44 +528,6 @@ namespace Tensorflow.Keras.Layers
throw new NotImplementedException();
}

// 好像不能cell不能传接口类型
//public RNN New(IRnnArgCell cell,
// bool return_sequences = false,
// bool return_state = false,
// bool go_backwards = false,
// bool stateful = false,
// bool unroll = false,
// bool time_major = false)
// => new RNN(new RNNArgs
// {
// Cell = cell,
// ReturnSequences = return_sequences,
// ReturnState = return_state,
// GoBackwards = go_backwards,
// Stateful = stateful,
// Unroll = unroll,
// TimeMajor = time_major
// });

//public RNN New(List<IRnnArgCell> cell,
// bool return_sequences = false,
// bool return_state = false,
// bool go_backwards = false,
// bool stateful = false,
// bool unroll = false,
// bool time_major = false)
// => new RNN(new RNNArgs
// {
// Cell = cell,
// ReturnSequences = return_sequences,
// ReturnState = return_state,
// GoBackwards = go_backwards,
// Stateful = stateful,
// Unroll = unroll,
// TimeMajor = time_major
// });


protected Tensors get_initial_state(Tensors inputs)
{
var input = inputs[0];


+ 9
- 0
test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs View File

@@ -146,6 +146,15 @@ namespace Tensorflow.Keras.UnitTest.Layers

}

[TestMethod]
public void GRU()
{
var inputs = tf.ones((32, 10, 8));
var gru = tf.keras.layers.GRU(4);
var output = gru.Apply(inputs);
Assert.AreEqual((32, 4), output.shape);
}

[TestMethod]
public void Bidirectional()
{


Loading…
Cancel
Save