Browse Source

Add Attention support and test it

pull/943/head
hlx1120@outlook.com 3 years ago
parent
commit
875e41d38b
7 changed files with 800 additions and 0 deletions
  1. +20
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/AttentionArgs.cs
  2. +20
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/BaseDenseAttentionArgs.cs
  3. +9
    -0
      src/TensorFlowNET.Keras/Engine/Layer.cs
  4. +159
    -0
      src/TensorFlowNET.Keras/Layers/Attention/Attention.cs
  5. +257
    -0
      src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs
  6. +25
    -0
      src/TensorFlowNET.Keras/Layers/LayersApi.Attention.cs
  7. +310
    -0
      test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs

+ 20
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/AttentionArgs.cs View File

@@ -0,0 +1,20 @@
namespace Tensorflow.Keras.ArgsDefinition
{
public class AttentionArgs : BaseDenseAttentionArgs
{

/// <summary>
/// If `true`, will create a scalar variable to scale the attention scores.
/// </summary>
public bool use_scale { get; set; } = false;

/// <summary>
/// Function to use to compute attention scores, one of
/// `{"dot", "concat"}`. `"dot"` refers to the dot product between the query
/// and key vectors. `"concat"` refers to the hyperbolic tangent of the
/// concatenation of the query and key vectors.
/// </summary>
public string score_mode { get; set; } = "dot";

}
}

+ 20
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/BaseDenseAttentionArgs.cs View File

@@ -0,0 +1,20 @@
namespace Tensorflow.Keras.ArgsDefinition
{
public class BaseDenseAttentionArgs : LayerArgs
{

/// <summary>
/// Boolean. Set to `true` for decoder self-attention. Adds a mask such
/// that position `i` cannot attend to positions `j > i`. This prevents the
/// flow of information from the future towards the past.
/// </summary>
public bool causal { get; set; } = false;

/// <summary>
/// Float between 0 and 1. Fraction of the units to drop for the
/// attention scores.
/// </summary>
public float dropout { get; set; } = 0f;

}
}

+ 9
- 0
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -272,6 +272,15 @@ namespace Tensorflow.Keras.Engine
weights.AddRange(non_trainable_weights);
return weights;
}
set
{
if (weights.Count() != value.Count()) throw new ValueError(
$"You called `set_weights` on layer \"{this.name}\"" +
$"with a weight list of length {len(value)}, but the layer was " +
$"expecting {len(weights)} weights.");
foreach (var (this_w, v_w) in zip(weights, value))
this_w.assign(v_w, read_value: true);
}
}

public virtual LayerArgs get_config()


+ 159
- 0
src/TensorFlowNET.Keras/Layers/Attention/Attention.cs View File

@@ -0,0 +1,159 @@
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;

