Browse Source

multi-head attention

tags/v0.70.2-NET6
hlx1120@outlook.com Haiping 3 years ago
parent
commit
2e94ed38b0
5 changed files with 409 additions and 161 deletions
  1. +2
    -1
      src/TensorFlowNET.Keras/Layers/Activation/Softmax.cs
  2. +1
    -1
      src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs
  3. +352
    -0
      src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs
  4. +31
    -0
      src/TensorFlowNET.Keras/Layers/LayersApi.Attention.cs
  5. +23
    -159
      test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs

+ 2
- 1
src/TensorFlowNET.Keras/Layers/Activation/Softmax.cs View File

@@ -12,7 +12,8 @@ namespace Tensorflow.Keras.Layers {
axis = args.axis;
}
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
Tensor x = inputs;
Tensor x = inputs.Length == 2 ? inputs + ((1.0 - tf.cast(inputs[1], inputs.dtype)) * 1e-9)
: inputs;
Tensor e = tf.exp(tf.sub(x, tf.reduce_max(x, axis: this.axis, keepdims: true)));
Tensor s = tf.reduce_sum(e, axis: this.axis, keepdims: true);
return tf.div(e, s);


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs View File

@@ -120,7 +120,7 @@ namespace Tensorflow.Keras.Layers

int count = inputs.Count();
if (count < 2 || count > 6) throw new ValueError(
$"{ this.name } layer accepts inputs list of length from 2 to 5, " +
$"{ this.name } layer accepts inputs list of length from 2 to 6, " +
$"namely [query, value, (key), (query_mask), (value_mask), (return_attention_scores)]." +
$"Received length: {count}.");



+ 352
- 0
src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs View File

@@ -0,0 +1,352 @@
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using System;
using System.Linq;

