You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Attention.cs 7.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. using static Tensorflow.Binding;
  2. using static Tensorflow.KerasApi;
  3. using System.Collections;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using Tensorflow.Keras.ArgsDefinition;
  7. using Tensorflow.Keras.Saving;
  8. namespace Tensorflow.Keras.Layers
  9. {
  10. /// <summary>
  11. /// Dot-product attention layer, a.k.a. Luong-style attention.
  12. /// Inputs are `query` tensor of shape `[batch_size, Tq, dim]`, `value` tensor of
  13. /// shape `[batch_size, Tv, dim]` and `key` tensor of shape
  14. /// `[batch_size, Tv, dim]`. The calculation follows the steps:
  15. /// <para>
  16. /// 1. Calculate scores with shape `[batch_size, Tq, Tv]` as a `query`-`key` dot
  17. /// product: `scores = tf.matmul(query, key, transpose_b=True)`.
  18. /// </para>
  19. /// <para>
  20. /// 2. Use scores to calculate a distribution with shape
  21. /// `[batch_size, Tq, Tv]`: `distribution = tf.nn.softmax(scores)`.
  22. /// </para>
  23. /// <para>
  24. /// 3. Use `distribution` to create a linear combination of `value` with
  25. /// shape `[batch_size, Tq, dim]`:
  26. /// `return tf.matmul(distribution, value)`.
  27. /// </para>
  28. /// </summary>
  29. /// <example> 0
  30. /// <code>
  31. /// //Variable-length int sequences.
  32. /// var query_input = keras.Input((1000), dtype: TF_DataType.TF_INT32);
  33. /// var value_input = keras.Input((1000), dtype: TF_DataType.TF_INT32);
  34. /// // Embedding lookup.
  35. /// var token_embedding = keras.layers.Embedding(input_dim: 1000, output_dim: 64);
  36. /// // Query embeddings of shape [batch_size, Tq, dimension].
  37. /// var query_embeddings = token_embedding.Apply(query_input);
  38. /// // Value embeddings of shape [batch_size, Tv, dimension].
  39. /// var value_embeddings = token_embedding.Apply(value_input);
  40. /// // CNN layer.
  41. /// var cnn_layer = keras.layers.Conv1D(
  42. /// filters: 100,
  43. /// kernel_size: 4,
  44. /// // Use 'same' padding so outputs have the same shape as inputs.
  45. /// padding: "same");
  46. /// var cnn_layer2 = keras.layers.Conv1D(
  47. /// filters: 100,
  48. /// kernel_size: 4,
  49. /// // Use 'same' padding so outputs have the same shape as inputs.
  50. /// padding: "same");
  51. /// // Query encoding of shape [batch_size, Tq, filters].
  52. /// var query_seq_encoding = cnn_layer.Apply(query_embeddings);
  53. /// // Value encoding of shape [batch_size, Tv, filters].
  54. /// var value_seq_encoding = cnn_layer.Apply(value_embeddings);
  55. /// // Query-value attention of shape [batch_size, Tq, filters].
  56. /// var query_value_attention_seq = keras.layers.Attention().Apply(
  57. /// (query_seq_encoding, value_seq_encoding));
  58. /// // Reduce over the sequence axis to produce encodings of shape
  59. /// // [batch_size, filters].
  60. /// var query_encoding = keras.layers.GlobalAveragePooling1D().Apply(
  61. /// query_seq_encoding);
  62. /// var query_value_attention = keras.layers.GlobalAveragePooling1D().Apply(
  63. /// query_value_attention_seq);
  64. /// // Concatenate query and document encodings to produce a DNN input layer.
  65. /// var input_layer = keras.layers.Concatenate().Apply(
  66. /// (query_encoding, query_value_attention));
  67. /// // Add DNN layers, and create Model.
  68. /// // ...
  69. /// </code>
  70. /// </example>
  71. public class Attention : BaseDenseAttention
  72. {
  73. public IVariableV1 concat_score_weight;
  74. public IVariableV1 scale;
  75. AttentionArgs args;
  76. string score_mode { get => args.score_mode; }
  77. bool use_scale { get => args.use_scale; }
  78. public Attention(AttentionArgs args) : base(args)
  79. {
  80. this.args = args;
  81. if (!new List<string> {
  82. "dot",
  83. "concat"
  84. }.Contains(this.score_mode))
  85. throw new ValueError("Received: score_mode={score_mode}. Acceptable values are: [\"dot\", \"concat\"]");
  86. }
  87. // Creates variable when `use_scale` is True or `score_mode` is `concat`.
  88. public override void build(KerasShapesWrapper input_shape)
  89. {
  90. if (this.use_scale)
  91. this.scale = this.add_weight(name: "scale",
  92. shape: 1,
  93. initializer: tf.ones_initializer,
  94. dtype: this.DType,
  95. trainable: true);
  96. else
  97. this.scale = null;
  98. if (this.score_mode == "concat")
  99. this.concat_score_weight = this.add_weight(name: "concat_score_weight",
  100. shape: 1,
  101. initializer: tf.ones_initializer,
  102. dtype: this.DType,
  103. trainable: true);
  104. else
  105. this.concat_score_weight = null;
  106. base.build(input_shape);
  107. }
  108. /// <summary>
  109. /// Calculates attention scores as a query-key dot product.
  110. /// </summary>
  111. /// <param name="query">query: Query tensor of shape `[batch_size, Tq, dim]`.</param>
  112. /// <param name="key">key: Key tensor of shape `[batch_size, Tv, dim]`.</param>
  113. /// <returns>Tensor of shape `[batch_size, Tq, Tv]`.</returns>
  114. public override Tensor _calculate_scores(Tensor query, Tensor key)
  115. {
  116. Tensor scores = null;
  117. if (this.score_mode == "dot")
  118. {
  119. //scores = tf.matmul(query, key, transpose_b: true);
  120. //scores = tf.matmul(tf.squeeze(query),tf.squeeze(key), transpose_b: true);
  121. scores = tf.linalg.einsum("bij,bkj->bik", (query, key));
  122. if (this.scale != null)
  123. scores *= this.scale.AsTensor();
  124. } else if (this.score_mode == "concat") {
  125. // Reshape tensors to enable broadcasting.
  126. // Reshape into [batch_size, Tq, 1, dim].
  127. var q_reshaped = tf.expand_dims(query, axis: -2);
  128. // Reshape into [batch_size, 1, Tv, dim].
  129. var k_reshaped = tf.expand_dims(key, axis: -3);
  130. if (this.scale != null)
  131. scores = this.concat_score_weight.AsTensor() *
  132. tf.reduce_sum(tf.tanh(this.scale.AsTensor() * (q_reshaped + k_reshaped)), axis: -1);
  133. else
  134. scores = this.concat_score_weight.AsTensor() *
  135. tf.reduce_sum(tf.tanh(q_reshaped + k_reshaped), axis: -1);
  136. }
  137. return scores;
  138. }
  139. public override IKerasConfig get_config() => this.args;
  140. //var config = new Dictionary<object, object> {
  141. // {
  142. // "use_scale",
  143. // this.use_scale},
  144. // {
  145. // "score_mode",
  146. // this.score_mode}};
  147. //var base_config = base.get_config();
  148. //return new dict(base_config.items().ToList() + config.items().ToList());
  149. }
  150. }