@@ -0,0 +1,20 @@ | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class AttentionArgs : BaseDenseAttentionArgs | |||
{ | |||
/// <summary> | |||
/// If `true`, will create a scalar variable to scale the attention scores. | |||
/// </summary> | |||
public bool use_scale { get; set; } = false; | |||
/// <summary> | |||
/// Function to use to compute attention scores, one of | |||
/// `{"dot", "concat"}`. `"dot"` refers to the dot product between the query | |||
/// and key vectors. `"concat"` refers to the hyperbolic tangent of the | |||
/// concatenation of the query and key vectors. | |||
/// </summary> | |||
public string score_mode { get; set; } = "dot"; | |||
} | |||
} |
@@ -0,0 +1,20 @@ | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class BaseDenseAttentionArgs : LayerArgs | |||
{ | |||
/// <summary> | |||
/// Boolean. Set to `true` for decoder self-attention. Adds a mask such | |||
/// that position `i` cannot attend to positions `j > i`. This prevents the | |||
/// flow of information from the future towards the past. | |||
/// </summary> | |||
public bool causal { get; set; } = false; | |||
/// <summary> | |||
/// Float between 0 and 1. Fraction of the units to drop for the | |||
/// attention scores. | |||
/// </summary> | |||
public float dropout { get; set; } = 0f; | |||
} | |||
} |
@@ -275,6 +275,15 @@ namespace Tensorflow.Keras.Engine | |||
weights.AddRange(non_trainable_weights); | |||
return weights; | |||
} | |||
set | |||
{ | |||
if (weights.Count() != value.Count()) throw new ValueError( | |||
$"You called `set_weights` on layer \"{this.name}\"" + | |||
$"with a weight list of length {len(value)}, but the layer was " + | |||
$"expecting {len(weights)} weights."); | |||
foreach (var (this_w, v_w) in zip(weights, value)) | |||
this_w.assign(v_w, read_value: true); | |||
} | |||
} | |||
public virtual LayerArgs get_config() | |||
@@ -0,0 +1,159 @@ | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.KerasApi; | |||
using System.Collections; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
/// <summary> | |||
/// Dot-product attention layer, a.k.a. Luong-style attention. | |||
/// Inputs are `query` tensor of shape `[batch_size, Tq, dim]`, `value` tensor of | |||
/// shape `[batch_size, Tv, dim]` and `key` tensor of shape | |||
/// `[batch_size, Tv, dim]`. The calculation follows the steps: | |||
/// <para> | |||
/// 1. Calculate scores with shape `[batch_size, Tq, Tv]` as a `query`-`key` dot | |||
/// product: `scores = tf.matmul(query, key, transpose_b=True)`. | |||
/// </para> | |||
/// <para> | |||
/// 2. Use scores to calculate a distribution with shape | |||
/// `[batch_size, Tq, Tv]`: `distribution = tf.nn.softmax(scores)`. | |||
/// </para> | |||
/// <para> | |||
/// 3. Use `distribution` to create a linear combination of `value` with | |||
/// shape `[batch_size, Tq, dim]`: | |||
/// `return tf.matmul(distribution, value)`. | |||
/// </para> | |||
/// </summary> | |||
/// <example> 0 | |||
/// <code> | |||
/// //Variable-length int sequences. | |||
/// var query_input = keras.Input((1000), dtype: TF_DataType.TF_INT32); | |||
/// var value_input = keras.Input((1000), dtype: TF_DataType.TF_INT32); | |||
/// // Embedding lookup. | |||
/// var token_embedding = keras.layers.Embedding(input_dim: 1000, output_dim: 64); | |||
/// // Query embeddings of shape [batch_size, Tq, dimension]. | |||
/// var query_embeddings = token_embedding.Apply(query_input); | |||
/// // Value embeddings of shape [batch_size, Tv, dimension]. | |||
/// var value_embeddings = token_embedding.Apply(value_input); | |||
/// // CNN layer. | |||
/// var cnn_layer = keras.layers.Conv1D( | |||
/// filters: 100, | |||
/// kernel_size: 4, | |||
/// // Use 'same' padding so outputs have the same shape as inputs. | |||
/// padding: "same"); | |||
/// var cnn_layer2 = keras.layers.Conv1D( | |||
/// filters: 100, | |||
/// kernel_size: 4, | |||
/// // Use 'same' padding so outputs have the same shape as inputs. | |||
/// padding: "same"); | |||
/// // Query encoding of shape [batch_size, Tq, filters]. | |||
/// var query_seq_encoding = cnn_layer.Apply(query_embeddings); | |||
/// // Value encoding of shape [batch_size, Tv, filters]. | |||
/// var value_seq_encoding = cnn_layer.Apply(value_embeddings); | |||
/// // Query-value attention of shape [batch_size, Tq, filters]. | |||
/// var query_value_attention_seq = keras.layers.Attention().Apply( | |||
/// (query_seq_encoding, value_seq_encoding)); | |||
/// // Reduce over the sequence axis to produce encodings of shape | |||
/// // [batch_size, filters]. | |||
/// var query_encoding = keras.layers.GlobalAveragePooling1D().Apply( | |||
/// query_seq_encoding); | |||
/// var query_value_attention = keras.layers.GlobalAveragePooling1D().Apply( | |||
/// query_value_attention_seq); | |||
/// // Concatenate query and document encodings to produce a DNN input layer. | |||
/// var input_layer = keras.layers.Concatenate().Apply( | |||
/// (query_encoding, query_value_attention)); | |||
/// // Add DNN layers, and create Model. | |||
/// // ... | |||
/// </code> | |||
/// </example> | |||
public class Attention : BaseDenseAttention | |||
{ | |||
public IVariableV1 concat_score_weight; | |||
public IVariableV1 scale; | |||
AttentionArgs args; | |||
string score_mode { get => args.score_mode; } | |||
bool use_scale { get => args.use_scale; } | |||
public Attention(AttentionArgs args) : base(args) | |||
{ | |||
this.args = args; | |||
if (!new List<string> { | |||
"dot", | |||
"concat" | |||
}.Contains(this.score_mode)) | |||
throw new ValueError("Received: score_mode={score_mode}. Acceptable values are: [\"dot\", \"concat\"]"); | |||
} | |||
// Creates variable when `use_scale` is True or `score_mode` is `concat`. | |||
protected override void build(Tensors inputs) { | |||
if (this.use_scale) | |||
this.scale = this.add_weight(name: "scale", | |||
shape: 1, | |||
initializer: tf.ones_initializer, | |||
dtype: this.DType, | |||
trainable: true); | |||
else | |||
this.scale = null; | |||
if (this.score_mode == "concat") | |||
this.concat_score_weight = this.add_weight(name: "concat_score_weight", | |||
shape: 1, | |||
initializer: tf.ones_initializer, | |||
dtype: this.DType, | |||
trainable: true); | |||
else | |||
this.concat_score_weight = null; | |||
base.build(inputs); | |||
} | |||
/// <summary> | |||
/// Calculates attention scores as a query-key dot product. | |||
/// </summary> | |||
/// <param name="query">query: Query tensor of shape `[batch_size, Tq, dim]`.</param> | |||
/// <param name="key">key: Key tensor of shape `[batch_size, Tv, dim]`.</param> | |||
/// <returns>Tensor of shape `[batch_size, Tq, Tv]`.</returns> | |||
public override Tensor _calculate_scores(Tensor query, Tensor key) | |||
{ | |||
Tensor scores = null; | |||
if (this.score_mode == "dot") | |||
{ | |||
//scores = tf.matmul(query, key, transpose_b: true); | |||
//scores = tf.matmul(tf.squeeze(query),tf.squeeze(key), transpose_b: true); | |||
scores = tf.linalg.einsum("bij,bkj->bik", (query, key)); | |||
if (this.scale != null) | |||
scores *= this.scale.AsTensor(); | |||
} else if (this.score_mode == "concat") { | |||
// Reshape tensors to enable broadcasting. | |||
// Reshape into [batch_size, Tq, 1, dim]. | |||
var q_reshaped = tf.expand_dims(query, axis: -2); | |||
// Reshape into [batch_size, 1, Tv, dim]. | |||
var k_reshaped = tf.expand_dims(key, axis: -3); | |||
if (this.scale != null) | |||
scores = this.concat_score_weight.AsTensor() * | |||
tf.reduce_sum(tf.tanh(this.scale.AsTensor() * (q_reshaped + k_reshaped)), axis: -1); | |||
else | |||
scores = this.concat_score_weight.AsTensor() * | |||
tf.reduce_sum(tf.tanh(q_reshaped + k_reshaped), axis: -1); | |||
} | |||
return scores; | |||
} | |||
public override LayerArgs get_config() => this.args; | |||
//var config = new Dictionary<object, object> { | |||
// { | |||
// "use_scale", | |||
// this.use_scale}, | |||
// { | |||
// "score_mode", | |||
// this.score_mode}}; | |||
//var base_config = base.get_config(); | |||
//return new dict(base_config.items().ToList() + config.items().ToList()); | |||
} | |||
} |
@@ -0,0 +1,257 @@ | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.KerasApi; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
/// <summary> | |||
/// Base class for attention layers that can be used in sequence DNN/CNN models. | |||
///This file follows the terminology of https://arxiv.org/abs/1706.03762 Figure 2. | |||
///Attention is formed by three tensors: Query, Key and Value. | |||
/// </summary> | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
/// <summary> | |||
/// Base Attention class for Dense networks. | |||
/// This class is suitable for Dense or CNN networks, and not for RNN networks. | |||
/// Implementations of attention mechanisms should inherit from this class, and | |||
/// reuse the `apply_attention_scores()` method. | |||
/// </summary> | |||
public class BaseDenseAttention : Layer | |||
{ | |||
BaseDenseAttentionArgs args; | |||
bool causal { get => args.causal; } | |||
float dropout { get => args.dropout; } | |||
protected bool supports_masking; | |||
public BaseDenseAttention(BaseDenseAttentionArgs args) : base(args) | |||
{ | |||
this.args = args; | |||
this.supports_masking = true; | |||
} | |||
/// <summary> | |||
/// Calculates attention scores. | |||
/// </summary> | |||
/// <param name="query">query: Query tensor of shape `[batch_size, Tq, dim]`.</param> | |||
/// <param name="key">key: Key tensor of shape `[batch_size, Tv, dim]`.</param> | |||
/// <returns>Tensor of shape `[batch_size, Tq, Tv]`.</returns> | |||
public virtual Tensor _calculate_scores(Tensor query, Tensor key) => | |||
throw new NotImplementedException(""); | |||
/// <summary> | |||
/// Applies attention scores to the given value tensor. | |||
/// To use this method in your attention layer, follow the steps: | |||
/// <para> | |||
/// * Use `query` tensor of shape `[batch_size, Tq]` and `key` tensor of shape | |||
/// `[batch_size, Tv]` to calculate the attention `scores`. | |||
/// </para> | |||
/// <para> | |||
/// * Pass `scores` and `value` tensors to this method. The method applies | |||
/// `scores_mask`, calculates `attention_distribution = softmax(scores)`, then | |||
/// returns `matmul(attention_distribution, value). | |||
/// </para> | |||
/// <para> | |||
/// * Apply `query_mask` and return the result. | |||
/// </para> | |||
/// </summary> | |||
/// <param name="scores">Scores float tensor of shape `[batch_size, Tq, Tv]`.</param> | |||
/// <param name="value">Value tensor of shape `[batch_size, Tv, dim]`.</param> | |||
/// <param name="scores_mask"> | |||
/// A boolean mask `Tensor` of shape `[batch_size, 1, Tv]` or | |||
/// [batch_size, Tq, Tv]`. If given, scores at positions where | |||
/// `scores_mask==False` do not contribute to the result. It must contain | |||
/// at least one `True` value in each line along the last dimension. | |||
/// </param> | |||
/// <param name="training"> | |||
/// Boolean indicating whether the layer should behave in | |||
/// training mode (adding dropout) or in inference mode (no dropout). | |||
/// </param> | |||
/// <returns> | |||
/// <para> | |||
/// Tensor of shape `[batch_size, Tq, dim]`. | |||
/// </para> | |||
/// <para> | |||
/// Attention scores after masking and softmax with shape | |||
/// [batch_size, Tq, Tv]`. | |||
/// </para> | |||
/// </returns> | |||
public (Tensor, Tensor) _apply_scores(Tensor scores, | |||
Tensor value, | |||
Tensor scores_mask = null, | |||
bool? training = null) | |||
{ | |||
if (scores_mask != null) | |||
{ | |||
var padding_mask = tf.logical_not(scores_mask); | |||
// Bias so padding positions do not contribute to attention distribution. | |||
// Note 65504. is the max float16 value. | |||
if (scores.dtype == tf.float16) | |||
scores -= 65504f * tf.cast(padding_mask, dtype: scores.dtype); | |||
else | |||
scores -= 1000000000f * tf.cast(padding_mask, dtype: scores.dtype); | |||
} | |||
bool _training; | |||
training ??= false; // TODO: Delete this line when backend.learning_phase is available | |||
if (training == null) | |||
_training = keras.backend.learning_phase() == | |||
Tensorflow.Keras.GraphLearningPhase.train_mode ? | |||
true : false; | |||
else _training = training.Value; | |||
var weights = tf.nn.softmax(scores); | |||
Func<Tensor> dropped_weights = () => tf.nn.dropout(weights, rate: this.dropout); | |||
weights = Tensorflow.Framework.smart_module.smart_cond(_training, dropped_weights, () => tf.identity(weights)); | |||
//return (tf.matmul(weights, value), weights); | |||
return (tf.linalg.einsum("bij,bjk->bik", (weights, value)), weights); | |||
} | |||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
{ | |||
Tensors _inp; | |||
Tensors _mask = null; | |||
int count = inputs.Count(); | |||
if (count < 2 || count > 6) throw new ValueError( | |||
$"{ this.name } layer accepts inputs list of length from 2 to 5, " + | |||
$"namely [query, value, (key), (query_mask), (value_mask), (return_attention_scores)]." + | |||
$"Received length: {count}."); | |||
bool has_bool = inputs[count - 1].dtype == TF_DataType.TF_BOOL; | |||
bool return_attention_scores = false; | |||
if (has_bool) | |||
{ | |||
return_attention_scores = (bool)inputs[count - 1]; | |||
count--; | |||
} | |||
switch (count) | |||
{ | |||
case 2: | |||
_inp = (inputs[0], inputs[1]); | |||
break; | |||
case 3: | |||
_inp = new[] { inputs[0], inputs[1], inputs[2] }; | |||
break; | |||
case 4: | |||
if (inputs[0].shape == inputs[2].shape) | |||
if (inputs[1].shape == inputs[3].shape) | |||
{ | |||
_inp = new[] { inputs[0], inputs[1] }; | |||
_mask = new[] { inputs[2], inputs[3] }; | |||
break; | |||
} | |||
throw new ValueError(); //TODO:Add discriptions for this err | |||
case 5: | |||
_inp = new[] { inputs[0], inputs[1], inputs[2] }; | |||
_mask = (inputs[3], inputs[4]); | |||
break; | |||
default: | |||
throw new ValueError(); //TODO:Add discriptions for this err | |||
} | |||
return call(_inp, _mask, training, return_attention_scores); | |||
} | |||
protected Tensors call(Tensors inputs, Tensors mask = null, bool? training = null, bool return_attention_scores = false) | |||
{ | |||
Tensor causal_mask; | |||
//this._validate_call_args(inputs: inputs, mask: mask); | |||
var q = inputs[0]; | |||
var v = inputs[1]; | |||
var k = inputs.Count() > 2 ? inputs[2] : v; | |||
var q_mask = mask != null ? mask[0] : null; | |||
var v_mask = mask != null ? mask[1] : null; | |||
var scores = this._calculate_scores(query: q, key: k); | |||
if (v_mask != null) | |||
// Mask of shape [batch_size, 1, Tv]. | |||
v_mask = tf.expand_dims(v_mask, axis: -2); | |||
if (this.causal) | |||
{ | |||
// Creates a lower triangular mask, so position i cannot attend to | |||
// positions j>i. This prevents the flow of information from the future | |||
// into the past. | |||
var scores_shape = tf.shape(scores); | |||
// causal_mask_shape = [1, Tq, Tv]. | |||
var causal_mask_shape = tf.concat(new List<Tensor> { | |||
tf.ones_like(tf.slice(scores_shape, new[]{0}, new[]{-2})), | |||
tf.concat(new[]{scores_shape[-2], scores_shape[-1]}, 0) | |||
}, axis: 0); | |||
var _causal_mask_shape = new Shape(causal_mask_shape.ToArray<int>()); | |||
causal_mask = _lower_triangular_mask(_causal_mask_shape); | |||
} | |||
else | |||
causal_mask = null; | |||
var scores_mask = _merge_masks(v_mask, causal_mask); | |||
var (result, attention_scores) = this._apply_scores(scores: scores, value: v, scores_mask: scores_mask, training: training); | |||
if (q_mask != null) | |||
{ | |||
// Mask of shape [batch_size, Tq, 1]. | |||
q_mask = tf.expand_dims(q_mask, axis: -1); | |||
result *= tf.cast(q_mask, dtype: result.dtype); | |||
} | |||
if (return_attention_scores) | |||
return new Tensors(result, attention_scores); | |||
return result; | |||
} | |||
public Tensor compute_mask(Tensors inputs, Tensors mask = null) | |||
{ | |||
this._validate_call_args(inputs: inputs, mask: mask); | |||
if (mask != null) | |||
{ | |||
var q_mask = mask[0]; | |||
if (q_mask == null) | |||
return null; | |||
return tf.convert_to_tensor(q_mask); | |||
} | |||
return null; | |||
} | |||
//public Shape compute_output_shape(Shape input_shape) { | |||
// // return_attention_scores argument of BaseDenseAttention.call method | |||
// // is ignored. Output shape of attention_scores cannot be returned. | |||
// return input_shape[0]; | |||
//} | |||
/// <summary> | |||
/// Validates arguments of the call method. | |||
/// </summary> | |||
public void _validate_call_args(Tensors inputs, Tensors mask) | |||
{ | |||
if (inputs.Count() < 2 || inputs.Count() > 3) | |||
throw new ValueError( | |||
$"{this.name} layer accepts inputs list of length 2 or 3, " + | |||
$"namely [query, value] or [query, value, key]. Received length: {len(inputs)}."); | |||
if (mask != null) | |||
if (mask.Count() < 2 || mask.Count() > inputs.Count()) | |||
throw new ValueError($"{this.name} layer mask must be a list of length 2, " + | |||
$"namely [query_mask, value_mask]. Received length: {len(mask)}."); | |||
} | |||
public static Tensor _lower_triangular_mask(Shape shape) | |||
{ | |||
var row_index = tf.cumsum(tf.ones(shape: shape, dtype: tf.int32), axis: -2); | |||
var col_index = tf.cumsum(tf.ones(shape: shape, dtype: tf.int32), axis: -1); | |||
return tf.greater_equal(row_index, col_index); | |||
} | |||
public static Tensor _merge_masks(Tensor x, Tensor y) | |||
{ | |||
if (x == null) | |||
return y; | |||
if (y == null) | |||
return x; | |||
return tf.logical_and(x, y); | |||
} | |||
public override LayerArgs get_config() => this.args; | |||
} | |||
} |
@@ -0,0 +1,25 @@ | |||
using System; | |||
using Tensorflow.NumPy; | |||
using System.Collections.Generic; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.KerasApi; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
public partial class LayersApi | |||
{ | |||
public Attention Attention(bool use_scale = false, | |||
string score_mode = "dot", | |||
bool causal = false, | |||
float dropout = 0f) => | |||
new Attention(new AttentionArgs | |||
{ | |||
use_scale = use_scale, | |||
score_mode = score_mode, | |||
causal = causal, | |||
dropout = dropout | |||
}); | |||
} | |||
} |
@@ -0,0 +1,310 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System; | |||
using System.Collections.Generic; | |||
using Tensorflow.NumPy; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.KerasApi; | |||
using Tensorflow.Keras.Layers; | |||
using Tensorflow; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Utils; | |||
namespace TensorFlowNET.Keras.UnitTest | |||
{ | |||
[TestClass] | |||
public class AttentionTest : EagerModeTestBase | |||
{ | |||
#region BaseDenseAttention | |||
[TestMethod] | |||
public void test_one_dim_with_mask() | |||
{ | |||
// Scores tensor of shape [1, 1, 1] | |||
var scores = np.array(new[, ,] { { { 1.1f } } }, dtype: np.float32); | |||
// Value tensor of shape [1, 1, 1] | |||
var v = np.array(new[, ,] { { { 1.6f } } }, dtype: np.float32); | |||
// Scores mask tensor of shape [1, 1, 1] | |||
var scores_mask = np.array(new[, ,] { { { true } } }, dtype: np.@bool); | |||
var _tup_1 = new BaseDenseAttention(new())._apply_scores(scores: scores, value: v, scores_mask: scores_mask); | |||
var actual = _tup_1.Item1; | |||
var actual_scores = _tup_1.Item2; | |||
// Expected softmax_scores = [[[1]]] | |||
var expected_scores = np.array(new[, ,] { { { 1f } } }, dtype: np.float32); | |||
Assert.AreEqual(expected_scores, actual_scores.numpy()); | |||
// Expected tensor of shape [1, 1, 1]. | |||
// expected000 = softmax_scores[0, 0] * 1.6 = 1.6 | |||
var expected = np.array(new[, ,] { { { 1.6f } } }, dtype: np.float32); | |||
Assert.AreEqual(expected, actual.numpy()); | |||
} | |||
[TestMethod] | |||
public void test_one_dim_no_mask() | |||
{ | |||
// Scores tensor of shape [1, 1, 1] | |||
var scores = np.array(new[, ,] { { { 1.1f } } }, dtype: np.float32); | |||
// Value tensor of shape [1, 1, 1] | |||
var v = np.array(new[, ,] { { { 1.6f } } }, dtype: np.float32); | |||
var _tup_1 = new BaseDenseAttention(new())._apply_scores(scores: scores, value: v); | |||
var actual = _tup_1.Item1; | |||
var actual_scores = _tup_1.Item2; | |||
// Expected softmax_scores = [[[1]]] | |||
var expected_scores = np.array(new[, ,] { { { 1f } } }, dtype: np.float32); | |||
Assert.AreEqual(expected_scores, actual_scores.numpy()); | |||
// Expected tensor of shape [1, 1, 1]. | |||
// expected000 = softmax_scores[0, 0] * 1.6 = 1.6 | |||
var expected = np.array(new[, ,] { { { 1.6f } } }, dtype: np.float32); | |||
Assert.AreEqual(expected, actual.numpy()); | |||
} | |||
[TestMethod] | |||
public void test_multi_dim_with_mask() | |||
{ | |||
// Scores tensor of shape [1, 1, 3] | |||
var scores = np.array(new[, ,] { { { 1f, 0f, 1f } } }, dtype: np.float32); | |||
// Value tensor of shape [1, 3, 1] | |||
var v = np.array(new[, ,] { { { 1.6f }, { 0.7f }, { -0.8f } } }, dtype: np.float32); | |||
// Scores mask tensor of shape [1, 1, 3] | |||
var scores_mask = np.array(new[, ,] { { { true, true, false } } }, dtype: np.@bool); | |||
var _tup_1 = new BaseDenseAttention(new())._apply_scores(scores: scores, value: v, scores_mask: scores_mask); | |||
var actual = _tup_1.Item1; | |||
var actual_scores = _tup_1.Item2; | |||
// Expected softmax scores = softmax(scores) with zeros in positions where | |||
// v_mask == False. | |||
// => softmax_scores000 = exp(1)/(exp(1) + exp(0)) = 0.73105857863 | |||
// softmax_scores001 = exp(0)/(exp(1) + exp(0)) = 0.26894142137 | |||
// softmax_scores002 = 0 | |||
var expected_scores = np.array(new[, ,] { { { 0.73105857863f, 0.26894142137f, 0f } } }, dtype: np.float32); | |||
Assert.AreEqual(expected_scores, actual_scores.numpy()); | |||
// Expected tensor of shape [1, 1, 1]. | |||
// expected000 = 0.73105857863 * 1.6 + 0.26894142137 * 0.7 - 0 * 0.8 | |||
// = 1.35795272077 | |||
//Actually the output is 1.3579528 | |||
var expected = np.array(new[, ,] { { { 1.3579528f } } }, dtype: np.float32); | |||
Assert.AreEqual(expected, actual.numpy()); | |||
} | |||
[TestMethod] | |||
public void test_multi_dim_no_mask() | |||
{ | |||
// Scores tensor of shape [1, 1, 3] | |||
var scores = np.array(new[, ,] { { { 1f, 0f, 1f } } }, dtype: np.float32); | |||
// Value tensor of shape [1, 3, 1] | |||
var v = np.array(new[, ,] { { { 1.6f }, { 0.7f }, { -0.8f } } }, dtype: np.float32); | |||
var _tup_1 = new BaseDenseAttention(new())._apply_scores(scores: scores, value: v); | |||
var actual = _tup_1.Item1; | |||
var actual_scores = _tup_1.Item2; | |||
// Expected softmax_scores = softmax(scores). | |||
// => softmax_scores000 = exp(1)/(exp(1) + exp(0) + exp(1)) | |||
// = 0.42231879825 | |||
// softmax_scores001 = exp(0)/(exp(1) + exp(0) + exp(1)) | |||
// = 0.15536240349 | |||
// softmax_scores002 = exp(1)/(exp(1) + exp(0) + exp(1)) | |||
// = 0.42231879825 | |||
//Actually the output is 0.42231882, 0.15536241, 0.42231882 | |||
var expected_scores = np.array(new[, ,] { { { 0.42231882f, 0.15536241f, 0.42231882f } } }, dtype: np.float32); | |||
Assert.AreEqual(expected_scores, actual_scores.numpy()); | |||
// Expected tensor of shape [1, 1, 1]. | |||
// expected000 = 0.42231879825 * 1.6 + 0.15536240349 * 0.7 | |||
// - 0.42231879825 * 0.8 | |||
// = 0.44660872104 | |||
//Actually the output is 0.44660875 | |||
var expected = np.array(new[, ,] { { { 0.44660875f } } }, dtype: np.float32); | |||
Assert.AreEqual(expected, actual.numpy()); | |||
} | |||
[TestMethod] | |||
public void test_one_dim_batch_size_two() | |||
{ | |||
// Scores tensor of shape [2, 1, 1] | |||
var scores = np.array(new[, ,] { { { 1.1f } }, { { 2.1f } } }, dtype: np.float32); | |||
// Value tensor of shape [2, 1, 1] | |||
var v = np.array(new[, ,] { { { 1.6f } }, { { 2.6f } } }, dtype: np.float32); | |||
// Scpres mask tensor of shape [2, 1, 1] | |||
var scores_mask = np.array(new[, ,] { { { true } }, { { true } } }, dtype: np.@bool); | |||
var _tup_1 = new BaseDenseAttention(new())._apply_scores(scores: scores, value: v, scores_mask: scores_mask); | |||
var actual = _tup_1.Item1; | |||
var actual_scores = _tup_1.Item2; | |||
// Expected softmax_scores = [[[1]], [[1]]] | |||
var expected_scores = np.array(new[, ,] { { { 1f } }, { { 1f } } }, dtype: np.float32); | |||
Assert.AreEqual(expected_scores, actual_scores.numpy()); | |||
// Expected tensor of shape [2, 1, 1]. | |||
// expected000 = softmax_scores[0, 0] * 1.6 = 1.6 | |||
// expected100 = softmax_scores[1, 0] * 2.6 = 2.6 | |||
var expected = np.array(new[, ,] { { { 1.6f } }, { { 2.6f } } }, dtype: np.float32); | |||
Assert.AreEqual(expected, actual.numpy()); | |||
} | |||
[TestMethod] | |||
public void test_shape_with_dropout() | |||
{ | |||
// scores: Scores float tensor of shape `[batch_size, tq, tv]`. | |||
// value: Value tensor of shape `[batch_size, tv, dim]`. | |||
var batch_size = 4; | |||
var tq = 5; | |||
var tv = 6; | |||
var dim = 7; | |||
var scores = np.ones((batch_size, tq, tv)); | |||
var value = np.ones((batch_size, tv, dim)); | |||
var _tup_1 = new BaseDenseAttention(new BaseDenseAttentionArgs { dropout = 0.1f }) | |||
._apply_scores(scores: scores, value: value, training: false); | |||
var actual = _tup_1.Item1; | |||
var actual_scores = _tup_1.Item2; | |||
// Expected Tensor of shape `[batch_size, tq, tv]`. | |||
var expected_scores_shape = new[] { | |||
batch_size, | |||
tq, | |||
tv | |||
}; | |||
Assert.AreEqual(expected_scores_shape, tf.shape(actual_scores).numpy()); | |||
// Expected Tensor of shape `[batch_size, tq, dim]`. | |||
var expected_shape = new[] { | |||
batch_size, | |||
tq, | |||
dim | |||
}; | |||
Assert.AreEqual(expected_shape, tf.shape(actual).numpy()); | |||
} | |||
#endregion | |||
// ------------------------------------------------------------------ | |||
#region Attention | |||
[TestMethod] | |||
public void test_example() | |||
{ | |||
//Variable-length int sequences. | |||
var query_input = keras.Input((1000), dtype: TF_DataType.TF_INT32); | |||
var value_input = keras.Input((1000), dtype: TF_DataType.TF_INT32); | |||
// Embedding lookup. | |||
var token_embedding = keras.layers.Embedding(input_dim: 1000, output_dim: 64); | |||
// Query embeddings of shape [batch_size, Tq, dimension]. | |||
var query_embeddings = token_embedding.Apply(query_input); | |||
// Value embeddings of shape [batch_size, Tv, dimension]. | |||
var value_embeddings = token_embedding.Apply(value_input); | |||
// CNN layer. | |||
var cnn_layer = keras.layers.Conv1D( | |||
filters: 100, | |||
kernel_size: 4, | |||
// Use 'same' padding so outputs have the same shape as inputs. | |||
padding: "same", | |||
activation: "relu"); | |||
var cnn_layer2 = keras.layers.Conv1D( | |||
filters: 100, | |||
kernel_size: 4, | |||
// Use 'same' padding so outputs have the same shape as inputs. | |||
padding: "same", | |||
activation: "relu"); | |||
// Query encoding of shape [batch_size, Tq, filters]. | |||
var query_seq_encoding = cnn_layer.Apply(query_embeddings); | |||
// Value encoding of shape [batch_size, Tv, filters]. | |||
var value_seq_encoding = cnn_layer2.Apply(value_embeddings); | |||
// Query-value attention of shape [batch_size, Tq, filters]. | |||
var query_value_attention_seq = keras.layers.Attention().Apply( | |||
(query_seq_encoding, value_seq_encoding)); | |||
// Reduce over the sequence axis to produce encodings of shape | |||
// [batch_size, filters]. | |||
var query_encoding = keras.layers.GlobalAveragePooling1D().Apply( | |||
query_seq_encoding); | |||
var query_value_attention = keras.layers.GlobalAveragePooling1D().Apply( | |||
query_value_attention_seq); | |||
// Concatenate query and document encodings to produce a DNN input layer. | |||
var input_layer = keras.layers.Concatenate().Apply( | |||
(query_encoding, query_value_attention)); | |||
// Add DNN layers, and create Model. | |||
// ... | |||
} | |||
[TestMethod] | |||
public void test_calculate_scores_one_dim() | |||
{ | |||
// Query tensor of shape [1, 1, 1] | |||
var q = np.array(new[,,] { { { 1.1f } } }, dtype: np.float32); | |||
// Key tensor of shape [1, 1, 1] | |||
var k = np.array(new[,,] { { { 1.6f } } }, dtype: np.float32); | |||
var attention_layer = keras.layers.Attention(); | |||
//attention_layer.build((1)); | |||
var actual = attention_layer._calculate_scores(query: q, key: k); | |||
// Expected tensor of shape [1, 1, 1]. | |||
// expected000 = 1.1*1.6 = 1.76 | |||
// Actually the output is 1.7600001 | |||
var expected = np.array(new[,,] { { { 1.7600001f } } }, dtype: np.float32); | |||
Assert.AreEqual(expected, actual.numpy()); | |||
} | |||
[TestMethod] | |||
public void test_calculate_scores_multi_dim() | |||
{ | |||
// Query tensor of shape [1, 2, 4] | |||
var q = np.array(new[, ,] { { | |||
{ 1f, 1.1f, 1.2f, 1.3f }, | |||
{ 2f, 2.1f, 2.2f, 2.3f } | |||
} }, dtype: np.float32); | |||
// Key tensor of shape [1, 3, 4] | |||
var k = np.array(new[, ,] { { | |||
{ 1.5f, 1.6f, 1.7f, 1.8f }, | |||
{ 2.5f, 2.6f, 2.7f, 2.8f }, | |||
{ 3.5f, 3.6f, 3.7f, 3.8f } | |||
} }, dtype: np.float32); | |||
var attention_layer = keras.layers.Attention(); | |||
//attention_layer.build(((1, 2, 4), (1, 3, 4))); | |||
var actual = attention_layer._calculate_scores(query: q, key: k); | |||
// Expected tensor of shape [1, 2, 3]. | |||
// expected000 = 1.*1.5+1.1*1.6+1.2*1.7+1.3*1.8 = 7.64 | |||
// expected001 = 1.*2.5+1.1*2.6+1.2*2.7+1.3*2.8 = 12.24 | |||
// expected002 = 1.*3.5+1.1*3.6+1.2*3.7+1.3*3.8 = 16.84 | |||
// expected010 = 2.*1.5+2.1*1.6+2.2*1.7+2.3*1.8 = 14.24 | |||
// expected011 = 2.*2.5+2.1*2.6+2.2*2.7+2.3*2.8 = 22.84 | |||
// expected012 = 2.*3.5+2.1*3.6+2.2*3.7+2.3*3.8 = 31.44 | |||
// Actually the output000 is 7.6400003, the output012 is 31.439999 | |||
var expected = np.array(new[, ,] { { | |||
{ 7.6400003f, 12.24f, 16.84f }, | |||
{ 14.24f, 22.84f, 31.439999f } | |||
} }, dtype: np.float32); | |||
Assert.AreEqual(expected, actual.numpy()); | |||
} | |||
[TestMethod] | |||
public void test_calculate_scores_multi_dim_concat() | |||
{ | |||
// Query tensor of shape [1, 2, 4] | |||
var q = np.array(new[, ,] { { | |||
{ 1f, 1.1f, 1.2f, 1.3f }, | |||
{ 2f, 2.1f, 2.2f, 2.3f } | |||
} }, dtype: np.float32); | |||
// Key tensor of shape [1, 3, 4] | |||
var k = np.array(new[, ,] { { | |||
{ 1.5f, 1.6f, 1.7f, 1.8f }, | |||
{ 2.5f, 2.6f, 2.7f, 2.8f }, | |||
{ 3.5f, 3.6f, 3.7f, 3.8f } | |||
} }, dtype: np.float32); | |||
var attention_layer = keras.layers.Attention(score_mode: "concat"); | |||
//attention_layer.concat_score_weight = 1; | |||
attention_layer.concat_score_weight = base_layer_utils.make_variable(new VariableArgs() { | |||
Name = "concat_score_weight", | |||
Shape = (1), | |||
DType = TF_DataType.TF_FLOAT, | |||
Getter = base_layer_utils.make_variable, | |||
Overwrite = true, | |||
Initializer = tf.ones_initializer, | |||
Synchronization = VariableSynchronization.Auto, | |||
Aggregation = VariableAggregation.None, | |||
Trainable = true | |||
}); | |||
//attention_layer.build(((1, 2, 4), (1, 3, 4))); | |||
//var actual = keras.backend.get_value(attention_layer._calculate_scores(query: q, key: k)); | |||
var actual = attention_layer._calculate_scores(query: q, key: k); | |||
// pylint:disable=line-too-long | |||
// expected000 = tanh(1.+1.5) + tanh(1.1+1.6) + tanh(1.2+1.7) + tanh(1.3+1.8) = 3.96753427840 | |||
// expected001 = tanh(1.+2.5) + tanh(1.1+2.6) + tanh(1.2+2.7) + tanh(1.3+2.8) = 3.99558784825 | |||
// expected002 = tanh(1.+3.5) + tanh(1.1+3.6) + tanh(1.2+3.7) + tanh(1.3+3.8) = 3.99940254147 | |||
// expected010 = tanh(2.+1.5) + tanh(2.1+1.6) + tanh(2.2+1.7) + tanh(2.3+1.8) = 3.99558784825 | |||
// expected011 = tanh(2.+2.5) + tanh(2.1+2.6) + tanh(2.2+2.7) + tanh(2.3+2.8) = 3.99940254147 | |||
// expected012 = tanh(2.+3.5) + tanh(2.1+3.6) + tanh(2.2+3.7) + tanh(2.3+3.8) = 3.99991913657 | |||
//Actually the output012 is 3.9999194 | |||
var expected = np.array(new[, ,] { { | |||
{ 3.96753427840f, 3.99558784825f, 3.99940254147f }, | |||
{ 3.99558784825f, 3.99940254147f, 3.9999194f } | |||
} }, dtype: np.float32); | |||
Assert.AreEqual(expected, actual.numpy()); | |||
} | |||
#endregion | |||
} | |||
} |