Browse Source

Fix the error when using layers.Input with unknown batch size.

tags/v0.100.5-BERT-load
Yaohui Liu 2 years ago
parent
commit
4aab86b876
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
2 changed files with 146 additions and 8 deletions
  1. +104
    -8
      src/TensorFlowNET.Core/Operations/math_ops.cs
  2. +42
    -0
      test/TensorFlowNET.Keras.UnitTest/Model/ModelBuildTest.cs

+ 104
- 8
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -905,13 +905,29 @@ namespace Tensorflow
var (a_reshape, a_free_dims, a_free_dims_static) = _tensordot_reshape(a, a_axes);
var (b_reshape, b_free_dims, b_free_dims_static) = _tensordot_reshape(b, b_axes, true);
var ab_matmul = matmul(a_reshape, b_reshape);
var dims = new List<int>();
dims.AddRange(a_free_dims);
dims.AddRange(b_free_dims);
if (ab_matmul.shape.Equals(dims))
return ab_matmul;
if(a_free_dims is int[] a_free_dims_list && b_free_dims is int[] b_free_dims_list)
{
var total_free_dims = a_free_dims_list.Concat(b_free_dims_list).ToArray();
if (ab_matmul.shape.IsFullyDefined && ab_matmul.shape.as_int_list().SequenceEqual(total_free_dims))
{
return ab_matmul;
}
else
{
return array_ops.reshape(ab_matmul, ops.convert_to_tensor(total_free_dims), name);
}
}
else
return array_ops.reshape(ab_matmul, tf.constant(dims.ToArray()), name: name);
{
var a_free_dims_tensor = ops.convert_to_tensor(a_free_dims, dtype: dtypes.int32);
var b_free_dims_tensor = ops.convert_to_tensor(b_free_dims, dtype: dtypes.int32);
var product = array_ops.reshape(ab_matmul, array_ops.concat(new[] { a_free_dims_tensor, b_free_dims_tensor }, 0), name);
if(a_free_dims_static is not null && b_free_dims_static is not null)
{
product.shape = new Shape(a_free_dims_static.Concat(b_free_dims_static).ToArray());
}
return product;
}
});
}

@@ -927,14 +943,42 @@ namespace Tensorflow
return (Binding.range(a.shape.ndim - axe, a.shape.ndim).ToArray(),
Binding.range(0, axe).ToArray());
}
else
else if(axes.rank == 1)
{
if (axes.shape[0] != 2)
{
throw new ValueError($"`axes` must be an integer or have length 2. Received {axes}.");
}
(int a_axe, int b_axe) = (axes[0], axes[1]);
return (new[] { a_axe }, new[] { b_axe });
}
else if(axes.rank == 2)
{
if (axes.shape[0] != 2)
{
throw new ValueError($"`axes` must be an integer or have length 2. Received {axes}.");
}
int[] a_axes = new int[axes.shape[1]];
int[] b_axes = new int[axes.shape[1]];
for(int i = 0; i < a_axes.Length; i++)
{
a_axes[i] = axes[0, i];
b_axes[i] = axes[1, i];
if (a_axes[i] == -1 || b_axes[i] == -1)
{
throw new ValueError($"Different number of contraction axes `a` and `b`," +
$"{len(a_axes)} != {len(b_axes)}.");
}
}
return (a_axes, b_axes);
}
else
{
throw new ValueError($"Invalid rank {axes.rank} to make tensor dot.");
}
}

static (Tensor, int[], int[]) _tensordot_reshape(Tensor a, int[] axes, bool flipped = false)
static (Tensor, object, int[]) _tensordot_reshape(Tensor a, int[] axes, bool flipped = false)
{
if (a.shape.IsFullyDefined && isinstance(axes, (typeof(int[]), typeof(Tuple))))
{
@@ -977,6 +1021,58 @@ namespace Tensorflow
var reshaped_a = array_ops.reshape(a_trans, new_shape);
return (reshaped_a, free_dims, free_dims);
}
else
{
int[] free_dims_static;
Tensor converted_shape_a, converted_axes, converted_free;
if (a.shape.ndim != -1)
{
var shape_a = a.shape.as_int_list();
for(int i = 0; i < axes.Length; i++)
{
if (axes[i] < 0)
{
axes[i] += shape_a.Length;
}
}
var free = Enumerable.Range(0, shape_a.Length).Where(i => !axes.Contains(i)).ToArray();

var axes_dims = axes.Select(i => shape_a[i]);
var free_dims = free.Select(i => shape_a[i]).ToArray();
free_dims_static = free_dims;
converted_axes = ops.convert_to_tensor(axes, dtypes.int32, "axes");
converted_free = ops.convert_to_tensor(free, dtypes.int32, "free");
converted_shape_a = array_ops.shape(a);
}
else
{
free_dims_static = null;
converted_shape_a = array_ops.shape(a);
var rank_a = array_ops.rank(a);
converted_axes = ops.convert_to_tensor(axes, dtypes.int32, "axes");
converted_axes = array_ops.where_v2(converted_axes >= 0, converted_axes, converted_axes + rank_a);
(converted_free, var _) = gen_ops.list_diff(gen_math_ops.range(ops.convert_to_tensor(0), rank_a, ops.convert_to_tensor(1)),
converted_axes, dtypes.int32);
}
var converted_free_dims = array_ops.gather(converted_shape_a, converted_free);
var converted_axes_dims = array_ops.gather(converted_shape_a, converted_axes);
var prod_free_dims = reduce_prod(converted_free_dims);
var prod_axes_dims = reduce_prod(converted_axes_dims);
Tensor reshaped_a;
if (flipped)
{
var perm = array_ops.concat(new[] { converted_axes, converted_free }, 0);
var new_shape = array_ops.stack(new[] { prod_axes_dims, prod_free_dims });
reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape);
}
else
{
var perm = array_ops.concat(new[] { converted_free, converted_axes }, 0);
var new_shape = array_ops.stack(new[] { prod_free_dims, prod_axes_dims });
reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape);
}
return (reshaped_a, converted_free_dims, free_dims_static);
}

throw new NotImplementedException("_tensordot_reshape");
}


+ 42
- 0
test/TensorFlowNET.Keras.UnitTest/Model/ModelBuildTest.cs View File

@@ -0,0 +1,42 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using static Tensorflow.Binding;

namespace TensorflowNET.Keras
{
[TestClass]
public class ModelBuildTest
{
[TestMethod]
public void DenseBuild()
{
// two dimensions input with unknown batchsize
var input = tf.keras.layers.Input((17, 60));
var dense = tf.keras.layers.Dense(64);
var output = dense.Apply(input);
var model = tf.keras.Model(input, output);

// one dimensions input with unknown batchsize
var input_2 = tf.keras.layers.Input((60));
var dense_2 = tf.keras.layers.Dense(64);
var output_2 = dense.Apply(input_2);
var model_2 = tf.keras.Model(input_2, output_2);

// two dimensions input with specified batchsize
var input_3 = tf.keras.layers.Input((17, 60), 8);
var dense_3 = tf.keras.layers.Dense(64);
var output_3 = dense.Apply(input_3);
var model_3 = tf.keras.Model(input_3, output_3);

// one dimensions input with specified batchsize
var input_4 = tf.keras.layers.Input((60), 8);
var dense_4 = tf.keras.layers.Dense(64);
var output_4 = dense.Apply(input_4);
var model_4 = tf.keras.Model(input_4, output_4);
}
}
}

Loading…
Cancel
Save