diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Softmax.cs b/src/TensorFlowNET.Keras/Layers/Activation/Softmax.cs index e2d3ad8b..3ffae27f 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/Softmax.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/Softmax.cs @@ -12,7 +12,8 @@ namespace Tensorflow.Keras.Layers { axis = args.axis; } protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { - Tensor x = inputs; + Tensor x = inputs.Length == 2 ? inputs + ((1.0 - tf.cast(inputs[1], inputs.dtype)) * 1e-9) + : inputs; Tensor e = tf.exp(tf.sub(x, tf.reduce_max(x, axis: this.axis, keepdims: true))); Tensor s = tf.reduce_sum(e, axis: this.axis, keepdims: true); return tf.div(e, s); diff --git a/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs b/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs index 190ad5a7..3f618b5d 100644 --- a/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs +++ b/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs @@ -120,7 +120,7 @@ namespace Tensorflow.Keras.Layers int count = inputs.Count(); if (count < 2 || count > 6) throw new ValueError( - $"{ this.name } layer accepts inputs list of length from 2 to 5, " + + $"{ this.name } layer accepts inputs list of length from 2 to 6, " + $"namely [query, value, (key), (query_mask), (value_mask), (return_attention_scores)]." + $"Received length: {count}."); diff --git a/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs b/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs new file mode 100644 index 00000000..1a936aa7 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs @@ -0,0 +1,352 @@ +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.NumPy; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; +using System; +using System.Linq; + +namespace Tensorflow.Keras.Layers +{ + public class MultiHeadAttention : Layer + { + static readonly string _CHR_IDX = "abcdefghijklmnopqrstuvwxyz"; + + MultiHeadAttentionArgs args; + Shape _query_shape = null; + Shape _key_shape = null; + Shape _value_shape = null; + bool _built_from_signature = false; + EinsumDense _query_dense = null; + EinsumDense _key_dense = null; + EinsumDense _value_dense = null; + EinsumDense _output_dense = null; + string _dot_product_equation = ""; + string _combine_equation = ""; + Softmax _softmax = null; + Dropout _dropout_layer = null; + + /// + /// Builds einsum equations for the attention computation. + /// Query, key, value inputs after projection are expected to have the shape as: + /// `(bs, [non-attention dims], [attention dims], num_heads, channels)`. + /// `bs` and `[non-attention dims]` are treated as `[batch dims]`. + /// + /// + /// The attention operations can be generalized: + /// + /// + /// (1) Query-key dot product: + /// `([batch dims], [query attention dims], num_heads, channels), ([batch dims], + /// [key attention dims], num_heads, channels) -> ([batch dim], + /// num_heads, [query attention dims], [key attention dims])` + /// + /// (2) Combination: + /// `([batch dims], num_heads, [query attention dims], [key attention dims]), + /// ([batch dims], [value attention dims], num_heads, channels) -> ([batch dims], + /// [query attention dims], num_heads, channels)` + /// + /// + /// Rank of query, key, value tensors. + /// List/tuple of axes, `[-1, rank)`, + /// that attention will be applied to. + /// + public static (string, string, int) _build_attention_equation(int rank, Shape attn_axes) + { + var target_notation = _CHR_IDX.Substring(0, rank); + // `batch_dims` includes the head dim. + var batch_dims = range(rank).Except(attn_axes.as_int_list().concat(new[] { rank - 1 })); + var letter_offset = rank; + var source_notation = ""; + for (int i = 0; i < rank; i++) + { + if (batch_dims.Contains(i) || i == rank - 1) + source_notation += target_notation[i]; + else + { + source_notation += _CHR_IDX[letter_offset]; + letter_offset += 1; + } + } + var product_notation = "".Insert(0, new string((from i in batch_dims + select (char)(int)target_notation[i]).Concat( + + from i in attn_axes.as_int_list() + select (char)(int)target_notation[i]).Concat( + + from i in attn_axes.as_int_list() + select source_notation[i]).ToArray())); + var dot_product_equation = $"{source_notation},{target_notation}->{product_notation}"; + var attn_scores_rank = product_notation.Count(); + var combine_equation = $"{product_notation},{source_notation}->{target_notation}"; + return (dot_product_equation, combine_equation, attn_scores_rank); + } + + /// + /// Builds an einsum equation for projections inside multi-head attention. + /// + public static (string, string, int) _build_proj_equation(int free_dims, int bound_dims, int output_dims) + { + char _char; + var input_str = ""; + var kernel_str = ""; + var output_str = ""; + var bias_axes = ""; + var letter_offset = 0; + foreach (var i in range(free_dims)) + { + _char = _CHR_IDX[i + letter_offset]; + input_str += _char; + output_str += _char; + } + letter_offset += free_dims; + foreach (var i in range(bound_dims)) + { + _char = _CHR_IDX[i + letter_offset]; + input_str += _char; + kernel_str += _char; + } + letter_offset += bound_dims; + foreach (var i in range(output_dims)) + { + _char = _CHR_IDX[i + letter_offset]; + kernel_str += _char; + output_str += _char; + bias_axes += _char; + } + var equation = $"{input_str},{kernel_str}->{output_str}"; + return (equation, bias_axes, output_str.Count()); + } + + static Shape _get_output_shape(int output_rank, Shape known_last_dims) + => (from _ in range(output_rank - known_last_dims.rank) + select -1).Concat(known_last_dims.as_int_list()).ToArray(); + + public MultiHeadAttention(MultiHeadAttentionArgs args) : base(args) + { + this.args = args; + } + + public void _build_from_signature(Tensor query, Tensor value, Tensor key = null) + => this._build_from_signature(query.shape, value.shape, key?.shape); + + public void _build_from_signature(Shape query, Shape value, Shape key = null) + { + this._built_from_signature = true; + this._query_shape = query; + this._value_shape = value; + if (key == null) + this._key_shape = this._value_shape; + else + this._key_shape = key; + // Any setup work performed only once should happen in an `init_scope` + // to avoid creating symbolic Tensors that will later pollute any eager + // operations. + tf_with(tf.init_scope(), _ => + { + var free_dims = this._query_shape.rank - 1; + var (einsum_equation, bias_axes, output_rank) = _build_proj_equation( + free_dims, bound_dims: 1, output_dims: 2); + this._query_dense = _get_dense(einsum_equation, + _get_output_shape(output_rank - 1, + (this.args.NumHeads, this.args.KeyDim)), + this.args.UseBias ? bias_axes : null, + "query"); + (einsum_equation, bias_axes, output_rank) = _build_proj_equation( + this._key_shape.rank - 1, bound_dims: 1, output_dims: 2); + this._key_dense = _get_dense(einsum_equation, + _get_output_shape(output_rank - 1, + (this.args.NumHeads, this.args.KeyDim)), + this.args.UseBias ? bias_axes : null, + "key"); + (einsum_equation, bias_axes, output_rank) = _build_proj_equation( + this._value_shape.rank - 1, bound_dims: 1, output_dims: 2); + this._value_dense = _get_dense(einsum_equation, + _get_output_shape(output_rank - 1, + (this.args.NumHeads, this.args.ValueDim ?? -1)), + this.args.UseBias ? bias_axes : null, + "value"); + // Builds the attention computations for multi-head dot product attention. + // These computations could be wrapped into the keras attention layer once + // it support mult-head einsum computations. + this._build_attention(output_rank); + this._output_dense = _build_output_dense(free_dims, "attention_output"); + }); + this.StackLayers(_query_dense, _key_dense, _value_dense, _output_dense); + } + + EinsumDense _get_dense(string equation, Shape output_shape, string bias_axes, string name) + => new EinsumDense(new EinsumDenseArgs() + { + Equation = equation, + OutputShape = output_shape, + BiasAxes = bias_axes, + Name = name, + KernelInitializer = this.args.KernelInitializer, + BiasInitializer = this.args.BiasInitializer, + KernelRegularizer = this.args.KernelRegularizer, + BiasRegularizer = this.args.BiasRegularizer, + KernelConstraint = this.args.KernelConstraint, + BiasConstraint = this.args.BiasConstraint + }); + + EinsumDense _build_output_dense(int free_dims, string name) + { + if (this.args.OutputShape == null) this.args.OutputShape = new(this._query_shape[-1]); + var (einsum_equation, bias_axes, output_rank) = _build_proj_equation( + free_dims, bound_dims: 2, output_dims: len(this.args.OutputShape)); + return _get_dense(einsum_equation, + _get_output_shape(output_rank - 1, this.args.OutputShape), + this.args.UseBias ? bias_axes : null, + name); + } + + void _build_attention(int rank) + { + if (this.args.AttentionAxis == null) + this.args.AttentionAxis = new(range(1, rank - 2).ToArray()); + int attn_scores_rank; + (this._dot_product_equation, this._combine_equation, attn_scores_rank) + = _build_attention_equation(rank, this.args.AttentionAxis); + var norm_axes = range(attn_scores_rank - len(this.args.AttentionAxis), + attn_scores_rank).ToArray(); + this._softmax = new Softmax(new SoftmaxArgs { axis = norm_axes }); + this._dropout_layer = new Dropout(new DropoutArgs { Rate = this.args.Dropout }); + } + + Tensor _masked_softmax(Tensor attention_scores, Tensor attention_mask = null) + { + if(attention_mask != null) + { + var mask_expansion_axis = -len(this.args.AttentionAxis) * 2 - 1; + for (int i = 0; i < len(attention_scores.shape) - len(attention_mask.shape); i++) + attention_mask = tf.expand_dims(attention_mask, axis: mask_expansion_axis); + } + return this._softmax.Apply(attention_mask == null ? attention_scores : (attention_scores, attention_mask)); + } + + public Tensors _compute_attention( + Tensor query, + Tensor key, + Tensor value, + Tensor attention_mask = null, + bool training = false) + { + // Note: Applying scalar multiply at the smaller end of einsum improves + // XLA performance, but may introduce slight numeric differences in + // the Transformer attention head. + query = tf.multiply(query, 1d / Math.Sqrt(this.args.KeyDim)); + // Take the dot product between "query" and "key" to get the raw + // attention scores. + var attention_scores = tf.linalg.einsum(this._dot_product_equation, (key, query)); + attention_scores = this._masked_softmax(attention_scores, attention_mask); + // This is actually dropping out entire tokens to attend to, which might + // seem a bit unusual, but is taken from the original Transformer paper. + var attention_scores_dropout = this._dropout_layer.Apply(attention_scores, training: training); + // `context_layer` = [B, T, N, H] + var attention_output = tf.linalg.einsum(this._combine_equation, (attention_scores_dropout, value)); + return (attention_output, attention_scores); + } + + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + { + Tensors _inp; + Tensor _mask = null; + + int count = inputs.Count(); + if (count < 2 || count > 5) throw new ValueError( + $"{ this.name } layer accepts inputs list of length from 2 to 5, " + + $"namely [query, value, (key), (attention_mask), (return_attention_scores)]." + + $"Received length: {count}."); + + bool has_bool = inputs[count - 1].dtype == TF_DataType.TF_BOOL; + bool return_attention_scores = false; + if (has_bool) + { + return_attention_scores = (bool)inputs[count - 1]; + count--; + } + + switch (count) + { + case 2: + _inp = (inputs[0], inputs[1]); + break; + case 3: + if (inputs[2].shape[-1] != inputs[0].shape[-1]) + _inp = new[] { inputs[0], inputs[1], inputs[2] }; + else + { + _inp = (inputs[0], inputs[1]); + _mask = inputs[2]; + } + break; + case 4: + _inp = new[] { inputs[0], inputs[1], inputs[2] }; + _mask = inputs[3]; + break; + default: + throw new ValueError(); //TODO:Add discriptions for this err + } + + return call(_inp, _mask, training, return_attention_scores); + } + + protected Tensors call(Tensors inputs, + Tensor attention_mask, + bool? training = null, + bool return_attention_scores = false) + { + var (query, value, key) = (inputs[0], inputs[1], inputs.Length == 3 ? inputs[2] : null); + if (!this._built_from_signature) + this._build_from_signature(query: query, value: value, key: key); + if (key == null) + key = value; + + // TODO: Add RaggedTensor support + //var query_is_ragged = query is tf.RaggedTensor; + //if (query_is_ragged) + //{ + // var query_lengths = query.nested_row_lengths(); + // query = query.to_tensor(); + //} + //var key_is_ragged = key is tf.RaggedTensor; + //var value_is_ragged = value is tf.RaggedTensor; + //if (key_is_ragged && value_is_ragged) + //{ + // // Ensure they have the same shape. + // var bounding_shape = tf.math.maximum(key.bounding_shape(), value.bounding_shape()); + // key = key.to_tensor(shape: bounding_shape); + // value = value.to_tensor(shape: bounding_shape); + //} + //else if (key_is_ragged) + //{ + // key = key.to_tensor(shape: tf.shape(value)); + //} + //else if (value_is_ragged) + //{ + // value = value.to_tensor(shape: tf.shape(key)); + //} + + // N = `num_attention_heads` + // H = `size_per_head` + // `query` = [B, T, N ,H] + query = this._query_dense.Apply(query); + // `key` = [B, S, N, H] + key = this._key_dense.Apply(key); + // `value` = [B, S, N, H] + value = this._value_dense.Apply(value); + var (attention_output, attention_scores) = this._compute_attention(query, key, value, attention_mask, training ?? false); + attention_output = this._output_dense.Apply(attention_output); + + //if (query_is_ragged) + //{ + // attention_output = tf.RaggedTensor.from_tensor(attention_output, lengths: query_lengths); + //} + + if (return_attention_scores) + return (attention_output, attention_scores); + return attention_output; + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.Attention.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.Attention.cs index 4175b458..5effd175 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.Attention.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.Attention.cs @@ -21,5 +21,36 @@ namespace Tensorflow.Keras.Layers causal = causal, dropout = dropout }); + public MultiHeadAttention MultiHeadAttention(int num_heads, + int key_dim, + int? value_dim = null, + float dropout = 0f, + bool use_bias = true, + Shape output_shape = null, + Shape attention_axes = null, + IInitializer kernel_initializer = null, + IInitializer bias_initializer = null, + IRegularizer kernel_regularizer = null, + IRegularizer bias_regularizer = null, + IRegularizer activity_regularizer = null, + Action kernel_constraint = null, + Action bias_constraint = null) => + new MultiHeadAttention(new MultiHeadAttentionArgs + { + NumHeads = num_heads, + KeyDim = key_dim, + ValueDim = value_dim, + Dropout = dropout, + UseBias = use_bias, + OutputShape = output_shape, + AttentionAxis = attention_axes, + KernelInitializer = kernel_initializer ?? tf.glorot_uniform_initializer, + BiasInitializer = bias_initializer ?? tf.zeros_initializer, + KernelRegularizer = kernel_regularizer, + BiasRegularizer = bias_regularizer, + ActivityRegularizer = activity_regularizer, + KernelConstraint = kernel_constraint, + BiasConstraint = bias_constraint, + }); } } \ No newline at end of file diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs index 0807b87c..54ac3795 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs @@ -15,45 +15,6 @@ namespace TensorFlowNET.Keras.UnitTest public class AttentionTest : EagerModeTestBase { #region BaseDenseAttention - [TestMethod] - public void test_one_dim_with_mask() - { - // Scores tensor of shape [1, 1, 1] - var scores = np.array(new[, ,] { { { 1.1f } } }, dtype: np.float32); - // Value tensor of shape [1, 1, 1] - var v = np.array(new[, ,] { { { 1.6f } } }, dtype: np.float32); - // Scores mask tensor of shape [1, 1, 1] - var scores_mask = np.array(new[, ,] { { { true } } }, dtype: np.@bool); - var _tup_1 = new BaseDenseAttention(new())._apply_scores(scores: scores, value: v, scores_mask: scores_mask); - var actual = _tup_1.Item1; - var actual_scores = _tup_1.Item2; - // Expected softmax_scores = [[[1]]] - var expected_scores = np.array(new[, ,] { { { 1f } } }, dtype: np.float32); - Assert.AreEqual(expected_scores, actual_scores.numpy()); - // Expected tensor of shape [1, 1, 1]. - // expected000 = softmax_scores[0, 0] * 1.6 = 1.6 - var expected = np.array(new[, ,] { { { 1.6f } } }, dtype: np.float32); - Assert.AreEqual(expected, actual.numpy()); - } - - [TestMethod] - public void test_one_dim_no_mask() - { - // Scores tensor of shape [1, 1, 1] - var scores = np.array(new[, ,] { { { 1.1f } } }, dtype: np.float32); - // Value tensor of shape [1, 1, 1] - var v = np.array(new[, ,] { { { 1.6f } } }, dtype: np.float32); - var _tup_1 = new BaseDenseAttention(new())._apply_scores(scores: scores, value: v); - var actual = _tup_1.Item1; - var actual_scores = _tup_1.Item2; - // Expected softmax_scores = [[[1]]] - var expected_scores = np.array(new[, ,] { { { 1f } } }, dtype: np.float32); - Assert.AreEqual(expected_scores, actual_scores.numpy()); - // Expected tensor of shape [1, 1, 1]. - // expected000 = softmax_scores[0, 0] * 1.6 = 1.6 - var expected = np.array(new[, ,] { { { 1.6f } } }, dtype: np.float32); - Assert.AreEqual(expected, actual.numpy()); - } [TestMethod] public void test_multi_dim_with_mask() @@ -81,35 +42,6 @@ namespace TensorFlowNET.Keras.UnitTest var expected = np.array(new[, ,] { { { 1.3579528f } } }, dtype: np.float32); Assert.AreEqual(expected, actual.numpy()); } - - [TestMethod] - public void test_multi_dim_no_mask() - { - // Scores tensor of shape [1, 1, 3] - var scores = np.array(new[, ,] { { { 1f, 0f, 1f } } }, dtype: np.float32); - // Value tensor of shape [1, 3, 1] - var v = np.array(new[, ,] { { { 1.6f }, { 0.7f }, { -0.8f } } }, dtype: np.float32); - var _tup_1 = new BaseDenseAttention(new())._apply_scores(scores: scores, value: v); - var actual = _tup_1.Item1; - var actual_scores = _tup_1.Item2; - // Expected softmax_scores = softmax(scores). - // => softmax_scores000 = exp(1)/(exp(1) + exp(0) + exp(1)) - // = 0.42231879825 - // softmax_scores001 = exp(0)/(exp(1) + exp(0) + exp(1)) - // = 0.15536240349 - // softmax_scores002 = exp(1)/(exp(1) + exp(0) + exp(1)) - // = 0.42231879825 - //Actually the output is 0.42231882, 0.15536241, 0.42231882 - var expected_scores = np.array(new[, ,] { { { 0.42231882f, 0.15536241f, 0.42231882f } } }, dtype: np.float32); - Assert.AreEqual(expected_scores, actual_scores.numpy()); - // Expected tensor of shape [1, 1, 1]. - // expected000 = 0.42231879825 * 1.6 + 0.15536240349 * 0.7 - // - 0.42231879825 * 0.8 - // = 0.44660872104 - //Actually the output is 0.44660875 - var expected = np.array(new[, ,] { { { 0.44660875f } } }, dtype: np.float32); - Assert.AreEqual(expected, actual.numpy()); - } [TestMethod] public void test_one_dim_batch_size_two() @@ -132,101 +64,10 @@ namespace TensorFlowNET.Keras.UnitTest var expected = np.array(new[, ,] { { { 1.6f } }, { { 2.6f } } }, dtype: np.float32); Assert.AreEqual(expected, actual.numpy()); } - - [TestMethod] - public void test_shape_with_dropout() - { - // scores: Scores float tensor of shape `[batch_size, tq, tv]`. - // value: Value tensor of shape `[batch_size, tv, dim]`. - var batch_size = 4; - var tq = 5; - var tv = 6; - var dim = 7; - var scores = np.ones((batch_size, tq, tv)); - var value = np.ones((batch_size, tv, dim)); - var _tup_1 = new BaseDenseAttention(new BaseDenseAttentionArgs { dropout = 0.1f }) - ._apply_scores(scores: scores, value: value, training: false); - var actual = _tup_1.Item1; - var actual_scores = _tup_1.Item2; - // Expected Tensor of shape `[batch_size, tq, tv]`. - var expected_scores_shape = new[] { - batch_size, - tq, - tv - }; - Assert.AreEqual(expected_scores_shape, tf.shape(actual_scores).numpy()); - // Expected Tensor of shape `[batch_size, tq, dim]`. - var expected_shape = new[] { - batch_size, - tq, - dim - }; - Assert.AreEqual(expected_shape, tf.shape(actual).numpy()); - } #endregion // ------------------------------------------------------------------ #region Attention - [TestMethod] - public void test_example() - { - //Variable-length int sequences. - var query_input = keras.Input((1000), dtype: TF_DataType.TF_INT32); - var value_input = keras.Input((1000), dtype: TF_DataType.TF_INT32); - // Embedding lookup. - var token_embedding = keras.layers.Embedding(input_dim: 1000, output_dim: 64); - // Query embeddings of shape [batch_size, Tq, dimension]. - var query_embeddings = token_embedding.Apply(query_input); - // Value embeddings of shape [batch_size, Tv, dimension]. - var value_embeddings = token_embedding.Apply(value_input); - // CNN layer. - var cnn_layer = keras.layers.Conv1D( - filters: 100, - kernel_size: 4, - // Use 'same' padding so outputs have the same shape as inputs. - padding: "same", - activation: "relu"); - var cnn_layer2 = keras.layers.Conv1D( - filters: 100, - kernel_size: 4, - // Use 'same' padding so outputs have the same shape as inputs. - padding: "same", - activation: "relu"); - // Query encoding of shape [batch_size, Tq, filters]. - var query_seq_encoding = cnn_layer.Apply(query_embeddings); - // Value encoding of shape [batch_size, Tv, filters]. - var value_seq_encoding = cnn_layer2.Apply(value_embeddings); - // Query-value attention of shape [batch_size, Tq, filters]. - var query_value_attention_seq = keras.layers.Attention().Apply( - (query_seq_encoding, value_seq_encoding)); - // Reduce over the sequence axis to produce encodings of shape - // [batch_size, filters]. - var query_encoding = keras.layers.GlobalAveragePooling1D().Apply( - query_seq_encoding); - var query_value_attention = keras.layers.GlobalAveragePooling1D().Apply( - query_value_attention_seq); - // Concatenate query and document encodings to produce a DNN input layer. - var input_layer = keras.layers.Concatenate().Apply( - (query_encoding, query_value_attention)); - // Add DNN layers, and create Model. - // ... - } - [TestMethod] - public void test_calculate_scores_one_dim() - { - // Query tensor of shape [1, 1, 1] - var q = np.array(new[,,] { { { 1.1f } } }, dtype: np.float32); - // Key tensor of shape [1, 1, 1] - var k = np.array(new[,,] { { { 1.6f } } }, dtype: np.float32); - var attention_layer = keras.layers.Attention(); - //attention_layer.build((1)); - var actual = attention_layer._calculate_scores(query: q, key: k); - // Expected tensor of shape [1, 1, 1]. - // expected000 = 1.1*1.6 = 1.76 - // Actually the output is 1.7600001 - var expected = np.array(new[,,] { { { 1.7600001f } } }, dtype: np.float32); - Assert.AreEqual(expected, actual.numpy()); - } [TestMethod] public void test_calculate_scores_multi_dim() @@ -305,6 +146,29 @@ namespace TensorFlowNET.Keras.UnitTest Assert.AreEqual(expected, actual.numpy()); } #endregion + // ------------------------------------------------------------------ + #region MultiHeadAttention + [TestMethod] + public void test_masked_attention() + { + var query = keras.Input(shape: (4, 8)); + var value = keras.Input(shape: (2, 8)); + var mask_tensor = keras.Input(shape:(4, 2)); + var attention_layer = keras.layers.MultiHeadAttention(num_heads: 2, key_dim: 2); + attention_layer.Apply(new[] { query, value, mask_tensor }); + + var from_data = 10 * np.random.randn(3, 4, 8); + var to_data = 10 * np.random.randn(3, 2, 8); + + var mask_data = np.random.randint(2, size: (3, 4, 2)); + var masked_output_data = attention_layer.Apply(new[] { from_data, to_data, mask_data }); + + var null_mask_data = np.ones((3, 4, 2)); + var unmasked_output_data = attention_layer.Apply(new[] { from_data, to_data, null_mask_data }); + + Assert.AreNotEqual(masked_output_data, unmasked_output_data); + } + #endregion } } \ No newline at end of file