namespace Tensorflow.Keras.Layers
{
/// <summary>
/// Dot-product attention layer, a.k.a. Luong-style attention.
/// Inputs are `query` tensor of shape `[batch_size, Tq, dim]`, `value` tensor of
/// shape `[batch_size, Tv, dim]` and `key` tensor of shape
/// `[batch_size, Tv, dim]`. The calculation follows the steps:
/// <para>
/// 1. Calculate scores with shape `[batch_size, Tq, Tv]` as a `query`-`key` dot
/// product: `scores = tf.matmul(query, key, transpose_b=True)`.
/// </para>
/// <para>
/// 2. Use scores to calculate a distribution with shape
/// `[batch_size, Tq, Tv]`: `distribution = tf.nn.softmax(scores)`.
/// </para>
/// <para>
/// 3. Use `distribution` to create a linear combination of `value` with
/// shape `[batch_size, Tq, dim]`:
/// `return tf.matmul(distribution, value)`.
/// </para>
/// </summary>
/// <example> 0
/// <code>
/// //Variable-length int sequences.
/// var query_input = keras.Input((1000), dtype: TF_DataType.TF_INT32);
/// var value_input = keras.Input((1000), dtype: TF_DataType.TF_INT32);
/// // Embedding lookup.
/// var token_embedding = keras.layers.Embedding(input_dim: 1000, output_dim: 64);
/// // Query embeddings of shape [batch_size, Tq, dimension].
/// var query_embeddings = token_embedding.Apply(query_input);
/// // Value embeddings of shape [batch_size, Tv, dimension].
/// var value_embeddings = token_embedding.Apply(value_input);
/// // CNN layer.
/// var cnn_layer = keras.layers.Conv1D(
/// filters: 100,
/// kernel_size: 4,
/// // Use 'same' padding so outputs have the same shape as inputs.
/// padding: "same");
/// var cnn_layer2 = keras.layers.Conv1D(
/// filters: 100,
/// kernel_size: 4,
/// // Use 'same' padding so outputs have the same shape as inputs.
/// padding: "same");
/// // Query encoding of shape [batch_size, Tq, filters].
/// var query_seq_encoding = cnn_layer.Apply(query_embeddings);
/// // Value encoding of shape [batch_size, Tv, filters].
/// var value_seq_encoding = cnn_layer.Apply(value_embeddings);
/// // Query-value attention of shape [batch_size, Tq, filters].
/// var query_value_attention_seq = keras.layers.Attention().Apply(
/// (query_seq_encoding, value_seq_encoding));
/// // Reduce over the sequence axis to produce encodings of shape
/// // [batch_size, filters].
/// var query_encoding = keras.layers.GlobalAveragePooling1D().Apply(
/// query_seq_encoding);
/// var query_value_attention = keras.layers.GlobalAveragePooling1D().Apply(
/// query_value_attention_seq);
/// // Concatenate query and document encodings to produce a DNN input layer.
/// var input_layer = keras.layers.Concatenate().Apply(
/// (query_encoding, query_value_attention));
/// // Add DNN layers, and create Model.
/// // ...
/// </code>
/// </example>
public class Attention : BaseDenseAttention
{
public IVariableV1 concat_score_weight;
public IVariableV1 scale;

AttentionArgs args;
string score_mode { get => args.score_mode; }
bool use_scale { get => args.use_scale; }
public Attention(AttentionArgs args) : base(args)
{
this.args = args;
if (!new List<string> {
"dot",
"concat"
}.Contains(this.score_mode))
throw new ValueError("Received: score_mode={score_mode}. Acceptable values are: [\"dot\", \"concat\"]");
}
// Creates variable when `use_scale` is True or `score_mode` is `concat`.
protected override void build(Tensors inputs) {
if (this.use_scale)
this.scale = this.add_weight(name: "scale",
shape: 1,
initializer: tf.ones_initializer,
dtype: this.DType,
trainable: true);
else
this.scale = null;

if (this.score_mode == "concat")
this.concat_score_weight = this.add_weight(name: "concat_score_weight",
shape: 1,
initializer: tf.ones_initializer,
dtype: this.DType,
trainable: true);
else
this.concat_score_weight = null;
base.build(inputs);
}

/// <summary>
/// Calculates attention scores as a query-key dot product.
/// </summary>
/// <param name="query">query: Query tensor of shape `[batch_size, Tq, dim]`.</param>
/// <param name="key">key: Key tensor of shape `[batch_size, Tv, dim]`.</param>
/// <returns>Tensor of shape `[batch_size, Tq, Tv]`.</returns>
public override Tensor _calculate_scores(Tensor query, Tensor key)
{
Tensor scores = null;
if (this.score_mode == "dot")
{
//scores = tf.matmul(query, key, transpose_b: true);
//scores = tf.matmul(tf.squeeze(query),tf.squeeze(key), transpose_b: true);
scores = tf.linalg.einsum("bij,bkj->bik", (query, key));
if (this.scale != null)
scores *= this.scale.AsTensor();
} else if (this.score_mode == "concat") {
// Reshape tensors to enable broadcasting.
// Reshape into [batch_size, Tq, 1, dim].
var q_reshaped = tf.expand_dims(query, axis: -2);
// Reshape into [batch_size, 1, Tv, dim].
var k_reshaped = tf.expand_dims(key, axis: -3);
if (this.scale != null)
scores = this.concat_score_weight.AsTensor() *
tf.reduce_sum(tf.tanh(this.scale.AsTensor() * (q_reshaped + k_reshaped)), axis: -1);
else
scores = this.concat_score_weight.AsTensor() *
tf.reduce_sum(tf.tanh(q_reshaped + k_reshaped), axis: -1);
}
return scores;
}

public override LayerArgs get_config() => this.args;
//var config = new Dictionary<object, object> {
// {
// "use_scale",
// this.use_scale},
// {
// "score_mode",
// this.score_mode}};
//var base_config = base.get_config();
//return new dict(base_config.items().ToList() + config.items().ToList());
}
}

