diff --git a/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs b/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs index 1a936aa7..1b82e0a9 100644 --- a/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs +++ b/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs @@ -55,6 +55,9 @@ namespace Tensorflow.Keras.Layers { var target_notation = _CHR_IDX.Substring(0, rank); // `batch_dims` includes the head dim. + // batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,))) + // Since range(rank) is an IEnumerable like (0, 1, 2 ...) whose index is equal to its value + // use IEnumerable.Except instead of np.delete which is unavailable var batch_dims = range(rank).Except(attn_axes.as_int_list().concat(new[] { rank - 1 })); var letter_offset = rank; var source_notation = ""; @@ -68,14 +71,14 @@ namespace Tensorflow.Keras.Layers letter_offset += 1; } } - var product_notation = "".Insert(0, new string((from i in batch_dims - select (char)(int)target_notation[i]).Concat( - - from i in attn_axes.as_int_list() - select (char)(int)target_notation[i]).Concat( - - from i in attn_axes.as_int_list() - select source_notation[i]).ToArray())); + var product_notation = new string((from i in batch_dims + select target_notation[i]).Concat( + + from i in attn_axes.as_int_list() + select target_notation[i]).Concat( + + from i in attn_axes.as_int_list() + select source_notation[i]).ToArray()); var dot_product_equation = $"{source_notation},{target_notation}->{product_notation}"; var attn_scores_rank = product_notation.Count(); var combine_equation = $"{product_notation},{source_notation}->{target_notation}"; @@ -163,7 +166,7 @@ namespace Tensorflow.Keras.Layers this._value_shape.rank - 1, bound_dims: 1, output_dims: 2); this._value_dense = _get_dense(einsum_equation, _get_output_shape(output_rank - 1, - (this.args.NumHeads, this.args.ValueDim ?? -1)), + (this.args.NumHeads, this.args.ValueDim ?? this.args.KeyDim)), this.args.UseBias ? bias_axes : null, "value"); // Builds the attention computations for multi-head dot product attention. @@ -235,7 +238,7 @@ namespace Tensorflow.Keras.Layers // Note: Applying scalar multiply at the smaller end of einsum improves // XLA performance, but may introduce slight numeric differences in // the Transformer attention head. - query = tf.multiply(query, 1d / Math.Sqrt(this.args.KeyDim)); + query = tf.multiply(query, 1f / tf.sqrt(tf.convert_to_tensor((float)this.args.KeyDim))); // Take the dot product between "query" and "key" to get the raw // attention scores. var attention_scores = tf.linalg.einsum(this._dot_product_equation, (key, query)); @@ -273,7 +276,7 @@ namespace Tensorflow.Keras.Layers _inp = (inputs[0], inputs[1]); break; case 3: - if (inputs[2].shape[-1] != inputs[0].shape[-1]) + if (inputs[2].shape[-1] == inputs[1].shape[-1]) _inp = new[] { inputs[0], inputs[1], inputs[2] }; else { diff --git a/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs b/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs index 7f85cb5e..2bd987a7 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs @@ -228,7 +228,7 @@ namespace Tensorflow.Keras.Layers Shape output_shape, bool left_elided = false) { - List bias_shape; + List bias_shape; Dictionary output_dim_map; Dictionary input_dim_map; @@ -275,8 +275,8 @@ namespace Tensorflow.Keras.Layers var input_shape_at_dim = input_shape[input_dim_map[dim]]; if (output_dim_map.TryGetValue(dim, out int index)) { - var output_shape_at_dim = output_shape[index]; - if (output_shape_at_dim != input_shape_at_dim) + var output_shape_at_dim = _output_shape[index]; + if (output_shape_at_dim != -1 && output_shape_at_dim != input_shape_at_dim) throw new ValueError($"Input shape and output shape do not match at shared dimension '{dim}'. " + $"Input shape is {input_shape_at_dim}, " + $"and output shape is {output_shape[output_dim_map[dim]]}."); @@ -299,7 +299,7 @@ namespace Tensorflow.Keras.Layers if (input_dim_map.ContainsKey(dim)) weight_shape.append(input_shape[input_dim_map[dim]]); else if (output_dim_map.ContainsKey(dim)) - weight_shape.append(output_shape[output_dim_map[dim]]); + weight_shape.append(_output_shape[output_dim_map[dim]]); else throw new ValueError($"Weight dimension '{dim}' did not have a match in " + $"either the input spec '{input_spec}' " + $"or the output spec '{output_spec}'. " + @@ -310,7 +310,7 @@ namespace Tensorflow.Keras.Layers { var num_left_elided = left_elided ? elided : 0; var idx_map = output_spec.Select((_char, i) => (i, _char)) - .ToDictionary(_ => _._char, _ => output_shape[_.i + num_left_elided]); + .ToDictionary(_ => _._char, _ => _output_shape[_.i + num_left_elided]); foreach (var _char in bias_axes) if (!output_spec.Contains(_char)) throw new ValueError($"Bias dimension '{_char}' was requested," + @@ -327,7 +327,7 @@ namespace Tensorflow.Keras.Layers else bias_shape = null; return (weight_shape.ToArray(), - (bias_shape ?? new List()).ToArray(), + (bias_shape ?? new List()).ToArray(), _output_shape.ToArray()); } } diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs index 54ac3795..28c5c5bc 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs @@ -151,19 +151,21 @@ namespace TensorFlowNET.Keras.UnitTest [TestMethod] public void test_masked_attention() { + var batch_size = 3; + var query = keras.Input(shape: (4, 8)); var value = keras.Input(shape: (2, 8)); var mask_tensor = keras.Input(shape:(4, 2)); var attention_layer = keras.layers.MultiHeadAttention(num_heads: 2, key_dim: 2); attention_layer.Apply(new[] { query, value, mask_tensor }); - var from_data = 10 * np.random.randn(3, 4, 8); - var to_data = 10 * np.random.randn(3, 2, 8); + var from_data = 10 * np.random.randn(batch_size, 4, 8); + var to_data = 10 * np.random.randn(batch_size, 2, 8); - var mask_data = np.random.randint(2, size: (3, 4, 2)); + var mask_data = np.random.randint(2, size: (batch_size, 4, 2)); var masked_output_data = attention_layer.Apply(new[] { from_data, to_data, mask_data }); - var null_mask_data = np.ones((3, 4, 2)); + var null_mask_data = np.ones((batch_size, 4, 2)); var unmasked_output_data = attention_layer.Apply(new[] { from_data, to_data, null_mask_data }); Assert.AreNotEqual(masked_output_data, unmasked_output_data);