From 4aab86b8768a1653d2d427445e2c8395fc60c06a Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Thu, 27 Apr 2023 02:11:33 +0800 Subject: [PATCH] Fix the error when using layers.Input with unknown batch size. --- src/TensorFlowNET.Core/Operations/math_ops.cs | 112 ++++++++++++++++-- .../Model/ModelBuildTest.cs | 42 +++++++ 2 files changed, 146 insertions(+), 8 deletions(-) create mode 100644 test/TensorFlowNET.Keras.UnitTest/Model/ModelBuildTest.cs diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index a89e7a22..f7b428bb 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -905,13 +905,29 @@ namespace Tensorflow var (a_reshape, a_free_dims, a_free_dims_static) = _tensordot_reshape(a, a_axes); var (b_reshape, b_free_dims, b_free_dims_static) = _tensordot_reshape(b, b_axes, true); var ab_matmul = matmul(a_reshape, b_reshape); - var dims = new List(); - dims.AddRange(a_free_dims); - dims.AddRange(b_free_dims); - if (ab_matmul.shape.Equals(dims)) - return ab_matmul; + if(a_free_dims is int[] a_free_dims_list && b_free_dims is int[] b_free_dims_list) + { + var total_free_dims = a_free_dims_list.Concat(b_free_dims_list).ToArray(); + if (ab_matmul.shape.IsFullyDefined && ab_matmul.shape.as_int_list().SequenceEqual(total_free_dims)) + { + return ab_matmul; + } + else + { + return array_ops.reshape(ab_matmul, ops.convert_to_tensor(total_free_dims), name); + } + } else - return array_ops.reshape(ab_matmul, tf.constant(dims.ToArray()), name: name); + { + var a_free_dims_tensor = ops.convert_to_tensor(a_free_dims, dtype: dtypes.int32); + var b_free_dims_tensor = ops.convert_to_tensor(b_free_dims, dtype: dtypes.int32); + var product = array_ops.reshape(ab_matmul, array_ops.concat(new[] { a_free_dims_tensor, b_free_dims_tensor }, 0), name); + if(a_free_dims_static is not null && b_free_dims_static is not null) + { + product.shape = new Shape(a_free_dims_static.Concat(b_free_dims_static).ToArray()); + } + return product; + } }); } @@ -927,14 +943,42 @@ namespace Tensorflow return (Binding.range(a.shape.ndim - axe, a.shape.ndim).ToArray(), Binding.range(0, axe).ToArray()); } - else + else if(axes.rank == 1) { + if (axes.shape[0] != 2) + { + throw new ValueError($"`axes` must be an integer or have length 2. Received {axes}."); + } (int a_axe, int b_axe) = (axes[0], axes[1]); return (new[] { a_axe }, new[] { b_axe }); } + else if(axes.rank == 2) + { + if (axes.shape[0] != 2) + { + throw new ValueError($"`axes` must be an integer or have length 2. Received {axes}."); + } + int[] a_axes = new int[axes.shape[1]]; + int[] b_axes = new int[axes.shape[1]]; + for(int i = 0; i < a_axes.Length; i++) + { + a_axes[i] = axes[0, i]; + b_axes[i] = axes[1, i]; + if (a_axes[i] == -1 || b_axes[i] == -1) + { + throw new ValueError($"Different number of contraction axes `a` and `b`," + + $"{len(a_axes)} != {len(b_axes)}."); + } + } + return (a_axes, b_axes); + } + else + { + throw new ValueError($"Invalid rank {axes.rank} to make tensor dot."); + } } - static (Tensor, int[], int[]) _tensordot_reshape(Tensor a, int[] axes, bool flipped = false) + static (Tensor, object, int[]) _tensordot_reshape(Tensor a, int[] axes, bool flipped = false) { if (a.shape.IsFullyDefined && isinstance(axes, (typeof(int[]), typeof(Tuple)))) { @@ -977,6 +1021,58 @@ namespace Tensorflow var reshaped_a = array_ops.reshape(a_trans, new_shape); return (reshaped_a, free_dims, free_dims); } + else + { + int[] free_dims_static; + Tensor converted_shape_a, converted_axes, converted_free; + if (a.shape.ndim != -1) + { + var shape_a = a.shape.as_int_list(); + for(int i = 0; i < axes.Length; i++) + { + if (axes[i] < 0) + { + axes[i] += shape_a.Length; + } + } + var free = Enumerable.Range(0, shape_a.Length).Where(i => !axes.Contains(i)).ToArray(); + + var axes_dims = axes.Select(i => shape_a[i]); + var free_dims = free.Select(i => shape_a[i]).ToArray(); + free_dims_static = free_dims; + converted_axes = ops.convert_to_tensor(axes, dtypes.int32, "axes"); + converted_free = ops.convert_to_tensor(free, dtypes.int32, "free"); + converted_shape_a = array_ops.shape(a); + } + else + { + free_dims_static = null; + converted_shape_a = array_ops.shape(a); + var rank_a = array_ops.rank(a); + converted_axes = ops.convert_to_tensor(axes, dtypes.int32, "axes"); + converted_axes = array_ops.where_v2(converted_axes >= 0, converted_axes, converted_axes + rank_a); + (converted_free, var _) = gen_ops.list_diff(gen_math_ops.range(ops.convert_to_tensor(0), rank_a, ops.convert_to_tensor(1)), + converted_axes, dtypes.int32); + } + var converted_free_dims = array_ops.gather(converted_shape_a, converted_free); + var converted_axes_dims = array_ops.gather(converted_shape_a, converted_axes); + var prod_free_dims = reduce_prod(converted_free_dims); + var prod_axes_dims = reduce_prod(converted_axes_dims); + Tensor reshaped_a; + if (flipped) + { + var perm = array_ops.concat(new[] { converted_axes, converted_free }, 0); + var new_shape = array_ops.stack(new[] { prod_axes_dims, prod_free_dims }); + reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape); + } + else + { + var perm = array_ops.concat(new[] { converted_free, converted_axes }, 0); + var new_shape = array_ops.stack(new[] { prod_free_dims, prod_axes_dims }); + reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape); + } + return (reshaped_a, converted_free_dims, free_dims_static); + } throw new NotImplementedException("_tensordot_reshape"); } diff --git a/test/TensorFlowNET.Keras.UnitTest/Model/ModelBuildTest.cs b/test/TensorFlowNET.Keras.UnitTest/Model/ModelBuildTest.cs new file mode 100644 index 00000000..3b158279 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Model/ModelBuildTest.cs @@ -0,0 +1,42 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using static Tensorflow.Binding; + +namespace TensorflowNET.Keras +{ + [TestClass] + public class ModelBuildTest + { + [TestMethod] + public void DenseBuild() + { + // two dimensions input with unknown batchsize + var input = tf.keras.layers.Input((17, 60)); + var dense = tf.keras.layers.Dense(64); + var output = dense.Apply(input); + var model = tf.keras.Model(input, output); + + // one dimensions input with unknown batchsize + var input_2 = tf.keras.layers.Input((60)); + var dense_2 = tf.keras.layers.Dense(64); + var output_2 = dense.Apply(input_2); + var model_2 = tf.keras.Model(input_2, output_2); + + // two dimensions input with specified batchsize + var input_3 = tf.keras.layers.Input((17, 60), 8); + var dense_3 = tf.keras.layers.Dense(64); + var output_3 = dense.Apply(input_3); + var model_3 = tf.keras.Model(input_3, output_3); + + // one dimensions input with specified batchsize + var input_4 = tf.keras.layers.Input((60), 8); + var dense_4 = tf.keras.layers.Dense(64); + var output_4 = dense.Apply(input_4); + var model_4 = tf.keras.Model(input_4, output_4); + } + } +}