+ 257
- 0
src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs View File

@@ -0,0 +1,257 @@
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.ArgsDefinition;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using System;
using System.Collections.Generic;
using System.Linq;

/// <summary>
/// Base class for attention layers that can be used in sequence DNN/CNN models.
///This file follows the terminology of https://arxiv.org/abs/1706.03762 Figure 2.
///Attention is formed by three tensors: Query, Key and Value.
/// </summary>

namespace Tensorflow.Keras.Layers
{

/// <summary>
/// Base Attention class for Dense networks.
/// This class is suitable for Dense or CNN networks, and not for RNN networks.
/// Implementations of attention mechanisms should inherit from this class, and
/// reuse the `apply_attention_scores()` method.
/// </summary>
public class BaseDenseAttention : Layer
{

BaseDenseAttentionArgs args;

bool causal { get => args.causal; }
float dropout { get => args.dropout; }

protected bool supports_masking;
public BaseDenseAttention(BaseDenseAttentionArgs args) : base(args)
{
this.args = args;
this.supports_masking = true;
}

/// <summary>
/// Calculates attention scores.
/// </summary>
/// <param name="query">query: Query tensor of shape `[batch_size, Tq, dim]`.</param>
/// <param name="key">key: Key tensor of shape `[batch_size, Tv, dim]`.</param>
/// <returns>Tensor of shape `[batch_size, Tq, Tv]`.</returns>
public virtual Tensor _calculate_scores(Tensor query, Tensor key) =>
throw new NotImplementedException("");

/// <summary>
/// Applies attention scores to the given value tensor.
/// To use this method in your attention layer, follow the steps:
/// <para>
/// * Use `query` tensor of shape `[batch_size, Tq]` and `key` tensor of shape
/// `[batch_size, Tv]` to calculate the attention `scores`.
/// </para>
/// <para>
/// * Pass `scores` and `value` tensors to this method. The method applies
/// `scores_mask`, calculates `attention_distribution = softmax(scores)`, then
/// returns `matmul(attention_distribution, value).
/// </para>
/// <para>
/// * Apply `query_mask` and return the result.
/// </para>
/// </summary>
/// <param name="scores">Scores float tensor of shape `[batch_size, Tq, Tv]`.</param>
/// <param name="value">Value tensor of shape `[batch_size, Tv, dim]`.</param>
/// <param name="scores_mask">
/// A boolean mask `Tensor` of shape `[batch_size, 1, Tv]` or
/// [batch_size, Tq, Tv]`. If given, scores at positions where
/// `scores_mask==False` do not contribute to the result. It must contain
/// at least one `True` value in each line along the last dimension.
/// </param>
/// <param name="training">
/// Boolean indicating whether the layer should behave in
/// training mode (adding dropout) or in inference mode (no dropout).
/// </param>
/// <returns>
/// <para>
/// Tensor of shape `[batch_size, Tq, dim]`.
/// </para>
/// <para>
/// Attention scores after masking and softmax with shape
/// [batch_size, Tq, Tv]`.
/// </para>
/// </returns>
public (Tensor, Tensor) _apply_scores(Tensor scores,
Tensor value,
Tensor scores_mask = null,
bool? training = null)
{
if (scores_mask != null)
{
var padding_mask = tf.logical_not(scores_mask);
// Bias so padding positions do not contribute to attention distribution.
// Note 65504. is the max float16 value.
if (scores.dtype == tf.float16)
scores -= 65504f * tf.cast(padding_mask, dtype: scores.dtype);
else
scores -= 1000000000f * tf.cast(padding_mask, dtype: scores.dtype);
}
bool _training;
training ??= false; // TODO: Delete this line when backend.learning_phase is available
if (training == null)
_training = keras.backend.learning_phase() ==
Tensorflow.Keras.GraphLearningPhase.train_mode ?
true : false;
else _training = training.Value;
var weights = tf.nn.softmax(scores);
Func<Tensor> dropped_weights = () => tf.nn.dropout(weights, rate: this.dropout);
weights = Tensorflow.Framework.smart_module.smart_cond(_training, dropped_weights, () => tf.identity(weights));
//return (tf.matmul(weights, value), weights);
return (tf.linalg.einsum("bij,bjk->bik", (weights, value)), weights);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
Tensors _inp;
Tensors _mask = null;

int count = inputs.Count();
if (count < 2 || count > 6) throw new ValueError(
$"{ this.name } layer accepts inputs list of length from 2 to 5, " +
$"namely [query, value, (key), (query_mask), (value_mask), (return_attention_scores)]." +
$"Received length: {count}.");

bool has_bool = inputs[count - 1].dtype == TF_DataType.TF_BOOL;
bool return_attention_scores = false;
if (has_bool)
{
return_attention_scores = (bool)inputs[count - 1];
count--;
}

switch (count)
{
case 2:
_inp = (inputs[0], inputs[1]);
break;
case 3:
_inp = new[] { inputs[0], inputs[1], inputs[2] };
break;
case 4:
if (inputs[0].shape == inputs[2].shape)
if (inputs[1].shape == inputs[3].shape)
{
_inp = new[] { inputs[0], inputs[1] };
_mask = new[] { inputs[2], inputs[3] };
break;
}
throw new ValueError(); //TODO:Add discriptions for this err
case 5:
_inp = new[] { inputs[0], inputs[1], inputs[2] };
_mask = (inputs[3], inputs[4]);
break;
default:
throw new ValueError(); //TODO:Add discriptions for this err
}

return call(_inp, _mask, training, return_attention_scores);
}

protected Tensors call(Tensors inputs, Tensors mask = null, bool? training = null, bool return_attention_scores = false)
{
Tensor causal_mask;
//this._validate_call_args(inputs: inputs, mask: mask);
var q = inputs[0];
var v = inputs[1];
var k = inputs.Count() > 2 ? inputs[2] : v;
var q_mask = mask != null ? mask[0] : null;
var v_mask = mask != null ? mask[1] : null;
var scores = this._calculate_scores(query: q, key: k);
if (v_mask != null)
// Mask of shape [batch_size, 1, Tv].
v_mask = tf.expand_dims(v_mask, axis: -2);
if (this.causal)
{
// Creates a lower triangular mask, so position i cannot attend to
// positions j>i. This prevents the flow of information from the future
// into the past.
var scores_shape = tf.shape(scores);
// causal_mask_shape = [1, Tq, Tv].
var causal_mask_shape = tf.concat(new List<Tensor> {
tf.ones_like(tf.slice(scores_shape, new[]{0}, new[]{-2})),
tf.concat(new[]{scores_shape[-2], scores_shape[-1]}, 0)
}, axis: 0);
var _causal_mask_shape = new Shape(causal_mask_shape.ToArray<int>());
causal_mask = _lower_triangular_mask(_causal_mask_shape);
}
else
causal_mask = null;
var scores_mask = _merge_masks(v_mask, causal_mask);
var (result, attention_scores) = this._apply_scores(scores: scores, value: v, scores_mask: scores_mask, training: training);
if (q_mask != null)
{
// Mask of shape [batch_size, Tq, 1].
q_mask = tf.expand_dims(q_mask, axis: -1);
result *= tf.cast(q_mask, dtype: result.dtype);
}
if (return_attention_scores)
return new Tensors(result, attention_scores);
return result;
}
public Tensor compute_mask(Tensors inputs, Tensors mask = null)
{
this._validate_call_args(inputs: inputs, mask: mask);
if (mask != null)
{
var q_mask = mask[0];
if (q_mask == null)
return null;
return tf.convert_to_tensor(q_mask);
}
return null;
}

//public Shape compute_output_shape(Shape input_shape) {
// // return_attention_scores argument of BaseDenseAttention.call method
// // is ignored. Output shape of attention_scores cannot be returned.
// return input_shape[0];
//}

/// <summary>
/// Validates arguments of the call method.
/// </summary>
public void _validate_call_args(Tensors inputs, Tensors mask)
{
if (inputs.Count() < 2 || inputs.Count() > 3)
throw new ValueError(
$"{this.name} layer accepts inputs list of length 2 or 3, " +
$"namely [query, value] or [query, value, key]. Received length: {len(inputs)}.");
if (mask != null)
if (mask.Count() < 2 || mask.Count() > inputs.Count())
throw new ValueError($"{this.name} layer mask must be a list of length 2, " +
$"namely [query_mask, value_mask]. Received length: {len(mask)}.");
}

public static Tensor _lower_triangular_mask(Shape shape)
{
var row_index = tf.cumsum(tf.ones(shape: shape, dtype: tf.int32), axis: -2);
var col_index = tf.cumsum(tf.ones(shape: shape, dtype: tf.int32), axis: -1);
return tf.greater_equal(row_index, col_index);
}

public static Tensor _merge_masks(Tensor x, Tensor y)
{
if (x == null)
return y;
if (y == null)
return x;
return tf.logical_and(x, y);
}

public override LayerArgs get_config() => this.args;
}
}