namespace Tensorflow.Keras.Layers
{
public class MultiHeadAttention : Layer
{
static readonly string _CHR_IDX = "abcdefghijklmnopqrstuvwxyz";

MultiHeadAttentionArgs args;
Shape _query_shape = null;
Shape _key_shape = null;
Shape _value_shape = null;
bool _built_from_signature = false;
EinsumDense _query_dense = null;
EinsumDense _key_dense = null;
EinsumDense _value_dense = null;
EinsumDense _output_dense = null;
string _dot_product_equation = "";
string _combine_equation = "";
Softmax _softmax = null;
Dropout _dropout_layer = null;

/// <summary>
/// Builds einsum equations for the attention computation.
/// Query, key, value inputs after projection are expected to have the shape as:
/// `(bs, [non-attention dims], [attention dims], num_heads, channels)`.
/// `bs` and `[non-attention dims]` are treated as `[batch dims]`.
///
/// <para>
/// The attention operations can be generalized:
/// </para>
/// <para>
/// (1) Query-key dot product:
/// `([batch dims], [query attention dims], num_heads, channels), ([batch dims],
/// [key attention dims], num_heads, channels) -> ([batch dim],
/// num_heads, [query attention dims], [key attention dims])`
/// </para><para>
/// (2) Combination:
/// `([batch dims], num_heads, [query attention dims], [key attention dims]),
/// ([batch dims], [value attention dims], num_heads, channels) -> ([batch dims],
/// [query attention dims], num_heads, channels)`
/// </para>
/// </summary>
/// <param name="rank">Rank of query, key, value tensors.</param>
/// <param name="attn_axes">List/tuple of axes, `[-1, rank)`,
/// that attention will be applied to.</param>
/// <returns></returns>
public static (string, string, int) _build_attention_equation(int rank, Shape attn_axes)
{
var target_notation = _CHR_IDX.Substring(0, rank);
// `batch_dims` includes the head dim.
var batch_dims = range(rank).Except(attn_axes.as_int_list().concat(new[] { rank - 1 }));
var letter_offset = rank;
var source_notation = "";
for (int i = 0; i < rank; i++)
{
if (batch_dims.Contains(i) || i == rank - 1)
source_notation += target_notation[i];
else
{
source_notation += _CHR_IDX[letter_offset];
letter_offset += 1;
}
}
var product_notation = "".Insert(0, new string((from i in batch_dims
select (char)(int)target_notation[i]).Concat(
from i in attn_axes.as_int_list()
select (char)(int)target_notation[i]).Concat(
from i in attn_axes.as_int_list()
select source_notation[i]).ToArray()));
var dot_product_equation = $"{source_notation},{target_notation}->{product_notation}";
var attn_scores_rank = product_notation.Count();
var combine_equation = $"{product_notation},{source_notation}->{target_notation}";
return (dot_product_equation, combine_equation, attn_scores_rank);
}

/// <summary>
/// Builds an einsum equation for projections inside multi-head attention.
/// </summary>
public static (string, string, int) _build_proj_equation(int free_dims, int bound_dims, int output_dims)
{
char _char;
var input_str = "";
var kernel_str = "";
var output_str = "";
var bias_axes = "";
var letter_offset = 0;
foreach (var i in range(free_dims))
{
_char = _CHR_IDX[i + letter_offset];
input_str += _char;
output_str += _char;
}
letter_offset += free_dims;
foreach (var i in range(bound_dims))
{
_char = _CHR_IDX[i + letter_offset];
input_str += _char;
kernel_str += _char;
}
letter_offset += bound_dims;
foreach (var i in range(output_dims))
{
_char = _CHR_IDX[i + letter_offset];
kernel_str += _char;
output_str += _char;
bias_axes += _char;
}
var equation = $"{input_str},{kernel_str}->{output_str}";
return (equation, bias_axes, output_str.Count());
}

static Shape _get_output_shape(int output_rank, Shape known_last_dims)
=> (from _ in range(output_rank - known_last_dims.rank)
select -1).Concat(known_last_dims.as_int_list()).ToArray();

public MultiHeadAttention(MultiHeadAttentionArgs args) : base(args)
{
this.args = args;
}

public void _build_from_signature(Tensor query, Tensor value, Tensor key = null)
=> this._build_from_signature(query.shape, value.shape, key?.shape);

public void _build_from_signature(Shape query, Shape value, Shape key = null)
{
this._built_from_signature = true;
this._query_shape = query;
this._value_shape = value;
if (key == null)
this._key_shape = this._value_shape;
else
this._key_shape = key;
// Any setup work performed only once should happen in an `init_scope`
// to avoid creating symbolic Tensors that will later pollute any eager
// operations.
tf_with(tf.init_scope(), _ =>
{
var free_dims = this._query_shape.rank - 1;
var (einsum_equation, bias_axes, output_rank) = _build_proj_equation(
free_dims, bound_dims: 1, output_dims: 2);
this._query_dense = _get_dense(einsum_equation,
_get_output_shape(output_rank - 1,
(this.args.NumHeads, this.args.KeyDim)),
this.args.UseBias ? bias_axes : null,
"query");
(einsum_equation, bias_axes, output_rank) = _build_proj_equation(
this._key_shape.rank - 1, bound_dims: 1, output_dims: 2);
this._key_dense = _get_dense(einsum_equation,
_get_output_shape(output_rank - 1,
(this.args.NumHeads, this.args.KeyDim)),
this.args.UseBias ? bias_axes : null,
"key");
(einsum_equation, bias_axes, output_rank) = _build_proj_equation(
this._value_shape.rank - 1, bound_dims: 1, output_dims: 2);
this._value_dense = _get_dense(einsum_equation,
_get_output_shape(output_rank - 1,
(this.args.NumHeads, this.args.ValueDim ?? -1)),
this.args.UseBias ? bias_axes : null,
"value");
// Builds the attention computations for multi-head dot product attention.
// These computations could be wrapped into the keras attention layer once
// it support mult-head einsum computations.
this._build_attention(output_rank);
this._output_dense = _build_output_dense(free_dims, "attention_output");
});
this.StackLayers(_query_dense, _key_dense, _value_dense, _output_dense);
}

EinsumDense _get_dense(string equation, Shape output_shape, string bias_axes, string name)
=> new EinsumDense(new EinsumDenseArgs()
{
Equation = equation,
OutputShape = output_shape,
BiasAxes = bias_axes,
Name = name,
KernelInitializer = this.args.KernelInitializer,
BiasInitializer = this.args.BiasInitializer,
KernelRegularizer = this.args.KernelRegularizer,
BiasRegularizer = this.args.BiasRegularizer,
KernelConstraint = this.args.KernelConstraint,
BiasConstraint = this.args.BiasConstraint
});

EinsumDense _build_output_dense(int free_dims, string name)
{
if (this.args.OutputShape == null) this.args.OutputShape = new(this._query_shape[-1]);
var (einsum_equation, bias_axes, output_rank) = _build_proj_equation(
free_dims, bound_dims: 2, output_dims: len(this.args.OutputShape));
return _get_dense(einsum_equation,
_get_output_shape(output_rank - 1, this.args.OutputShape),
this.args.UseBias ? bias_axes : null,
name);
}

void _build_attention(int rank)
{
if (this.args.AttentionAxis == null)
this.args.AttentionAxis = new(range(1, rank - 2).ToArray());
int attn_scores_rank;
(this._dot_product_equation, this._combine_equation, attn_scores_rank)
= _build_attention_equation(rank, this.args.AttentionAxis);
var norm_axes = range(attn_scores_rank - len(this.args.AttentionAxis),
attn_scores_rank).ToArray();
this._softmax = new Softmax(new SoftmaxArgs { axis = norm_axes });
this._dropout_layer = new Dropout(new DropoutArgs { Rate = this.args.Dropout });
}

Tensor _masked_softmax(Tensor attention_scores, Tensor attention_mask = null)
{
if(attention_mask != null)
{
var mask_expansion_axis = -len(this.args.AttentionAxis) * 2 - 1;
for (int i = 0; i < len(attention_scores.shape) - len(attention_mask.shape); i++)
attention_mask = tf.expand_dims(attention_mask, axis: mask_expansion_axis);
}
return this._softmax.Apply(attention_mask == null ? attention_scores : (attention_scores, attention_mask));
}

public Tensors _compute_attention(
Tensor query,
Tensor key,
Tensor value,
Tensor attention_mask = null,
bool training = false)
{
// Note: Applying scalar multiply at the smaller end of einsum improves
// XLA performance, but may introduce slight numeric differences in
// the Transformer attention head.
query = tf.multiply(query, 1d / Math.Sqrt(this.args.KeyDim));
// Take the dot product between "query" and "key" to get the raw
// attention scores.
var attention_scores = tf.linalg.einsum(this._dot_product_equation, (key, query));
attention_scores = this._masked_softmax(attention_scores, attention_mask);
// This is actually dropping out entire tokens to attend to, which might
// seem a bit unusual, but is taken from the original Transformer paper.
var attention_scores_dropout = this._dropout_layer.Apply(attention_scores, training: training);
// `context_layer` = [B, T, N, H]
var attention_output = tf.linalg.einsum(this._combine_equation, (attention_scores_dropout, value));
return (attention_output, attention_scores);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
Tensors _inp;
Tensor _mask = null;

int count = inputs.Count();
if (count < 2 || count > 5) throw new ValueError(
$"{ this.name } layer accepts inputs list of length from 2 to 5, " +
$"namely [query, value, (key), (attention_mask), (return_attention_scores)]." +
$"Received length: {count}.");

bool has_bool = inputs[count - 1].dtype == TF_DataType.TF_BOOL;
bool return_attention_scores = false;
if (has_bool)
{
return_attention_scores = (bool)inputs[count - 1];
count--;
}

switch (count)
{
case 2:
_inp = (inputs[0], inputs[1]);
break;
case 3:
if (inputs[2].shape[-1] != inputs[0].shape[-1])
_inp = new[] { inputs[0], inputs[1], inputs[2] };
else
{
_inp = (inputs[0], inputs[1]);
_mask = inputs[2];
}
break;
case 4:
_inp = new[] { inputs[0], inputs[1], inputs[2] };
_mask = inputs[3];
break;
default:
throw new ValueError(); //TODO:Add discriptions for this err
}

return call(_inp, _mask, training, return_attention_scores);
}

protected Tensors call(Tensors inputs,
Tensor attention_mask,
bool? training = null,
bool return_attention_scores = false)
{
var (query, value, key) = (inputs[0], inputs[1], inputs.Length == 3 ? inputs[2] : null);
if (!this._built_from_signature)
this._build_from_signature(query: query, value: value, key: key);
if (key == null)
key = value;

// TODO: Add RaggedTensor support
//var query_is_ragged = query is tf.RaggedTensor;
//if (query_is_ragged)
//{
// var query_lengths = query.nested_row_lengths();
// query = query.to_tensor();
//}
//var key_is_ragged = key is tf.RaggedTensor;
//var value_is_ragged = value is tf.RaggedTensor;
//if (key_is_ragged && value_is_ragged)
//{
// // Ensure they have the same shape.
// var bounding_shape = tf.math.maximum(key.bounding_shape(), value.bounding_shape());
// key = key.to_tensor(shape: bounding_shape);
// value = value.to_tensor(shape: bounding_shape);
//}
//else if (key_is_ragged)
//{
// key = key.to_tensor(shape: tf.shape(value));
//}
//else if (value_is_ragged)
//{
// value = value.to_tensor(shape: tf.shape(key));
//}

// N = `num_attention_heads`
// H = `size_per_head`
// `query` = [B, T, N ,H]
query = this._query_dense.Apply(query);
// `key` = [B, S, N, H]
key = this._key_dense.Apply(key);
// `value` = [B, S, N, H]
value = this._value_dense.Apply(value);
var (attention_output, attention_scores) = this._compute_attention(query, key, value, attention_mask, training ?? false);
attention_output = this._output_dense.Apply(attention_output);

//if (query_is_ragged)
//{
// attention_output = tf.RaggedTensor.from_tensor(attention_output, lengths: query_lengths);
//}

if (return_attention_scores)
return (attention_output, attention_scores);
return attention_output;
}
}
}

