diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/AttentionArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/AttentionArgs.cs new file mode 100644 index 00000000..73477c58 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/AttentionArgs.cs @@ -0,0 +1,20 @@ +namespace Tensorflow.Keras.ArgsDefinition +{ + public class AttentionArgs : BaseDenseAttentionArgs + { + + /// + /// If `true`, will create a scalar variable to scale the attention scores. + /// + public bool use_scale { get; set; } = false; + + /// + /// Function to use to compute attention scores, one of + /// `{"dot", "concat"}`. `"dot"` refers to the dot product between the query + /// and key vectors. `"concat"` refers to the hyperbolic tangent of the + /// concatenation of the query and key vectors. + /// + public string score_mode { get; set; } = "dot"; + + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/BaseDenseAttentionArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/BaseDenseAttentionArgs.cs new file mode 100644 index 00000000..b2a0c3a5 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/BaseDenseAttentionArgs.cs @@ -0,0 +1,20 @@ +namespace Tensorflow.Keras.ArgsDefinition +{ + public class BaseDenseAttentionArgs : LayerArgs + { + + /// + /// Boolean. Set to `true` for decoder self-attention. Adds a mask such + /// that position `i` cannot attend to positions `j > i`. This prevents the + /// flow of information from the future towards the past. + /// + public bool causal { get; set; } = false; + + /// + /// Float between 0 and 1. Fraction of the units to drop for the + /// attention scores. + /// + public float dropout { get; set; } = 0f; + + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index 8ee3484f..03308ede 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -275,6 +275,15 @@ namespace Tensorflow.Keras.Engine weights.AddRange(non_trainable_weights); return weights; } + set + { + if (weights.Count() != value.Count()) throw new ValueError( + $"You called `set_weights` on layer \"{this.name}\"" + + $"with a weight list of length {len(value)}, but the layer was " + + $"expecting {len(weights)} weights."); + foreach (var (this_w, v_w) in zip(weights, value)) + this_w.assign(v_w, read_value: true); + } } public virtual LayerArgs get_config() diff --git a/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs b/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs new file mode 100644 index 00000000..51a40b58 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs @@ -0,0 +1,159 @@ +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Keras.ArgsDefinition; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Dot-product attention layer, a.k.a. Luong-style attention. + /// Inputs are `query` tensor of shape `[batch_size, Tq, dim]`, `value` tensor of + /// shape `[batch_size, Tv, dim]` and `key` tensor of shape + /// `[batch_size, Tv, dim]`. The calculation follows the steps: + /// + /// 1. Calculate scores with shape `[batch_size, Tq, Tv]` as a `query`-`key` dot + /// product: `scores = tf.matmul(query, key, transpose_b=True)`. + /// + /// + /// 2. Use scores to calculate a distribution with shape + /// `[batch_size, Tq, Tv]`: `distribution = tf.nn.softmax(scores)`. + /// + /// + /// 3. Use `distribution` to create a linear combination of `value` with + /// shape `[batch_size, Tq, dim]`: + /// `return tf.matmul(distribution, value)`. + /// + /// + /// 0 + /// + /// //Variable-length int sequences. + /// var query_input = keras.Input((1000), dtype: TF_DataType.TF_INT32); + /// var value_input = keras.Input((1000), dtype: TF_DataType.TF_INT32); + /// // Embedding lookup. + /// var token_embedding = keras.layers.Embedding(input_dim: 1000, output_dim: 64); + /// // Query embeddings of shape [batch_size, Tq, dimension]. + /// var query_embeddings = token_embedding.Apply(query_input); + /// // Value embeddings of shape [batch_size, Tv, dimension]. + /// var value_embeddings = token_embedding.Apply(value_input); + /// // CNN layer. + /// var cnn_layer = keras.layers.Conv1D( + /// filters: 100, + /// kernel_size: 4, + /// // Use 'same' padding so outputs have the same shape as inputs. + /// padding: "same"); + /// var cnn_layer2 = keras.layers.Conv1D( + /// filters: 100, + /// kernel_size: 4, + /// // Use 'same' padding so outputs have the same shape as inputs. + /// padding: "same"); + /// // Query encoding of shape [batch_size, Tq, filters]. + /// var query_seq_encoding = cnn_layer.Apply(query_embeddings); + /// // Value encoding of shape [batch_size, Tv, filters]. + /// var value_seq_encoding = cnn_layer.Apply(value_embeddings); + /// // Query-value attention of shape [batch_size, Tq, filters]. + /// var query_value_attention_seq = keras.layers.Attention().Apply( + /// (query_seq_encoding, value_seq_encoding)); + /// // Reduce over the sequence axis to produce encodings of shape + /// // [batch_size, filters]. + /// var query_encoding = keras.layers.GlobalAveragePooling1D().Apply( + /// query_seq_encoding); + /// var query_value_attention = keras.layers.GlobalAveragePooling1D().Apply( + /// query_value_attention_seq); + /// // Concatenate query and document encodings to produce a DNN input layer. + /// var input_layer = keras.layers.Concatenate().Apply( + /// (query_encoding, query_value_attention)); + /// // Add DNN layers, and create Model. + /// // ... + /// + /// + public class Attention : BaseDenseAttention + { + + public IVariableV1 concat_score_weight; + + public IVariableV1 scale; + + AttentionArgs args; + + string score_mode { get => args.score_mode; } + + bool use_scale { get => args.use_scale; } + + public Attention(AttentionArgs args) : base(args) + { + this.args = args; + if (!new List { + "dot", + "concat" + }.Contains(this.score_mode)) + throw new ValueError("Received: score_mode={score_mode}. Acceptable values are: [\"dot\", \"concat\"]"); + } + + // Creates variable when `use_scale` is True or `score_mode` is `concat`. + protected override void build(Tensors inputs) { + if (this.use_scale) + this.scale = this.add_weight(name: "scale", + shape: 1, + initializer: tf.ones_initializer, + dtype: this.DType, + trainable: true); + else + this.scale = null; + + if (this.score_mode == "concat") + this.concat_score_weight = this.add_weight(name: "concat_score_weight", + shape: 1, + initializer: tf.ones_initializer, + dtype: this.DType, + trainable: true); + else + this.concat_score_weight = null; + base.build(inputs); + } + + /// + /// Calculates attention scores as a query-key dot product. + /// + /// query: Query tensor of shape `[batch_size, Tq, dim]`. + /// key: Key tensor of shape `[batch_size, Tv, dim]`. + /// Tensor of shape `[batch_size, Tq, Tv]`. + public override Tensor _calculate_scores(Tensor query, Tensor key) + { + Tensor scores = null; + if (this.score_mode == "dot") + { + //scores = tf.matmul(query, key, transpose_b: true); + //scores = tf.matmul(tf.squeeze(query),tf.squeeze(key), transpose_b: true); + scores = tf.linalg.einsum("bij,bkj->bik", (query, key)); + if (this.scale != null) + scores *= this.scale.AsTensor(); + } else if (this.score_mode == "concat") { + // Reshape tensors to enable broadcasting. + // Reshape into [batch_size, Tq, 1, dim]. + var q_reshaped = tf.expand_dims(query, axis: -2); + // Reshape into [batch_size, 1, Tv, dim]. + var k_reshaped = tf.expand_dims(key, axis: -3); + if (this.scale != null) + scores = this.concat_score_weight.AsTensor() * + tf.reduce_sum(tf.tanh(this.scale.AsTensor() * (q_reshaped + k_reshaped)), axis: -1); + else + scores = this.concat_score_weight.AsTensor() * + tf.reduce_sum(tf.tanh(q_reshaped + k_reshaped), axis: -1); + } + return scores; + } + + public override LayerArgs get_config() => this.args; + //var config = new Dictionary { + // { + // "use_scale", + // this.use_scale}, + // { + // "score_mode", + // this.score_mode}}; + //var base_config = base.get_config(); + //return new dict(base_config.items().ToList() + config.items().ToList()); + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs b/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs new file mode 100644 index 00000000..190ad5a7 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs @@ -0,0 +1,257 @@ +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.ArgsDefinition; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; +using System; +using System.Collections.Generic; +using System.Linq; + +/// +/// Base class for attention layers that can be used in sequence DNN/CNN models. +///This file follows the terminology of https://arxiv.org/abs/1706.03762 Figure 2. +///Attention is formed by three tensors: Query, Key and Value. +/// + +namespace Tensorflow.Keras.Layers +{ + + /// + /// Base Attention class for Dense networks. + /// This class is suitable for Dense or CNN networks, and not for RNN networks. + /// Implementations of attention mechanisms should inherit from this class, and + /// reuse the `apply_attention_scores()` method. + /// + public class BaseDenseAttention : Layer + { + + BaseDenseAttentionArgs args; + + bool causal { get => args.causal; } + + float dropout { get => args.dropout; } + + protected bool supports_masking; + + public BaseDenseAttention(BaseDenseAttentionArgs args) : base(args) + { + this.args = args; + this.supports_masking = true; + } + + /// + /// Calculates attention scores. + /// + /// query: Query tensor of shape `[batch_size, Tq, dim]`. + /// key: Key tensor of shape `[batch_size, Tv, dim]`. + /// Tensor of shape `[batch_size, Tq, Tv]`. + public virtual Tensor _calculate_scores(Tensor query, Tensor key) => + throw new NotImplementedException(""); + + /// + /// Applies attention scores to the given value tensor. + /// To use this method in your attention layer, follow the steps: + /// + /// * Use `query` tensor of shape `[batch_size, Tq]` and `key` tensor of shape + /// `[batch_size, Tv]` to calculate the attention `scores`. + /// + /// + /// * Pass `scores` and `value` tensors to this method. The method applies + /// `scores_mask`, calculates `attention_distribution = softmax(scores)`, then + /// returns `matmul(attention_distribution, value). + /// + /// + /// * Apply `query_mask` and return the result. + /// + /// + /// Scores float tensor of shape `[batch_size, Tq, Tv]`. + /// Value tensor of shape `[batch_size, Tv, dim]`. + /// + /// A boolean mask `Tensor` of shape `[batch_size, 1, Tv]` or + /// [batch_size, Tq, Tv]`. If given, scores at positions where + /// `scores_mask==False` do not contribute to the result. It must contain + /// at least one `True` value in each line along the last dimension. + /// + /// + /// Boolean indicating whether the layer should behave in + /// training mode (adding dropout) or in inference mode (no dropout). + /// + /// + /// + /// Tensor of shape `[batch_size, Tq, dim]`. + /// + /// + /// Attention scores after masking and softmax with shape + /// [batch_size, Tq, Tv]`. + /// + /// + public (Tensor, Tensor) _apply_scores(Tensor scores, + Tensor value, + Tensor scores_mask = null, + bool? training = null) + { + if (scores_mask != null) + { + var padding_mask = tf.logical_not(scores_mask); + // Bias so padding positions do not contribute to attention distribution. + // Note 65504. is the max float16 value. + if (scores.dtype == tf.float16) + scores -= 65504f * tf.cast(padding_mask, dtype: scores.dtype); + else + scores -= 1000000000f * tf.cast(padding_mask, dtype: scores.dtype); + } + bool _training; + training ??= false; // TODO: Delete this line when backend.learning_phase is available + if (training == null) + _training = keras.backend.learning_phase() == + Tensorflow.Keras.GraphLearningPhase.train_mode ? + true : false; + else _training = training.Value; + var weights = tf.nn.softmax(scores); + Func dropped_weights = () => tf.nn.dropout(weights, rate: this.dropout); + weights = Tensorflow.Framework.smart_module.smart_cond(_training, dropped_weights, () => tf.identity(weights)); + //return (tf.matmul(weights, value), weights); + return (tf.linalg.einsum("bij,bjk->bik", (weights, value)), weights); + } + + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + { + Tensors _inp; + Tensors _mask = null; + + int count = inputs.Count(); + if (count < 2 || count > 6) throw new ValueError( + $"{ this.name } layer accepts inputs list of length from 2 to 5, " + + $"namely [query, value, (key), (query_mask), (value_mask), (return_attention_scores)]." + + $"Received length: {count}."); + + bool has_bool = inputs[count - 1].dtype == TF_DataType.TF_BOOL; + bool return_attention_scores = false; + if (has_bool) + { + return_attention_scores = (bool)inputs[count - 1]; + count--; + } + + switch (count) + { + case 2: + _inp = (inputs[0], inputs[1]); + break; + case 3: + _inp = new[] { inputs[0], inputs[1], inputs[2] }; + break; + case 4: + if (inputs[0].shape == inputs[2].shape) + if (inputs[1].shape == inputs[3].shape) + { + _inp = new[] { inputs[0], inputs[1] }; + _mask = new[] { inputs[2], inputs[3] }; + break; + } + throw new ValueError(); //TODO:Add discriptions for this err + case 5: + _inp = new[] { inputs[0], inputs[1], inputs[2] }; + _mask = (inputs[3], inputs[4]); + break; + default: + throw new ValueError(); //TODO:Add discriptions for this err + } + + return call(_inp, _mask, training, return_attention_scores); + } + + protected Tensors call(Tensors inputs, Tensors mask = null, bool? training = null, bool return_attention_scores = false) + { + Tensor causal_mask; + //this._validate_call_args(inputs: inputs, mask: mask); + var q = inputs[0]; + var v = inputs[1]; + var k = inputs.Count() > 2 ? inputs[2] : v; + var q_mask = mask != null ? mask[0] : null; + var v_mask = mask != null ? mask[1] : null; + var scores = this._calculate_scores(query: q, key: k); + if (v_mask != null) + // Mask of shape [batch_size, 1, Tv]. + v_mask = tf.expand_dims(v_mask, axis: -2); + if (this.causal) + { + // Creates a lower triangular mask, so position i cannot attend to + // positions j>i. This prevents the flow of information from the future + // into the past. + var scores_shape = tf.shape(scores); + // causal_mask_shape = [1, Tq, Tv]. + var causal_mask_shape = tf.concat(new List { + tf.ones_like(tf.slice(scores_shape, new[]{0}, new[]{-2})), + tf.concat(new[]{scores_shape[-2], scores_shape[-1]}, 0) + }, axis: 0); + var _causal_mask_shape = new Shape(causal_mask_shape.ToArray()); + causal_mask = _lower_triangular_mask(_causal_mask_shape); + } + else + causal_mask = null; + var scores_mask = _merge_masks(v_mask, causal_mask); + var (result, attention_scores) = this._apply_scores(scores: scores, value: v, scores_mask: scores_mask, training: training); + if (q_mask != null) + { + // Mask of shape [batch_size, Tq, 1]. + q_mask = tf.expand_dims(q_mask, axis: -1); + result *= tf.cast(q_mask, dtype: result.dtype); + } + if (return_attention_scores) + return new Tensors(result, attention_scores); + return result; + } + + public Tensor compute_mask(Tensors inputs, Tensors mask = null) + { + this._validate_call_args(inputs: inputs, mask: mask); + if (mask != null) + { + var q_mask = mask[0]; + if (q_mask == null) + return null; + return tf.convert_to_tensor(q_mask); + } + return null; + } + + //public Shape compute_output_shape(Shape input_shape) { + // // return_attention_scores argument of BaseDenseAttention.call method + // // is ignored. Output shape of attention_scores cannot be returned. + // return input_shape[0]; + //} + + /// + /// Validates arguments of the call method. + /// + public void _validate_call_args(Tensors inputs, Tensors mask) + { + if (inputs.Count() < 2 || inputs.Count() > 3) + throw new ValueError( + $"{this.name} layer accepts inputs list of length 2 or 3, " + + $"namely [query, value] or [query, value, key]. Received length: {len(inputs)}."); + if (mask != null) + if (mask.Count() < 2 || mask.Count() > inputs.Count()) + throw new ValueError($"{this.name} layer mask must be a list of length 2, " + + $"namely [query_mask, value_mask]. Received length: {len(mask)}."); + } + + public static Tensor _lower_triangular_mask(Shape shape) + { + var row_index = tf.cumsum(tf.ones(shape: shape, dtype: tf.int32), axis: -2); + var col_index = tf.cumsum(tf.ones(shape: shape, dtype: tf.int32), axis: -1); + return tf.greater_equal(row_index, col_index); + } + + public static Tensor _merge_masks(Tensor x, Tensor y) + { + if (x == null) + return y; + if (y == null) + return x; + return tf.logical_and(x, y); + } + + public override LayerArgs get_config() => this.args; + } +} diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.Attention.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.Attention.cs new file mode 100644 index 00000000..4175b458 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.Attention.cs @@ -0,0 +1,25 @@ +using System; +using Tensorflow.NumPy; +using System.Collections.Generic; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.Layers +{ + public partial class LayersApi + { + public Attention Attention(bool use_scale = false, + string score_mode = "dot", + bool causal = false, + float dropout = 0f) => + new Attention(new AttentionArgs + { + use_scale = use_scale, + score_mode = score_mode, + causal = causal, + dropout = dropout + }); + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs new file mode 100644 index 00000000..0807b87c --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs @@ -0,0 +1,310 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using Tensorflow.NumPy; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; +using Tensorflow.Keras.Layers; +using Tensorflow; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Utils; + +namespace TensorFlowNET.Keras.UnitTest +{ + [TestClass] + public class AttentionTest : EagerModeTestBase + { + #region BaseDenseAttention + [TestMethod] + public void test_one_dim_with_mask() + { + // Scores tensor of shape [1, 1, 1] + var scores = np.array(new[, ,] { { { 1.1f } } }, dtype: np.float32); + // Value tensor of shape [1, 1, 1] + var v = np.array(new[, ,] { { { 1.6f } } }, dtype: np.float32); + // Scores mask tensor of shape [1, 1, 1] + var scores_mask = np.array(new[, ,] { { { true } } }, dtype: np.@bool); + var _tup_1 = new BaseDenseAttention(new())._apply_scores(scores: scores, value: v, scores_mask: scores_mask); + var actual = _tup_1.Item1; + var actual_scores = _tup_1.Item2; + // Expected softmax_scores = [[[1]]] + var expected_scores = np.array(new[, ,] { { { 1f } } }, dtype: np.float32); + Assert.AreEqual(expected_scores, actual_scores.numpy()); + // Expected tensor of shape [1, 1, 1]. + // expected000 = softmax_scores[0, 0] * 1.6 = 1.6 + var expected = np.array(new[, ,] { { { 1.6f } } }, dtype: np.float32); + Assert.AreEqual(expected, actual.numpy()); + } + + [TestMethod] + public void test_one_dim_no_mask() + { + // Scores tensor of shape [1, 1, 1] + var scores = np.array(new[, ,] { { { 1.1f } } }, dtype: np.float32); + // Value tensor of shape [1, 1, 1] + var v = np.array(new[, ,] { { { 1.6f } } }, dtype: np.float32); + var _tup_1 = new BaseDenseAttention(new())._apply_scores(scores: scores, value: v); + var actual = _tup_1.Item1; + var actual_scores = _tup_1.Item2; + // Expected softmax_scores = [[[1]]] + var expected_scores = np.array(new[, ,] { { { 1f } } }, dtype: np.float32); + Assert.AreEqual(expected_scores, actual_scores.numpy()); + // Expected tensor of shape [1, 1, 1]. + // expected000 = softmax_scores[0, 0] * 1.6 = 1.6 + var expected = np.array(new[, ,] { { { 1.6f } } }, dtype: np.float32); + Assert.AreEqual(expected, actual.numpy()); + } + + [TestMethod] + public void test_multi_dim_with_mask() + { + // Scores tensor of shape [1, 1, 3] + var scores = np.array(new[, ,] { { { 1f, 0f, 1f } } }, dtype: np.float32); + // Value tensor of shape [1, 3, 1] + var v = np.array(new[, ,] { { { 1.6f }, { 0.7f }, { -0.8f } } }, dtype: np.float32); + // Scores mask tensor of shape [1, 1, 3] + var scores_mask = np.array(new[, ,] { { { true, true, false } } }, dtype: np.@bool); + var _tup_1 = new BaseDenseAttention(new())._apply_scores(scores: scores, value: v, scores_mask: scores_mask); + var actual = _tup_1.Item1; + var actual_scores = _tup_1.Item2; + // Expected softmax scores = softmax(scores) with zeros in positions where + // v_mask == False. + // => softmax_scores000 = exp(1)/(exp(1) + exp(0)) = 0.73105857863 + // softmax_scores001 = exp(0)/(exp(1) + exp(0)) = 0.26894142137 + // softmax_scores002 = 0 + var expected_scores = np.array(new[, ,] { { { 0.73105857863f, 0.26894142137f, 0f } } }, dtype: np.float32); + Assert.AreEqual(expected_scores, actual_scores.numpy()); + // Expected tensor of shape [1, 1, 1]. + // expected000 = 0.73105857863 * 1.6 + 0.26894142137 * 0.7 - 0 * 0.8 + // = 1.35795272077 + //Actually the output is 1.3579528 + var expected = np.array(new[, ,] { { { 1.3579528f } } }, dtype: np.float32); + Assert.AreEqual(expected, actual.numpy()); + } + + [TestMethod] + public void test_multi_dim_no_mask() + { + // Scores tensor of shape [1, 1, 3] + var scores = np.array(new[, ,] { { { 1f, 0f, 1f } } }, dtype: np.float32); + // Value tensor of shape [1, 3, 1] + var v = np.array(new[, ,] { { { 1.6f }, { 0.7f }, { -0.8f } } }, dtype: np.float32); + var _tup_1 = new BaseDenseAttention(new())._apply_scores(scores: scores, value: v); + var actual = _tup_1.Item1; + var actual_scores = _tup_1.Item2; + // Expected softmax_scores = softmax(scores). + // => softmax_scores000 = exp(1)/(exp(1) + exp(0) + exp(1)) + // = 0.42231879825 + // softmax_scores001 = exp(0)/(exp(1) + exp(0) + exp(1)) + // = 0.15536240349 + // softmax_scores002 = exp(1)/(exp(1) + exp(0) + exp(1)) + // = 0.42231879825 + //Actually the output is 0.42231882, 0.15536241, 0.42231882 + var expected_scores = np.array(new[, ,] { { { 0.42231882f, 0.15536241f, 0.42231882f } } }, dtype: np.float32); + Assert.AreEqual(expected_scores, actual_scores.numpy()); + // Expected tensor of shape [1, 1, 1]. + // expected000 = 0.42231879825 * 1.6 + 0.15536240349 * 0.7 + // - 0.42231879825 * 0.8 + // = 0.44660872104 + //Actually the output is 0.44660875 + var expected = np.array(new[, ,] { { { 0.44660875f } } }, dtype: np.float32); + Assert.AreEqual(expected, actual.numpy()); + } + + [TestMethod] + public void test_one_dim_batch_size_two() + { + // Scores tensor of shape [2, 1, 1] + var scores = np.array(new[, ,] { { { 1.1f } }, { { 2.1f } } }, dtype: np.float32); + // Value tensor of shape [2, 1, 1] + var v = np.array(new[, ,] { { { 1.6f } }, { { 2.6f } } }, dtype: np.float32); + // Scpres mask tensor of shape [2, 1, 1] + var scores_mask = np.array(new[, ,] { { { true } }, { { true } } }, dtype: np.@bool); + var _tup_1 = new BaseDenseAttention(new())._apply_scores(scores: scores, value: v, scores_mask: scores_mask); + var actual = _tup_1.Item1; + var actual_scores = _tup_1.Item2; + // Expected softmax_scores = [[[1]], [[1]]] + var expected_scores = np.array(new[, ,] { { { 1f } }, { { 1f } } }, dtype: np.float32); + Assert.AreEqual(expected_scores, actual_scores.numpy()); + // Expected tensor of shape [2, 1, 1]. + // expected000 = softmax_scores[0, 0] * 1.6 = 1.6 + // expected100 = softmax_scores[1, 0] * 2.6 = 2.6 + var expected = np.array(new[, ,] { { { 1.6f } }, { { 2.6f } } }, dtype: np.float32); + Assert.AreEqual(expected, actual.numpy()); + } + + [TestMethod] + public void test_shape_with_dropout() + { + // scores: Scores float tensor of shape `[batch_size, tq, tv]`. + // value: Value tensor of shape `[batch_size, tv, dim]`. + var batch_size = 4; + var tq = 5; + var tv = 6; + var dim = 7; + var scores = np.ones((batch_size, tq, tv)); + var value = np.ones((batch_size, tv, dim)); + var _tup_1 = new BaseDenseAttention(new BaseDenseAttentionArgs { dropout = 0.1f }) + ._apply_scores(scores: scores, value: value, training: false); + var actual = _tup_1.Item1; + var actual_scores = _tup_1.Item2; + // Expected Tensor of shape `[batch_size, tq, tv]`. + var expected_scores_shape = new[] { + batch_size, + tq, + tv + }; + Assert.AreEqual(expected_scores_shape, tf.shape(actual_scores).numpy()); + // Expected Tensor of shape `[batch_size, tq, dim]`. + var expected_shape = new[] { + batch_size, + tq, + dim + }; + Assert.AreEqual(expected_shape, tf.shape(actual).numpy()); + } + #endregion + // ------------------------------------------------------------------ + #region Attention + [TestMethod] + public void test_example() + { + //Variable-length int sequences. + var query_input = keras.Input((1000), dtype: TF_DataType.TF_INT32); + var value_input = keras.Input((1000), dtype: TF_DataType.TF_INT32); + // Embedding lookup. + var token_embedding = keras.layers.Embedding(input_dim: 1000, output_dim: 64); + // Query embeddings of shape [batch_size, Tq, dimension]. + var query_embeddings = token_embedding.Apply(query_input); + // Value embeddings of shape [batch_size, Tv, dimension]. + var value_embeddings = token_embedding.Apply(value_input); + // CNN layer. + var cnn_layer = keras.layers.Conv1D( + filters: 100, + kernel_size: 4, + // Use 'same' padding so outputs have the same shape as inputs. + padding: "same", + activation: "relu"); + var cnn_layer2 = keras.layers.Conv1D( + filters: 100, + kernel_size: 4, + // Use 'same' padding so outputs have the same shape as inputs. + padding: "same", + activation: "relu"); + // Query encoding of shape [batch_size, Tq, filters]. + var query_seq_encoding = cnn_layer.Apply(query_embeddings); + // Value encoding of shape [batch_size, Tv, filters]. + var value_seq_encoding = cnn_layer2.Apply(value_embeddings); + // Query-value attention of shape [batch_size, Tq, filters]. + var query_value_attention_seq = keras.layers.Attention().Apply( + (query_seq_encoding, value_seq_encoding)); + // Reduce over the sequence axis to produce encodings of shape + // [batch_size, filters]. + var query_encoding = keras.layers.GlobalAveragePooling1D().Apply( + query_seq_encoding); + var query_value_attention = keras.layers.GlobalAveragePooling1D().Apply( + query_value_attention_seq); + // Concatenate query and document encodings to produce a DNN input layer. + var input_layer = keras.layers.Concatenate().Apply( + (query_encoding, query_value_attention)); + // Add DNN layers, and create Model. + // ... + } + + [TestMethod] + public void test_calculate_scores_one_dim() + { + // Query tensor of shape [1, 1, 1] + var q = np.array(new[,,] { { { 1.1f } } }, dtype: np.float32); + // Key tensor of shape [1, 1, 1] + var k = np.array(new[,,] { { { 1.6f } } }, dtype: np.float32); + var attention_layer = keras.layers.Attention(); + //attention_layer.build((1)); + var actual = attention_layer._calculate_scores(query: q, key: k); + // Expected tensor of shape [1, 1, 1]. + // expected000 = 1.1*1.6 = 1.76 + // Actually the output is 1.7600001 + var expected = np.array(new[,,] { { { 1.7600001f } } }, dtype: np.float32); + Assert.AreEqual(expected, actual.numpy()); + } + + [TestMethod] + public void test_calculate_scores_multi_dim() + { + // Query tensor of shape [1, 2, 4] + var q = np.array(new[, ,] { { + { 1f, 1.1f, 1.2f, 1.3f }, + { 2f, 2.1f, 2.2f, 2.3f } + } }, dtype: np.float32); + // Key tensor of shape [1, 3, 4] + var k = np.array(new[, ,] { { + { 1.5f, 1.6f, 1.7f, 1.8f }, + { 2.5f, 2.6f, 2.7f, 2.8f }, + { 3.5f, 3.6f, 3.7f, 3.8f } + } }, dtype: np.float32); + var attention_layer = keras.layers.Attention(); + //attention_layer.build(((1, 2, 4), (1, 3, 4))); + var actual = attention_layer._calculate_scores(query: q, key: k); + // Expected tensor of shape [1, 2, 3]. + // expected000 = 1.*1.5+1.1*1.6+1.2*1.7+1.3*1.8 = 7.64 + // expected001 = 1.*2.5+1.1*2.6+1.2*2.7+1.3*2.8 = 12.24 + // expected002 = 1.*3.5+1.1*3.6+1.2*3.7+1.3*3.8 = 16.84 + // expected010 = 2.*1.5+2.1*1.6+2.2*1.7+2.3*1.8 = 14.24 + // expected011 = 2.*2.5+2.1*2.6+2.2*2.7+2.3*2.8 = 22.84 + // expected012 = 2.*3.5+2.1*3.6+2.2*3.7+2.3*3.8 = 31.44 + // Actually the output000 is 7.6400003, the output012 is 31.439999 + var expected = np.array(new[, ,] { { + { 7.6400003f, 12.24f, 16.84f }, + { 14.24f, 22.84f, 31.439999f } + } }, dtype: np.float32); + Assert.AreEqual(expected, actual.numpy()); + } + + [TestMethod] + public void test_calculate_scores_multi_dim_concat() + { + // Query tensor of shape [1, 2, 4] + var q = np.array(new[, ,] { { + { 1f, 1.1f, 1.2f, 1.3f }, + { 2f, 2.1f, 2.2f, 2.3f } + } }, dtype: np.float32); + // Key tensor of shape [1, 3, 4] + var k = np.array(new[, ,] { { + { 1.5f, 1.6f, 1.7f, 1.8f }, + { 2.5f, 2.6f, 2.7f, 2.8f }, + { 3.5f, 3.6f, 3.7f, 3.8f } + } }, dtype: np.float32); + var attention_layer = keras.layers.Attention(score_mode: "concat"); + //attention_layer.concat_score_weight = 1; + attention_layer.concat_score_weight = base_layer_utils.make_variable(new VariableArgs() { + Name = "concat_score_weight", + Shape = (1), + DType = TF_DataType.TF_FLOAT, + Getter = base_layer_utils.make_variable, + Overwrite = true, + Initializer = tf.ones_initializer, + Synchronization = VariableSynchronization.Auto, + Aggregation = VariableAggregation.None, + Trainable = true + }); + //attention_layer.build(((1, 2, 4), (1, 3, 4))); + //var actual = keras.backend.get_value(attention_layer._calculate_scores(query: q, key: k)); + var actual = attention_layer._calculate_scores(query: q, key: k); + // pylint:disable=line-too-long + // expected000 = tanh(1.+1.5) + tanh(1.1+1.6) + tanh(1.2+1.7) + tanh(1.3+1.8) = 3.96753427840 + // expected001 = tanh(1.+2.5) + tanh(1.1+2.6) + tanh(1.2+2.7) + tanh(1.3+2.8) = 3.99558784825 + // expected002 = tanh(1.+3.5) + tanh(1.1+3.6) + tanh(1.2+3.7) + tanh(1.3+3.8) = 3.99940254147 + // expected010 = tanh(2.+1.5) + tanh(2.1+1.6) + tanh(2.2+1.7) + tanh(2.3+1.8) = 3.99558784825 + // expected011 = tanh(2.+2.5) + tanh(2.1+2.6) + tanh(2.2+2.7) + tanh(2.3+2.8) = 3.99940254147 + // expected012 = tanh(2.+3.5) + tanh(2.1+3.6) + tanh(2.2+3.7) + tanh(2.3+3.8) = 3.99991913657 + //Actually the output012 is 3.9999194 + var expected = np.array(new[, ,] { { + { 3.96753427840f, 3.99558784825f, 3.99940254147f }, + { 3.99558784825f, 3.99940254147f, 3.9999194f } + } }, dtype: np.float32); + Assert.AreEqual(expected, actual.numpy()); + } + #endregion + } + +} \ No newline at end of file