+ 25
- 0
src/TensorFlowNET.Keras/Layers/LayersApi.Attention.cs View File

@@ -0,0 +1,25 @@
using System;
using Tensorflow.NumPy;
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras.Layers
{
public partial class LayersApi
{
public Attention Attention(bool use_scale = false,
string score_mode = "dot",
bool causal = false,
float dropout = 0f) =>
new Attention(new AttentionArgs
{
use_scale = use_scale,
score_mode = score_mode,
causal = causal,
dropout = dropout
});
}
}

+ 310
- 0
test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs View File

@@ -0,0 +1,310 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using Tensorflow.Keras.Layers;
using Tensorflow;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Utils;

namespace TensorFlowNET.Keras.UnitTest
{
[TestClass]
public class AttentionTest : EagerModeTestBase
{
#region BaseDenseAttention
[TestMethod]
public void test_one_dim_with_mask()
{
// Scores tensor of shape [1, 1, 1]
var scores = np.array(new[, ,] { { { 1.1f } } }, dtype: np.float32);
// Value tensor of shape [1, 1, 1]
var v = np.array(new[, ,] { { { 1.6f } } }, dtype: np.float32);
// Scores mask tensor of shape [1, 1, 1]
var scores_mask = np.array(new[, ,] { { { true } } }, dtype: np.@bool);
var _tup_1 = new BaseDenseAttention(new())._apply_scores(scores: scores, value: v, scores_mask: scores_mask);
var actual = _tup_1.Item1;
var actual_scores = _tup_1.Item2;
// Expected softmax_scores = [[[1]]]
var expected_scores = np.array(new[, ,] { { { 1f } } }, dtype: np.float32);
Assert.AreEqual(expected_scores, actual_scores.numpy());
// Expected tensor of shape [1, 1, 1].
// expected000 = softmax_scores[0, 0] * 1.6 = 1.6
var expected = np.array(new[, ,] { { { 1.6f } } }, dtype: np.float32);
Assert.AreEqual(expected, actual.numpy());
}

[TestMethod]
public void test_one_dim_no_mask()
{
// Scores tensor of shape [1, 1, 1]
var scores = np.array(new[, ,] { { { 1.1f } } }, dtype: np.float32);
// Value tensor of shape [1, 1, 1]
var v = np.array(new[, ,] { { { 1.6f } } }, dtype: np.float32);
var _tup_1 = new BaseDenseAttention(new())._apply_scores(scores: scores, value: v);
var actual = _tup_1.Item1;
var actual_scores = _tup_1.Item2;
// Expected softmax_scores = [[[1]]]
var expected_scores = np.array(new[, ,] { { { 1f } } }, dtype: np.float32);
Assert.AreEqual(expected_scores, actual_scores.numpy());
// Expected tensor of shape [1, 1, 1].
// expected000 = softmax_scores[0, 0] * 1.6 = 1.6
var expected = np.array(new[, ,] { { { 1.6f } } }, dtype: np.float32);
Assert.AreEqual(expected, actual.numpy());
}

[TestMethod]
public void test_multi_dim_with_mask()
{
// Scores tensor of shape [1, 1, 3]
var scores = np.array(new[, ,] { { { 1f, 0f, 1f } } }, dtype: np.float32);
// Value tensor of shape [1, 3, 1]
var v = np.array(new[, ,] { { { 1.6f }, { 0.7f }, { -0.8f } } }, dtype: np.float32);
// Scores mask tensor of shape [1, 1, 3]
var scores_mask = np.array(new[, ,] { { { true, true, false } } }, dtype: np.@bool);
var _tup_1 = new BaseDenseAttention(new())._apply_scores(scores: scores, value: v, scores_mask: scores_mask);
var actual = _tup_1.Item1;
var actual_scores = _tup_1.Item2;
// Expected softmax scores = softmax(scores) with zeros in positions where
// v_mask == False.
// => softmax_scores000 = exp(1)/(exp(1) + exp(0)) = 0.73105857863
// softmax_scores001 = exp(0)/(exp(1) + exp(0)) = 0.26894142137
// softmax_scores002 = 0
var expected_scores = np.array(new[, ,] { { { 0.73105857863f, 0.26894142137f, 0f } } }, dtype: np.float32);
Assert.AreEqual(expected_scores, actual_scores.numpy());
// Expected tensor of shape [1, 1, 1].
// expected000 = 0.73105857863 * 1.6 + 0.26894142137 * 0.7 - 0 * 0.8
// = 1.35795272077
//Actually the output is 1.3579528
var expected = np.array(new[, ,] { { { 1.3579528f } } }, dtype: np.float32);
Assert.AreEqual(expected, actual.numpy());
}
[TestMethod]
public void test_multi_dim_no_mask()
{
// Scores tensor of shape [1, 1, 3]
var scores = np.array(new[, ,] { { { 1f, 0f, 1f } } }, dtype: np.float32);
// Value tensor of shape [1, 3, 1]
var v = np.array(new[, ,] { { { 1.6f }, { 0.7f }, { -0.8f } } }, dtype: np.float32);
var _tup_1 = new BaseDenseAttention(new())._apply_scores(scores: scores, value: v);
var actual = _tup_1.Item1;
var actual_scores = _tup_1.Item2;
// Expected softmax_scores = softmax(scores).
// => softmax_scores000 = exp(1)/(exp(1) + exp(0) + exp(1))
// = 0.42231879825
// softmax_scores001 = exp(0)/(exp(1) + exp(0) + exp(1))
// = 0.15536240349
// softmax_scores002 = exp(1)/(exp(1) + exp(0) + exp(1))
// = 0.42231879825
//Actually the output is 0.42231882, 0.15536241, 0.42231882
var expected_scores = np.array(new[, ,] { { { 0.42231882f, 0.15536241f, 0.42231882f } } }, dtype: np.float32);
Assert.AreEqual(expected_scores, actual_scores.numpy());
// Expected tensor of shape [1, 1, 1].
// expected000 = 0.42231879825 * 1.6 + 0.15536240349 * 0.7
// - 0.42231879825 * 0.8
// = 0.44660872104
//Actually the output is 0.44660875
var expected = np.array(new[, ,] { { { 0.44660875f } } }, dtype: np.float32);
Assert.AreEqual(expected, actual.numpy());
}

[TestMethod]
public void test_one_dim_batch_size_two()
{
// Scores tensor of shape [2, 1, 1]
var scores = np.array(new[, ,] { { { 1.1f } }, { { 2.1f } } }, dtype: np.float32);
// Value tensor of shape [2, 1, 1]
var v = np.array(new[, ,] { { { 1.6f } }, { { 2.6f } } }, dtype: np.float32);
// Scpres mask tensor of shape [2, 1, 1]
var scores_mask = np.array(new[, ,] { { { true } }, { { true } } }, dtype: np.@bool);
var _tup_1 = new BaseDenseAttention(new())._apply_scores(scores: scores, value: v, scores_mask: scores_mask);
var actual = _tup_1.Item1;
var actual_scores = _tup_1.Item2;
// Expected softmax_scores = [[[1]], [[1]]]
var expected_scores = np.array(new[, ,] { { { 1f } }, { { 1f } } }, dtype: np.float32);
Assert.AreEqual(expected_scores, actual_scores.numpy());
// Expected tensor of shape [2, 1, 1].
// expected000 = softmax_scores[0, 0] * 1.6 = 1.6
// expected100 = softmax_scores[1, 0] * 2.6 = 2.6
var expected = np.array(new[, ,] { { { 1.6f } }, { { 2.6f } } }, dtype: np.float32);
Assert.AreEqual(expected, actual.numpy());
}

[TestMethod]
public void test_shape_with_dropout()
{
// scores: Scores float tensor of shape `[batch_size, tq, tv]`.
// value: Value tensor of shape `[batch_size, tv, dim]`.
var batch_size = 4;
var tq = 5;
var tv = 6;
var dim = 7;
var scores = np.ones((batch_size, tq, tv));
var value = np.ones((batch_size, tv, dim));
var _tup_1 = new BaseDenseAttention(new BaseDenseAttentionArgs { dropout = 0.1f })
._apply_scores(scores: scores, value: value, training: false);
var actual = _tup_1.Item1;
var actual_scores = _tup_1.Item2;
// Expected Tensor of shape `[batch_size, tq, tv]`.
var expected_scores_shape = new[] {
batch_size,
tq,
tv
};
Assert.AreEqual(expected_scores_shape, tf.shape(actual_scores).numpy());
// Expected Tensor of shape `[batch_size, tq, dim]`.
var expected_shape = new[] {
batch_size,
tq,
dim
};
Assert.AreEqual(expected_shape, tf.shape(actual).numpy());
}
#endregion
// ------------------------------------------------------------------
#region Attention
[TestMethod]
public void test_example()
{
//Variable-length int sequences.
var query_input = keras.Input((1000), dtype: TF_DataType.TF_INT32);
var value_input = keras.Input((1000), dtype: TF_DataType.TF_INT32);
// Embedding lookup.
var token_embedding = keras.layers.Embedding(input_dim: 1000, output_dim: 64);
// Query embeddings of shape [batch_size, Tq, dimension].
var query_embeddings = token_embedding.Apply(query_input);
// Value embeddings of shape [batch_size, Tv, dimension].
var value_embeddings = token_embedding.Apply(value_input);
// CNN layer.
var cnn_layer = keras.layers.Conv1D(
filters: 100,
kernel_size: 4,
// Use 'same' padding so outputs have the same shape as inputs.
padding: "same",
activation: "relu");
var cnn_layer2 = keras.layers.Conv1D(
filters: 100,
kernel_size: 4,
// Use 'same' padding so outputs have the same shape as inputs.
padding: "same",
activation: "relu");
// Query encoding of shape [batch_size, Tq, filters].
var query_seq_encoding = cnn_layer.Apply(query_embeddings);
// Value encoding of shape [batch_size, Tv, filters].
var value_seq_encoding = cnn_layer2.Apply(value_embeddings);
// Query-value attention of shape [batch_size, Tq, filters].
var query_value_attention_seq = keras.layers.Attention().Apply(
(query_seq_encoding, value_seq_encoding));
// Reduce over the sequence axis to produce encodings of shape
// [batch_size, filters].
var query_encoding = keras.layers.GlobalAveragePooling1D().Apply(
query_seq_encoding);
var query_value_attention = keras.layers.GlobalAveragePooling1D().Apply(
query_value_attention_seq);
// Concatenate query and document encodings to produce a DNN input layer.
var input_layer = keras.layers.Concatenate().Apply(
(query_encoding, query_value_attention));
// Add DNN layers, and create Model.
// ...
}

[TestMethod]
public void test_calculate_scores_one_dim()
{
// Query tensor of shape [1, 1, 1]
var q = np.array(new[,,] { { { 1.1f } } }, dtype: np.float32);
// Key tensor of shape [1, 1, 1]
var k = np.array(new[,,] { { { 1.6f } } }, dtype: np.float32);
var attention_layer = keras.layers.Attention();
//attention_layer.build((1));
var actual = attention_layer._calculate_scores(query: q, key: k);
// Expected tensor of shape [1, 1, 1].
// expected000 = 1.1*1.6 = 1.76
// Actually the output is 1.7600001
var expected = np.array(new[,,] { { { 1.7600001f } } }, dtype: np.float32);
Assert.AreEqual(expected, actual.numpy());
}

[TestMethod]
public void test_calculate_scores_multi_dim()
{
// Query tensor of shape [1, 2, 4]
var q = np.array(new[, ,] { {
{ 1f, 1.1f, 1.2f, 1.3f },
{ 2f, 2.1f, 2.2f, 2.3f }
} }, dtype: np.float32);
// Key tensor of shape [1, 3, 4]
var k = np.array(new[, ,] { {
{ 1.5f, 1.6f, 1.7f, 1.8f },
{ 2.5f, 2.6f, 2.7f, 2.8f },
{ 3.5f, 3.6f, 3.7f, 3.8f }
} }, dtype: np.float32);
var attention_layer = keras.layers.Attention();
//attention_layer.build(((1, 2, 4), (1, 3, 4)));
var actual = attention_layer._calculate_scores(query: q, key: k);
// Expected tensor of shape [1, 2, 3].
// expected000 = 1.*1.5+1.1*1.6+1.2*1.7+1.3*1.8 = 7.64
// expected001 = 1.*2.5+1.1*2.6+1.2*2.7+1.3*2.8 = 12.24
// expected002 = 1.*3.5+1.1*3.6+1.2*3.7+1.3*3.8 = 16.84
// expected010 = 2.*1.5+2.1*1.6+2.2*1.7+2.3*1.8 = 14.24
// expected011 = 2.*2.5+2.1*2.6+2.2*2.7+2.3*2.8 = 22.84
// expected012 = 2.*3.5+2.1*3.6+2.2*3.7+2.3*3.8 = 31.44
// Actually the output000 is 7.6400003, the output012 is 31.439999
var expected = np.array(new[, ,] { {
{ 7.6400003f, 12.24f, 16.84f },
{ 14.24f, 22.84f, 31.439999f }
} }, dtype: np.float32);
Assert.AreEqual(expected, actual.numpy());
}

[TestMethod]
public void test_calculate_scores_multi_dim_concat()
{
// Query tensor of shape [1, 2, 4]
var q = np.array(new[, ,] { {
{ 1f, 1.1f, 1.2f, 1.3f },
{ 2f, 2.1f, 2.2f, 2.3f }
} }, dtype: np.float32);
// Key tensor of shape [1, 3, 4]
var k = np.array(new[, ,] { {
{ 1.5f, 1.6f, 1.7f, 1.8f },
{ 2.5f, 2.6f, 2.7f, 2.8f },
{ 3.5f, 3.6f, 3.7f, 3.8f }
} }, dtype: np.float32);
var attention_layer = keras.layers.Attention(score_mode: "concat");
//attention_layer.concat_score_weight = 1;
attention_layer.concat_score_weight = base_layer_utils.make_variable(new VariableArgs() {
Name = "concat_score_weight",
Shape = (1),
DType = TF_DataType.TF_FLOAT,
Getter = base_layer_utils.make_variable,
Overwrite = true,
Initializer = tf.ones_initializer,
Synchronization = VariableSynchronization.Auto,
Aggregation = VariableAggregation.None,
Trainable = true
});
//attention_layer.build(((1, 2, 4), (1, 3, 4)));
//var actual = keras.backend.get_value(attention_layer._calculate_scores(query: q, key: k));
var actual = attention_layer._calculate_scores(query: q, key: k);
// pylint:disable=line-too-long
// expected000 = tanh(1.+1.5) + tanh(1.1+1.6) + tanh(1.2+1.7) + tanh(1.3+1.8) = 3.96753427840
// expected001 = tanh(1.+2.5) + tanh(1.1+2.6) + tanh(1.2+2.7) + tanh(1.3+2.8) = 3.99558784825
// expected002 = tanh(1.+3.5) + tanh(1.1+3.6) + tanh(1.2+3.7) + tanh(1.3+3.8) = 3.99940254147
// expected010 = tanh(2.+1.5) + tanh(2.1+1.6) + tanh(2.2+1.7) + tanh(2.3+1.8) = 3.99558784825
// expected011 = tanh(2.+2.5) + tanh(2.1+2.6) + tanh(2.2+2.7) + tanh(2.3+2.8) = 3.99940254147
// expected012 = tanh(2.+3.5) + tanh(2.1+3.6) + tanh(2.2+3.7) + tanh(2.3+3.8) = 3.99991913657
//Actually the output012 is 3.9999194
var expected = np.array(new[, ,] { {
{ 3.96753427840f, 3.99558784825f, 3.99940254147f },
{ 3.99558784825f, 3.99940254147f, 3.9999194f }
} }, dtype: np.float32);
Assert.AreEqual(expected, actual.numpy());
}
#endregion
}

}

Loading…
Cancel
Save