+ 31
- 0
src/TensorFlowNET.Keras/Layers/LayersApi.Attention.cs View File

@@ -21,5 +21,36 @@ namespace Tensorflow.Keras.Layers
causal = causal,
dropout = dropout
});
public MultiHeadAttention MultiHeadAttention(int num_heads,
int key_dim,
int? value_dim = null,
float dropout = 0f,
bool use_bias = true,
Shape output_shape = null,
Shape attention_axes = null,
IInitializer kernel_initializer = null,
IInitializer bias_initializer = null,
IRegularizer kernel_regularizer = null,
IRegularizer bias_regularizer = null,
IRegularizer activity_regularizer = null,
Action kernel_constraint = null,
Action bias_constraint = null) =>
new MultiHeadAttention(new MultiHeadAttentionArgs
{
NumHeads = num_heads,
KeyDim = key_dim,
ValueDim = value_dim,
Dropout = dropout,
UseBias = use_bias,
OutputShape = output_shape,
AttentionAxis = attention_axes,
KernelInitializer = kernel_initializer ?? tf.glorot_uniform_initializer,
BiasInitializer = bias_initializer ?? tf.zeros_initializer,
KernelRegularizer = kernel_regularizer,
BiasRegularizer = bias_regularizer,
ActivityRegularizer = activity_regularizer,
KernelConstraint = kernel_constraint,
BiasConstraint = bias_constraint,
});
}
}

