@@ -12,7 +12,8 @@ namespace Tensorflow.Keras.Layers { | |||||
axis = args.axis; | axis = args.axis; | ||||
} | } | ||||
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | ||||
Tensor x = inputs; | |||||
Tensor x = inputs.Length == 2 ? inputs + ((1.0 - tf.cast(inputs[1], inputs.dtype)) * 1e-9) | |||||
: inputs; | |||||
Tensor e = tf.exp(tf.sub(x, tf.reduce_max(x, axis: this.axis, keepdims: true))); | Tensor e = tf.exp(tf.sub(x, tf.reduce_max(x, axis: this.axis, keepdims: true))); | ||||
Tensor s = tf.reduce_sum(e, axis: this.axis, keepdims: true); | Tensor s = tf.reduce_sum(e, axis: this.axis, keepdims: true); | ||||
return tf.div(e, s); | return tf.div(e, s); | ||||
@@ -120,7 +120,7 @@ namespace Tensorflow.Keras.Layers | |||||
int count = inputs.Count(); | int count = inputs.Count(); | ||||
if (count < 2 || count > 6) throw new ValueError( | if (count < 2 || count > 6) throw new ValueError( | ||||
$"{ this.name } layer accepts inputs list of length from 2 to 5, " + | |||||
$"{ this.name } layer accepts inputs list of length from 2 to 6, " + | |||||
$"namely [query, value, (key), (query_mask), (value_mask), (return_attention_scores)]." + | $"namely [query, value, (key), (query_mask), (value_mask), (return_attention_scores)]." + | ||||
$"Received length: {count}."); | $"Received length: {count}."); | ||||
@@ -0,0 +1,352 @@ | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Engine; | |||||
using Tensorflow.NumPy; | |||||
using static Tensorflow.Binding; | |||||
using static Tensorflow.KerasApi; | |||||
using System; | |||||
using System.Linq; | |||||
namespace Tensorflow.Keras.Layers | |||||
{ | |||||
public class MultiHeadAttention : Layer | |||||
{ | |||||
static readonly string _CHR_IDX = "abcdefghijklmnopqrstuvwxyz"; | |||||
MultiHeadAttentionArgs args; | |||||
Shape _query_shape = null; | |||||
Shape _key_shape = null; | |||||
Shape _value_shape = null; | |||||
bool _built_from_signature = false; | |||||
EinsumDense _query_dense = null; | |||||
EinsumDense _key_dense = null; | |||||
EinsumDense _value_dense = null; | |||||
EinsumDense _output_dense = null; | |||||
string _dot_product_equation = ""; | |||||
string _combine_equation = ""; | |||||
Softmax _softmax = null; | |||||
Dropout _dropout_layer = null; | |||||
/// <summary> | |||||
/// Builds einsum equations for the attention computation. | |||||
/// Query, key, value inputs after projection are expected to have the shape as: | |||||
/// `(bs, [non-attention dims], [attention dims], num_heads, channels)`. | |||||
/// `bs` and `[non-attention dims]` are treated as `[batch dims]`. | |||||
/// | |||||
/// <para> | |||||
/// The attention operations can be generalized: | |||||
/// </para> | |||||
/// <para> | |||||
/// (1) Query-key dot product: | |||||
/// `([batch dims], [query attention dims], num_heads, channels), ([batch dims], | |||||
/// [key attention dims], num_heads, channels) -> ([batch dim], | |||||
/// num_heads, [query attention dims], [key attention dims])` | |||||
/// </para><para> | |||||
/// (2) Combination: | |||||
/// `([batch dims], num_heads, [query attention dims], [key attention dims]), | |||||
/// ([batch dims], [value attention dims], num_heads, channels) -> ([batch dims], | |||||
/// [query attention dims], num_heads, channels)` | |||||
/// </para> | |||||
/// </summary> | |||||
/// <param name="rank">Rank of query, key, value tensors.</param> | |||||
/// <param name="attn_axes">List/tuple of axes, `[-1, rank)`, | |||||
/// that attention will be applied to.</param> | |||||
/// <returns></returns> | |||||
public static (string, string, int) _build_attention_equation(int rank, Shape attn_axes) | |||||
{ | |||||
var target_notation = _CHR_IDX.Substring(0, rank); | |||||
// `batch_dims` includes the head dim. | |||||
var batch_dims = range(rank).Except(attn_axes.as_int_list().concat(new[] { rank - 1 })); | |||||
var letter_offset = rank; | |||||
var source_notation = ""; | |||||
for (int i = 0; i < rank; i++) | |||||
{ | |||||
if (batch_dims.Contains(i) || i == rank - 1) | |||||
source_notation += target_notation[i]; | |||||
else | |||||
{ | |||||
source_notation += _CHR_IDX[letter_offset]; | |||||
letter_offset += 1; | |||||
} | |||||
} | |||||
var product_notation = "".Insert(0, new string((from i in batch_dims | |||||
select (char)(int)target_notation[i]).Concat( | |||||
from i in attn_axes.as_int_list() | |||||
select (char)(int)target_notation[i]).Concat( | |||||
from i in attn_axes.as_int_list() | |||||
select source_notation[i]).ToArray())); | |||||
var dot_product_equation = $"{source_notation},{target_notation}->{product_notation}"; | |||||
var attn_scores_rank = product_notation.Count(); | |||||
var combine_equation = $"{product_notation},{source_notation}->{target_notation}"; | |||||
return (dot_product_equation, combine_equation, attn_scores_rank); | |||||
} | |||||
/// <summary> | |||||
/// Builds an einsum equation for projections inside multi-head attention. | |||||
/// </summary> | |||||
public static (string, string, int) _build_proj_equation(int free_dims, int bound_dims, int output_dims) | |||||
{ | |||||
char _char; | |||||
var input_str = ""; | |||||
var kernel_str = ""; | |||||
var output_str = ""; | |||||
var bias_axes = ""; | |||||
var letter_offset = 0; | |||||
foreach (var i in range(free_dims)) | |||||
{ | |||||
_char = _CHR_IDX[i + letter_offset]; | |||||
input_str += _char; | |||||
output_str += _char; | |||||
} | |||||
letter_offset += free_dims; | |||||
foreach (var i in range(bound_dims)) | |||||
{ | |||||
_char = _CHR_IDX[i + letter_offset]; | |||||
input_str += _char; | |||||
kernel_str += _char; | |||||
} | |||||
letter_offset += bound_dims; | |||||
foreach (var i in range(output_dims)) | |||||
{ | |||||
_char = _CHR_IDX[i + letter_offset]; | |||||
kernel_str += _char; | |||||
output_str += _char; | |||||
bias_axes += _char; | |||||
} | |||||
var equation = $"{input_str},{kernel_str}->{output_str}"; | |||||
return (equation, bias_axes, output_str.Count()); | |||||
} | |||||
static Shape _get_output_shape(int output_rank, Shape known_last_dims) | |||||
=> (from _ in range(output_rank - known_last_dims.rank) | |||||
select -1).Concat(known_last_dims.as_int_list()).ToArray(); | |||||
public MultiHeadAttention(MultiHeadAttentionArgs args) : base(args) | |||||
{ | |||||
this.args = args; | |||||
} | |||||
public void _build_from_signature(Tensor query, Tensor value, Tensor key = null) | |||||
=> this._build_from_signature(query.shape, value.shape, key?.shape); | |||||
public void _build_from_signature(Shape query, Shape value, Shape key = null) | |||||
{ | |||||
this._built_from_signature = true; | |||||
this._query_shape = query; | |||||
this._value_shape = value; | |||||
if (key == null) | |||||
this._key_shape = this._value_shape; | |||||
else | |||||
this._key_shape = key; | |||||
// Any setup work performed only once should happen in an `init_scope` | |||||
// to avoid creating symbolic Tensors that will later pollute any eager | |||||
// operations. | |||||
tf_with(tf.init_scope(), _ => | |||||
{ | |||||
var free_dims = this._query_shape.rank - 1; | |||||
var (einsum_equation, bias_axes, output_rank) = _build_proj_equation( | |||||
free_dims, bound_dims: 1, output_dims: 2); | |||||
this._query_dense = _get_dense(einsum_equation, | |||||
_get_output_shape(output_rank - 1, | |||||
(this.args.NumHeads, this.args.KeyDim)), | |||||
this.args.UseBias ? bias_axes : null, | |||||
"query"); | |||||
(einsum_equation, bias_axes, output_rank) = _build_proj_equation( | |||||
this._key_shape.rank - 1, bound_dims: 1, output_dims: 2); | |||||
this._key_dense = _get_dense(einsum_equation, | |||||
_get_output_shape(output_rank - 1, | |||||
(this.args.NumHeads, this.args.KeyDim)), | |||||
this.args.UseBias ? bias_axes : null, | |||||
"key"); | |||||
(einsum_equation, bias_axes, output_rank) = _build_proj_equation( | |||||
this._value_shape.rank - 1, bound_dims: 1, output_dims: 2); | |||||
this._value_dense = _get_dense(einsum_equation, | |||||
_get_output_shape(output_rank - 1, | |||||
(this.args.NumHeads, this.args.ValueDim ?? -1)), | |||||
this.args.UseBias ? bias_axes : null, | |||||
"value"); | |||||
// Builds the attention computations for multi-head dot product attention. | |||||
// These computations could be wrapped into the keras attention layer once | |||||
// it support mult-head einsum computations. | |||||
this._build_attention(output_rank); | |||||
this._output_dense = _build_output_dense(free_dims, "attention_output"); | |||||
}); | |||||
this.StackLayers(_query_dense, _key_dense, _value_dense, _output_dense); | |||||
} | |||||
EinsumDense _get_dense(string equation, Shape output_shape, string bias_axes, string name) | |||||
=> new EinsumDense(new EinsumDenseArgs() | |||||
{ | |||||
Equation = equation, | |||||
OutputShape = output_shape, | |||||
BiasAxes = bias_axes, | |||||
Name = name, | |||||
KernelInitializer = this.args.KernelInitializer, | |||||
BiasInitializer = this.args.BiasInitializer, | |||||
KernelRegularizer = this.args.KernelRegularizer, | |||||
BiasRegularizer = this.args.BiasRegularizer, | |||||
KernelConstraint = this.args.KernelConstraint, | |||||
BiasConstraint = this.args.BiasConstraint | |||||
}); | |||||
EinsumDense _build_output_dense(int free_dims, string name) | |||||
{ | |||||
if (this.args.OutputShape == null) this.args.OutputShape = new(this._query_shape[-1]); | |||||
var (einsum_equation, bias_axes, output_rank) = _build_proj_equation( | |||||
free_dims, bound_dims: 2, output_dims: len(this.args.OutputShape)); | |||||
return _get_dense(einsum_equation, | |||||
_get_output_shape(output_rank - 1, this.args.OutputShape), | |||||
this.args.UseBias ? bias_axes : null, | |||||
name); | |||||
} | |||||
void _build_attention(int rank) | |||||
{ | |||||
if (this.args.AttentionAxis == null) | |||||
this.args.AttentionAxis = new(range(1, rank - 2).ToArray()); | |||||
int attn_scores_rank; | |||||
(this._dot_product_equation, this._combine_equation, attn_scores_rank) | |||||
= _build_attention_equation(rank, this.args.AttentionAxis); | |||||
var norm_axes = range(attn_scores_rank - len(this.args.AttentionAxis), | |||||
attn_scores_rank).ToArray(); | |||||
this._softmax = new Softmax(new SoftmaxArgs { axis = norm_axes }); | |||||
this._dropout_layer = new Dropout(new DropoutArgs { Rate = this.args.Dropout }); | |||||
} | |||||
Tensor _masked_softmax(Tensor attention_scores, Tensor attention_mask = null) | |||||
{ | |||||
if(attention_mask != null) | |||||
{ | |||||
var mask_expansion_axis = -len(this.args.AttentionAxis) * 2 - 1; | |||||
for (int i = 0; i < len(attention_scores.shape) - len(attention_mask.shape); i++) | |||||
attention_mask = tf.expand_dims(attention_mask, axis: mask_expansion_axis); | |||||
} | |||||
return this._softmax.Apply(attention_mask == null ? attention_scores : (attention_scores, attention_mask)); | |||||
} | |||||
public Tensors _compute_attention( | |||||
Tensor query, | |||||
Tensor key, | |||||
Tensor value, | |||||
Tensor attention_mask = null, | |||||
bool training = false) | |||||
{ | |||||
// Note: Applying scalar multiply at the smaller end of einsum improves | |||||
// XLA performance, but may introduce slight numeric differences in | |||||
// the Transformer attention head. | |||||
query = tf.multiply(query, 1d / Math.Sqrt(this.args.KeyDim)); | |||||
// Take the dot product between "query" and "key" to get the raw | |||||
// attention scores. | |||||
var attention_scores = tf.linalg.einsum(this._dot_product_equation, (key, query)); | |||||
attention_scores = this._masked_softmax(attention_scores, attention_mask); | |||||
// This is actually dropping out entire tokens to attend to, which might | |||||
// seem a bit unusual, but is taken from the original Transformer paper. | |||||
var attention_scores_dropout = this._dropout_layer.Apply(attention_scores, training: training); | |||||
// `context_layer` = [B, T, N, H] | |||||
var attention_output = tf.linalg.einsum(this._combine_equation, (attention_scores_dropout, value)); | |||||
return (attention_output, attention_scores); | |||||
} | |||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
{ | |||||
Tensors _inp; | |||||
Tensor _mask = null; | |||||
int count = inputs.Count(); | |||||
if (count < 2 || count > 5) throw new ValueError( | |||||
$"{ this.name } layer accepts inputs list of length from 2 to 5, " + | |||||
$"namely [query, value, (key), (attention_mask), (return_attention_scores)]." + | |||||
$"Received length: {count}."); | |||||
bool has_bool = inputs[count - 1].dtype == TF_DataType.TF_BOOL; | |||||
bool return_attention_scores = false; | |||||
if (has_bool) | |||||
{ | |||||
return_attention_scores = (bool)inputs[count - 1]; | |||||
count--; | |||||
} | |||||
switch (count) | |||||
{ | |||||
case 2: | |||||
_inp = (inputs[0], inputs[1]); | |||||
break; | |||||
case 3: | |||||
if (inputs[2].shape[-1] != inputs[0].shape[-1]) | |||||
_inp = new[] { inputs[0], inputs[1], inputs[2] }; | |||||
else | |||||
{ | |||||
_inp = (inputs[0], inputs[1]); | |||||
_mask = inputs[2]; | |||||
} | |||||
break; | |||||
case 4: | |||||
_inp = new[] { inputs[0], inputs[1], inputs[2] }; | |||||
_mask = inputs[3]; | |||||
break; | |||||
default: | |||||
throw new ValueError(); //TODO:Add discriptions for this err | |||||
} | |||||
return call(_inp, _mask, training, return_attention_scores); | |||||
} | |||||
protected Tensors call(Tensors inputs, | |||||
Tensor attention_mask, | |||||
bool? training = null, | |||||
bool return_attention_scores = false) | |||||
{ | |||||
var (query, value, key) = (inputs[0], inputs[1], inputs.Length == 3 ? inputs[2] : null); | |||||
if (!this._built_from_signature) | |||||
this._build_from_signature(query: query, value: value, key: key); | |||||
if (key == null) | |||||
key = value; | |||||
// TODO: Add RaggedTensor support | |||||
//var query_is_ragged = query is tf.RaggedTensor; | |||||
//if (query_is_ragged) | |||||
//{ | |||||
// var query_lengths = query.nested_row_lengths(); | |||||
// query = query.to_tensor(); | |||||
//} | |||||
//var key_is_ragged = key is tf.RaggedTensor; | |||||
//var value_is_ragged = value is tf.RaggedTensor; | |||||
//if (key_is_ragged && value_is_ragged) | |||||
//{ | |||||
// // Ensure they have the same shape. | |||||
// var bounding_shape = tf.math.maximum(key.bounding_shape(), value.bounding_shape()); | |||||
// key = key.to_tensor(shape: bounding_shape); | |||||
// value = value.to_tensor(shape: bounding_shape); | |||||
//} | |||||
//else if (key_is_ragged) | |||||
//{ | |||||
// key = key.to_tensor(shape: tf.shape(value)); | |||||
//} | |||||
//else if (value_is_ragged) | |||||
//{ | |||||
// value = value.to_tensor(shape: tf.shape(key)); | |||||
//} | |||||
// N = `num_attention_heads` | |||||
// H = `size_per_head` | |||||
// `query` = [B, T, N ,H] | |||||
query = this._query_dense.Apply(query); | |||||
// `key` = [B, S, N, H] | |||||
key = this._key_dense.Apply(key); | |||||
// `value` = [B, S, N, H] | |||||
value = this._value_dense.Apply(value); | |||||
var (attention_output, attention_scores) = this._compute_attention(query, key, value, attention_mask, training ?? false); | |||||
attention_output = this._output_dense.Apply(attention_output); | |||||
//if (query_is_ragged) | |||||
//{ | |||||
// attention_output = tf.RaggedTensor.from_tensor(attention_output, lengths: query_lengths); | |||||
//} | |||||
if (return_attention_scores) | |||||
return (attention_output, attention_scores); | |||||
return attention_output; | |||||
} | |||||
} | |||||
} |
@@ -21,5 +21,36 @@ namespace Tensorflow.Keras.Layers | |||||
causal = causal, | causal = causal, | ||||
dropout = dropout | dropout = dropout | ||||
}); | }); | ||||
public MultiHeadAttention MultiHeadAttention(int num_heads, | |||||
int key_dim, | |||||
int? value_dim = null, | |||||
float dropout = 0f, | |||||
bool use_bias = true, | |||||
Shape output_shape = null, | |||||
Shape attention_axes = null, | |||||
IInitializer kernel_initializer = null, | |||||
IInitializer bias_initializer = null, | |||||
IRegularizer kernel_regularizer = null, | |||||
IRegularizer bias_regularizer = null, | |||||
IRegularizer activity_regularizer = null, | |||||
Action kernel_constraint = null, | |||||
Action bias_constraint = null) => | |||||
new MultiHeadAttention(new MultiHeadAttentionArgs | |||||
{ | |||||
NumHeads = num_heads, | |||||
KeyDim = key_dim, | |||||
ValueDim = value_dim, | |||||
Dropout = dropout, | |||||
UseBias = use_bias, | |||||
OutputShape = output_shape, | |||||
AttentionAxis = attention_axes, | |||||
KernelInitializer = kernel_initializer ?? tf.glorot_uniform_initializer, | |||||
BiasInitializer = bias_initializer ?? tf.zeros_initializer, | |||||
KernelRegularizer = kernel_regularizer, | |||||
BiasRegularizer = bias_regularizer, | |||||
ActivityRegularizer = activity_regularizer, | |||||
KernelConstraint = kernel_constraint, | |||||
BiasConstraint = bias_constraint, | |||||
}); | |||||
} | } | ||||
} | } |
@@ -15,45 +15,6 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
public class AttentionTest : EagerModeTestBase | public class AttentionTest : EagerModeTestBase | ||||
{ | { | ||||
#region BaseDenseAttention | #region BaseDenseAttention | ||||
[TestMethod] | |||||
public void test_one_dim_with_mask() | |||||
{ | |||||
// Scores tensor of shape [1, 1, 1] | |||||
var scores = np.array(new[, ,] { { { 1.1f } } }, dtype: np.float32); | |||||
// Value tensor of shape [1, 1, 1] | |||||
var v = np.array(new[, ,] { { { 1.6f } } }, dtype: np.float32); | |||||
// Scores mask tensor of shape [1, 1, 1] | |||||
var scores_mask = np.array(new[, ,] { { { true } } }, dtype: np.@bool); | |||||
var _tup_1 = new BaseDenseAttention(new())._apply_scores(scores: scores, value: v, scores_mask: scores_mask); | |||||
var actual = _tup_1.Item1; | |||||
var actual_scores = _tup_1.Item2; | |||||
// Expected softmax_scores = [[[1]]] | |||||
var expected_scores = np.array(new[, ,] { { { 1f } } }, dtype: np.float32); | |||||
Assert.AreEqual(expected_scores, actual_scores.numpy()); | |||||
// Expected tensor of shape [1, 1, 1]. | |||||
// expected000 = softmax_scores[0, 0] * 1.6 = 1.6 | |||||
var expected = np.array(new[, ,] { { { 1.6f } } }, dtype: np.float32); | |||||
Assert.AreEqual(expected, actual.numpy()); | |||||
} | |||||
[TestMethod] | |||||
public void test_one_dim_no_mask() | |||||
{ | |||||
// Scores tensor of shape [1, 1, 1] | |||||
var scores = np.array(new[, ,] { { { 1.1f } } }, dtype: np.float32); | |||||
// Value tensor of shape [1, 1, 1] | |||||
var v = np.array(new[, ,] { { { 1.6f } } }, dtype: np.float32); | |||||
var _tup_1 = new BaseDenseAttention(new())._apply_scores(scores: scores, value: v); | |||||
var actual = _tup_1.Item1; | |||||
var actual_scores = _tup_1.Item2; | |||||
// Expected softmax_scores = [[[1]]] | |||||
var expected_scores = np.array(new[, ,] { { { 1f } } }, dtype: np.float32); | |||||
Assert.AreEqual(expected_scores, actual_scores.numpy()); | |||||
// Expected tensor of shape [1, 1, 1]. | |||||
// expected000 = softmax_scores[0, 0] * 1.6 = 1.6 | |||||
var expected = np.array(new[, ,] { { { 1.6f } } }, dtype: np.float32); | |||||
Assert.AreEqual(expected, actual.numpy()); | |||||
} | |||||
[TestMethod] | [TestMethod] | ||||
public void test_multi_dim_with_mask() | public void test_multi_dim_with_mask() | ||||
@@ -81,35 +42,6 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
var expected = np.array(new[, ,] { { { 1.3579528f } } }, dtype: np.float32); | var expected = np.array(new[, ,] { { { 1.3579528f } } }, dtype: np.float32); | ||||
Assert.AreEqual(expected, actual.numpy()); | Assert.AreEqual(expected, actual.numpy()); | ||||
} | } | ||||
[TestMethod] | |||||
public void test_multi_dim_no_mask() | |||||
{ | |||||
// Scores tensor of shape [1, 1, 3] | |||||
var scores = np.array(new[, ,] { { { 1f, 0f, 1f } } }, dtype: np.float32); | |||||
// Value tensor of shape [1, 3, 1] | |||||
var v = np.array(new[, ,] { { { 1.6f }, { 0.7f }, { -0.8f } } }, dtype: np.float32); | |||||
var _tup_1 = new BaseDenseAttention(new())._apply_scores(scores: scores, value: v); | |||||
var actual = _tup_1.Item1; | |||||
var actual_scores = _tup_1.Item2; | |||||
// Expected softmax_scores = softmax(scores). | |||||
// => softmax_scores000 = exp(1)/(exp(1) + exp(0) + exp(1)) | |||||
// = 0.42231879825 | |||||
// softmax_scores001 = exp(0)/(exp(1) + exp(0) + exp(1)) | |||||
// = 0.15536240349 | |||||
// softmax_scores002 = exp(1)/(exp(1) + exp(0) + exp(1)) | |||||
// = 0.42231879825 | |||||
//Actually the output is 0.42231882, 0.15536241, 0.42231882 | |||||
var expected_scores = np.array(new[, ,] { { { 0.42231882f, 0.15536241f, 0.42231882f } } }, dtype: np.float32); | |||||
Assert.AreEqual(expected_scores, actual_scores.numpy()); | |||||
// Expected tensor of shape [1, 1, 1]. | |||||
// expected000 = 0.42231879825 * 1.6 + 0.15536240349 * 0.7 | |||||
// - 0.42231879825 * 0.8 | |||||
// = 0.44660872104 | |||||
//Actually the output is 0.44660875 | |||||
var expected = np.array(new[, ,] { { { 0.44660875f } } }, dtype: np.float32); | |||||
Assert.AreEqual(expected, actual.numpy()); | |||||
} | |||||
[TestMethod] | [TestMethod] | ||||
public void test_one_dim_batch_size_two() | public void test_one_dim_batch_size_two() | ||||
@@ -132,101 +64,10 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
var expected = np.array(new[, ,] { { { 1.6f } }, { { 2.6f } } }, dtype: np.float32); | var expected = np.array(new[, ,] { { { 1.6f } }, { { 2.6f } } }, dtype: np.float32); | ||||
Assert.AreEqual(expected, actual.numpy()); | Assert.AreEqual(expected, actual.numpy()); | ||||
} | } | ||||
[TestMethod] | |||||
public void test_shape_with_dropout() | |||||
{ | |||||
// scores: Scores float tensor of shape `[batch_size, tq, tv]`. | |||||
// value: Value tensor of shape `[batch_size, tv, dim]`. | |||||
var batch_size = 4; | |||||
var tq = 5; | |||||
var tv = 6; | |||||
var dim = 7; | |||||
var scores = np.ones((batch_size, tq, tv)); | |||||
var value = np.ones((batch_size, tv, dim)); | |||||
var _tup_1 = new BaseDenseAttention(new BaseDenseAttentionArgs { dropout = 0.1f }) | |||||
._apply_scores(scores: scores, value: value, training: false); | |||||
var actual = _tup_1.Item1; | |||||
var actual_scores = _tup_1.Item2; | |||||
// Expected Tensor of shape `[batch_size, tq, tv]`. | |||||
var expected_scores_shape = new[] { | |||||
batch_size, | |||||
tq, | |||||
tv | |||||
}; | |||||
Assert.AreEqual(expected_scores_shape, tf.shape(actual_scores).numpy()); | |||||
// Expected Tensor of shape `[batch_size, tq, dim]`. | |||||
var expected_shape = new[] { | |||||
batch_size, | |||||
tq, | |||||
dim | |||||
}; | |||||
Assert.AreEqual(expected_shape, tf.shape(actual).numpy()); | |||||
} | |||||
#endregion | #endregion | ||||
// ------------------------------------------------------------------ | // ------------------------------------------------------------------ | ||||
#region Attention | #region Attention | ||||
[TestMethod] | |||||
public void test_example() | |||||
{ | |||||
//Variable-length int sequences. | |||||
var query_input = keras.Input((1000), dtype: TF_DataType.TF_INT32); | |||||
var value_input = keras.Input((1000), dtype: TF_DataType.TF_INT32); | |||||
// Embedding lookup. | |||||
var token_embedding = keras.layers.Embedding(input_dim: 1000, output_dim: 64); | |||||
// Query embeddings of shape [batch_size, Tq, dimension]. | |||||
var query_embeddings = token_embedding.Apply(query_input); | |||||
// Value embeddings of shape [batch_size, Tv, dimension]. | |||||
var value_embeddings = token_embedding.Apply(value_input); | |||||
// CNN layer. | |||||
var cnn_layer = keras.layers.Conv1D( | |||||
filters: 100, | |||||
kernel_size: 4, | |||||
// Use 'same' padding so outputs have the same shape as inputs. | |||||
padding: "same", | |||||
activation: "relu"); | |||||
var cnn_layer2 = keras.layers.Conv1D( | |||||
filters: 100, | |||||
kernel_size: 4, | |||||
// Use 'same' padding so outputs have the same shape as inputs. | |||||
padding: "same", | |||||
activation: "relu"); | |||||
// Query encoding of shape [batch_size, Tq, filters]. | |||||
var query_seq_encoding = cnn_layer.Apply(query_embeddings); | |||||
// Value encoding of shape [batch_size, Tv, filters]. | |||||
var value_seq_encoding = cnn_layer2.Apply(value_embeddings); | |||||
// Query-value attention of shape [batch_size, Tq, filters]. | |||||
var query_value_attention_seq = keras.layers.Attention().Apply( | |||||
(query_seq_encoding, value_seq_encoding)); | |||||
// Reduce over the sequence axis to produce encodings of shape | |||||
// [batch_size, filters]. | |||||
var query_encoding = keras.layers.GlobalAveragePooling1D().Apply( | |||||
query_seq_encoding); | |||||
var query_value_attention = keras.layers.GlobalAveragePooling1D().Apply( | |||||
query_value_attention_seq); | |||||
// Concatenate query and document encodings to produce a DNN input layer. | |||||
var input_layer = keras.layers.Concatenate().Apply( | |||||
(query_encoding, query_value_attention)); | |||||
// Add DNN layers, and create Model. | |||||
// ... | |||||
} | |||||
[TestMethod] | |||||
public void test_calculate_scores_one_dim() | |||||
{ | |||||
// Query tensor of shape [1, 1, 1] | |||||
var q = np.array(new[,,] { { { 1.1f } } }, dtype: np.float32); | |||||
// Key tensor of shape [1, 1, 1] | |||||
var k = np.array(new[,,] { { { 1.6f } } }, dtype: np.float32); | |||||
var attention_layer = keras.layers.Attention(); | |||||
//attention_layer.build((1)); | |||||
var actual = attention_layer._calculate_scores(query: q, key: k); | |||||
// Expected tensor of shape [1, 1, 1]. | |||||
// expected000 = 1.1*1.6 = 1.76 | |||||
// Actually the output is 1.7600001 | |||||
var expected = np.array(new[,,] { { { 1.7600001f } } }, dtype: np.float32); | |||||
Assert.AreEqual(expected, actual.numpy()); | |||||
} | |||||
[TestMethod] | [TestMethod] | ||||
public void test_calculate_scores_multi_dim() | public void test_calculate_scores_multi_dim() | ||||
@@ -305,6 +146,29 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
Assert.AreEqual(expected, actual.numpy()); | Assert.AreEqual(expected, actual.numpy()); | ||||
} | } | ||||
#endregion | #endregion | ||||
// ------------------------------------------------------------------ | |||||
#region MultiHeadAttention | |||||
[TestMethod] | |||||
public void test_masked_attention() | |||||
{ | |||||
var query = keras.Input(shape: (4, 8)); | |||||
var value = keras.Input(shape: (2, 8)); | |||||
var mask_tensor = keras.Input(shape:(4, 2)); | |||||
var attention_layer = keras.layers.MultiHeadAttention(num_heads: 2, key_dim: 2); | |||||
attention_layer.Apply(new[] { query, value, mask_tensor }); | |||||
var from_data = 10 * np.random.randn(3, 4, 8); | |||||
var to_data = 10 * np.random.randn(3, 2, 8); | |||||
var mask_data = np.random.randint(2, size: (3, 4, 2)); | |||||
var masked_output_data = attention_layer.Apply(new[] { from_data, to_data, mask_data }); | |||||
var null_mask_data = np.ones((3, 4, 2)); | |||||
var unmasked_output_data = attention_layer.Apply(new[] { from_data, to_data, null_mask_data }); | |||||
Assert.AreNotEqual(masked_output_data, unmasked_output_data); | |||||
} | |||||
#endregion | |||||
} | } | ||||
} | } |