+ 23
- 159
test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs View File

@@ -15,45 +15,6 @@ namespace TensorFlowNET.Keras.UnitTest
public class AttentionTest : EagerModeTestBase
{
#region BaseDenseAttention
[TestMethod]
public void test_one_dim_with_mask()
{
// Scores tensor of shape [1, 1, 1]
var scores = np.array(new[, ,] { { { 1.1f } } }, dtype: np.float32);
// Value tensor of shape [1, 1, 1]
var v = np.array(new[, ,] { { { 1.6f } } }, dtype: np.float32);
// Scores mask tensor of shape [1, 1, 1]
var scores_mask = np.array(new[, ,] { { { true } } }, dtype: np.@bool);
var _tup_1 = new BaseDenseAttention(new())._apply_scores(scores: scores, value: v, scores_mask: scores_mask);
var actual = _tup_1.Item1;
var actual_scores = _tup_1.Item2;
// Expected softmax_scores = [[[1]]]
var expected_scores = np.array(new[, ,] { { { 1f } } }, dtype: np.float32);
Assert.AreEqual(expected_scores, actual_scores.numpy());
// Expected tensor of shape [1, 1, 1].
// expected000 = softmax_scores[0, 0] * 1.6 = 1.6
var expected = np.array(new[, ,] { { { 1.6f } } }, dtype: np.float32);
Assert.AreEqual(expected, actual.numpy());
}

[TestMethod]
public void test_one_dim_no_mask()
{
// Scores tensor of shape [1, 1, 1]
var scores = np.array(new[, ,] { { { 1.1f } } }, dtype: np.float32);
// Value tensor of shape [1, 1, 1]
var v = np.array(new[, ,] { { { 1.6f } } }, dtype: np.float32);
var _tup_1 = new BaseDenseAttention(new())._apply_scores(scores: scores, value: v);
var actual = _tup_1.Item1;
var actual_scores = _tup_1.Item2;
// Expected softmax_scores = [[[1]]]
var expected_scores = np.array(new[, ,] { { { 1f } } }, dtype: np.float32);
Assert.AreEqual(expected_scores, actual_scores.numpy());
// Expected tensor of shape [1, 1, 1].
// expected000 = softmax_scores[0, 0] * 1.6 = 1.6
var expected = np.array(new[, ,] { { { 1.6f } } }, dtype: np.float32);
Assert.AreEqual(expected, actual.numpy());
}

[TestMethod]
public void test_multi_dim_with_mask()
@@ -81,35 +42,6 @@ namespace TensorFlowNET.Keras.UnitTest
var expected = np.array(new[, ,] { { { 1.3579528f } } }, dtype: np.float32);
Assert.AreEqual(expected, actual.numpy());
}
[TestMethod]
public void test_multi_dim_no_mask()
{
// Scores tensor of shape [1, 1, 3]
var scores = np.array(new[, ,] { { { 1f, 0f, 1f } } }, dtype: np.float32);
// Value tensor of shape [1, 3, 1]
var v = np.array(new[, ,] { { { 1.6f }, { 0.7f }, { -0.8f } } }, dtype: np.float32);
var _tup_1 = new BaseDenseAttention(new())._apply_scores(scores: scores, value: v);
var actual = _tup_1.Item1;
var actual_scores = _tup_1.Item2;
// Expected softmax_scores = softmax(scores).
// => softmax_scores000 = exp(1)/(exp(1) + exp(0) + exp(1))
// = 0.42231879825
// softmax_scores001 = exp(0)/(exp(1) + exp(0) + exp(1))
// = 0.15536240349
// softmax_scores002 = exp(1)/(exp(1) + exp(0) + exp(1))
// = 0.42231879825
//Actually the output is 0.42231882, 0.15536241, 0.42231882
var expected_scores = np.array(new[, ,] { { { 0.42231882f, 0.15536241f, 0.42231882f } } }, dtype: np.float32);
Assert.AreEqual(expected_scores, actual_scores.numpy());
// Expected tensor of shape [1, 1, 1].
// expected000 = 0.42231879825 * 1.6 + 0.15536240349 * 0.7
// - 0.42231879825 * 0.8
// = 0.44660872104
//Actually the output is 0.44660875
var expected = np.array(new[, ,] { { { 0.44660875f } } }, dtype: np.float32);
Assert.AreEqual(expected, actual.numpy());
}

[TestMethod]
public void test_one_dim_batch_size_two()
@@ -132,101 +64,10 @@ namespace TensorFlowNET.Keras.UnitTest
var expected = np.array(new[, ,] { { { 1.6f } }, { { 2.6f } } }, dtype: np.float32);
Assert.AreEqual(expected, actual.numpy());
}

[TestMethod]
public void test_shape_with_dropout()
{
// scores: Scores float tensor of shape `[batch_size, tq, tv]`.
// value: Value tensor of shape `[batch_size, tv, dim]`.
var batch_size = 4;
var tq = 5;
var tv = 6;
var dim = 7;
var scores = np.ones((batch_size, tq, tv));
var value = np.ones((batch_size, tv, dim));
var _tup_1 = new BaseDenseAttention(new BaseDenseAttentionArgs { dropout = 0.1f })
._apply_scores(scores: scores, value: value, training: false);
var actual = _tup_1.Item1;
var actual_scores = _tup_1.Item2;
// Expected Tensor of shape `[batch_size, tq, tv]`.
var expected_scores_shape = new[] {
batch_size,
tq,
tv
};
Assert.AreEqual(expected_scores_shape, tf.shape(actual_scores).numpy());
// Expected Tensor of shape `[batch_size, tq, dim]`.
var expected_shape = new[] {
batch_size,
tq,
dim
};
Assert.AreEqual(expected_shape, tf.shape(actual).numpy());
}
#endregion
// ------------------------------------------------------------------
#region Attention
[TestMethod]
public void test_example()
{
//Variable-length int sequences.
var query_input = keras.Input((1000), dtype: TF_DataType.TF_INT32);
var value_input = keras.Input((1000), dtype: TF_DataType.TF_INT32);
// Embedding lookup.
var token_embedding = keras.layers.Embedding(input_dim: 1000, output_dim: 64);
// Query embeddings of shape [batch_size, Tq, dimension].
var query_embeddings = token_embedding.Apply(query_input);
// Value embeddings of shape [batch_size, Tv, dimension].
var value_embeddings = token_embedding.Apply(value_input);
// CNN layer.
var cnn_layer = keras.layers.Conv1D(
filters: 100,
kernel_size: 4,
// Use 'same' padding so outputs have the same shape as inputs.
padding: "same",
activation: "relu");
var cnn_layer2 = keras.layers.Conv1D(
filters: 100,
kernel_size: 4,
// Use 'same' padding so outputs have the same shape as inputs.
padding: "same",
activation: "relu");
// Query encoding of shape [batch_size, Tq, filters].
var query_seq_encoding = cnn_layer.Apply(query_embeddings);
// Value encoding of shape [batch_size, Tv, filters].
var value_seq_encoding = cnn_layer2.Apply(value_embeddings);
// Query-value attention of shape [batch_size, Tq, filters].
var query_value_attention_seq = keras.layers.Attention().Apply(
(query_seq_encoding, value_seq_encoding));
// Reduce over the sequence axis to produce encodings of shape
// [batch_size, filters].
var query_encoding = keras.layers.GlobalAveragePooling1D().Apply(
query_seq_encoding);
var query_value_attention = keras.layers.GlobalAveragePooling1D().Apply(
query_value_attention_seq);
// Concatenate query and document encodings to produce a DNN input layer.
var input_layer = keras.layers.Concatenate().Apply(
(query_encoding, query_value_attention));
// Add DNN layers, and create Model.
// ...
}

[TestMethod]
public void test_calculate_scores_one_dim()
{
// Query tensor of shape [1, 1, 1]
var q = np.array(new[,,] { { { 1.1f } } }, dtype: np.float32);
// Key tensor of shape [1, 1, 1]
var k = np.array(new[,,] { { { 1.6f } } }, dtype: np.float32);
var attention_layer = keras.layers.Attention();
//attention_layer.build((1));
var actual = attention_layer._calculate_scores(query: q, key: k);
// Expected tensor of shape [1, 1, 1].
// expected000 = 1.1*1.6 = 1.76
// Actually the output is 1.7600001
var expected = np.array(new[,,] { { { 1.7600001f } } }, dtype: np.float32);
Assert.AreEqual(expected, actual.numpy());
}

[TestMethod]
public void test_calculate_scores_multi_dim()
@@ -305,6 +146,29 @@ namespace TensorFlowNET.Keras.UnitTest
Assert.AreEqual(expected, actual.numpy());
}
#endregion
// ------------------------------------------------------------------
#region MultiHeadAttention
[TestMethod]
public void test_masked_attention()
{
var query = keras.Input(shape: (4, 8));
var value = keras.Input(shape: (2, 8));
var mask_tensor = keras.Input(shape:(4, 2));
var attention_layer = keras.layers.MultiHeadAttention(num_heads: 2, key_dim: 2);
attention_layer.Apply(new[] { query, value, mask_tensor });

var from_data = 10 * np.random.randn(3, 4, 8);
var to_data = 10 * np.random.randn(3, 2, 8);

var mask_data = np.random.randint(2, size: (3, 4, 2));
var masked_output_data = attention_layer.Apply(new[] { from_data, to_data, mask_data });

var null_mask_data = np.ones((3, 4, 2));
var unmasked_output_data = attention_layer.Apply(new[] { from_data, to_data, null_mask_data });

Assert.AreNotEqual(masked_output_data, unmasked_output_data);
}
#endregion
}

}

Loading…
Cancel
Save