|
|
@@ -905,13 +905,29 @@ namespace Tensorflow |
|
|
|
var (a_reshape, a_free_dims, a_free_dims_static) = _tensordot_reshape(a, a_axes); |
|
|
|
var (b_reshape, b_free_dims, b_free_dims_static) = _tensordot_reshape(b, b_axes, true); |
|
|
|
var ab_matmul = matmul(a_reshape, b_reshape); |
|
|
|
var dims = new List<int>(); |
|
|
|
dims.AddRange(a_free_dims); |
|
|
|
dims.AddRange(b_free_dims); |
|
|
|
if (ab_matmul.shape.Equals(dims)) |
|
|
|
return ab_matmul; |
|
|
|
if(a_free_dims is int[] a_free_dims_list && b_free_dims is int[] b_free_dims_list) |
|
|
|
{ |
|
|
|
var total_free_dims = a_free_dims_list.Concat(b_free_dims_list).ToArray(); |
|
|
|
if (ab_matmul.shape.IsFullyDefined && ab_matmul.shape.as_int_list().SequenceEqual(total_free_dims)) |
|
|
|
{ |
|
|
|
return ab_matmul; |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
return array_ops.reshape(ab_matmul, ops.convert_to_tensor(total_free_dims), name); |
|
|
|
} |
|
|
|
} |
|
|
|
else |
|
|
|
return array_ops.reshape(ab_matmul, tf.constant(dims.ToArray()), name: name); |
|
|
|
{ |
|
|
|
var a_free_dims_tensor = ops.convert_to_tensor(a_free_dims, dtype: dtypes.int32); |
|
|
|
var b_free_dims_tensor = ops.convert_to_tensor(b_free_dims, dtype: dtypes.int32); |
|
|
|
var product = array_ops.reshape(ab_matmul, array_ops.concat(new[] { a_free_dims_tensor, b_free_dims_tensor }, 0), name); |
|
|
|
if(a_free_dims_static is not null && b_free_dims_static is not null) |
|
|
|
{ |
|
|
|
product.shape = new Shape(a_free_dims_static.Concat(b_free_dims_static).ToArray()); |
|
|
|
} |
|
|
|
return product; |
|
|
|
} |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
@@ -927,14 +943,42 @@ namespace Tensorflow |
|
|
|
return (Binding.range(a.shape.ndim - axe, a.shape.ndim).ToArray(), |
|
|
|
Binding.range(0, axe).ToArray()); |
|
|
|
} |
|
|
|
else |
|
|
|
else if(axes.rank == 1) |
|
|
|
{ |
|
|
|
if (axes.shape[0] != 2) |
|
|
|
{ |
|
|
|
throw new ValueError($"`axes` must be an integer or have length 2. Received {axes}."); |
|
|
|
} |
|
|
|
(int a_axe, int b_axe) = (axes[0], axes[1]); |
|
|
|
return (new[] { a_axe }, new[] { b_axe }); |
|
|
|
} |
|
|
|
else if(axes.rank == 2) |
|
|
|
{ |
|
|
|
if (axes.shape[0] != 2) |
|
|
|
{ |
|
|
|
throw new ValueError($"`axes` must be an integer or have length 2. Received {axes}."); |
|
|
|
} |
|
|
|
int[] a_axes = new int[axes.shape[1]]; |
|
|
|
int[] b_axes = new int[axes.shape[1]]; |
|
|
|
for(int i = 0; i < a_axes.Length; i++) |
|
|
|
{ |
|
|
|
a_axes[i] = axes[0, i]; |
|
|
|
b_axes[i] = axes[1, i]; |
|
|
|
if (a_axes[i] == -1 || b_axes[i] == -1) |
|
|
|
{ |
|
|
|
throw new ValueError($"Different number of contraction axes `a` and `b`," + |
|
|
|
$"{len(a_axes)} != {len(b_axes)}."); |
|
|
|
} |
|
|
|
} |
|
|
|
return (a_axes, b_axes); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
throw new ValueError($"Invalid rank {axes.rank} to make tensor dot."); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
static (Tensor, int[], int[]) _tensordot_reshape(Tensor a, int[] axes, bool flipped = false) |
|
|
|
static (Tensor, object, int[]) _tensordot_reshape(Tensor a, int[] axes, bool flipped = false) |
|
|
|
{ |
|
|
|
if (a.shape.IsFullyDefined && isinstance(axes, (typeof(int[]), typeof(Tuple)))) |
|
|
|
{ |
|
|
@@ -977,6 +1021,58 @@ namespace Tensorflow |
|
|
|
var reshaped_a = array_ops.reshape(a_trans, new_shape); |
|
|
|
return (reshaped_a, free_dims, free_dims); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
int[] free_dims_static; |
|
|
|
Tensor converted_shape_a, converted_axes, converted_free; |
|
|
|
if (a.shape.ndim != -1) |
|
|
|
{ |
|
|
|
var shape_a = a.shape.as_int_list(); |
|
|
|
for(int i = 0; i < axes.Length; i++) |
|
|
|
{ |
|
|
|
if (axes[i] < 0) |
|
|
|
{ |
|
|
|
axes[i] += shape_a.Length; |
|
|
|
} |
|
|
|
} |
|
|
|
var free = Enumerable.Range(0, shape_a.Length).Where(i => !axes.Contains(i)).ToArray(); |
|
|
|
|
|
|
|
var axes_dims = axes.Select(i => shape_a[i]); |
|
|
|
var free_dims = free.Select(i => shape_a[i]).ToArray(); |
|
|
|
free_dims_static = free_dims; |
|
|
|
converted_axes = ops.convert_to_tensor(axes, dtypes.int32, "axes"); |
|
|
|
converted_free = ops.convert_to_tensor(free, dtypes.int32, "free"); |
|
|
|
converted_shape_a = array_ops.shape(a); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
free_dims_static = null; |
|
|
|
converted_shape_a = array_ops.shape(a); |
|
|
|
var rank_a = array_ops.rank(a); |
|
|
|
converted_axes = ops.convert_to_tensor(axes, dtypes.int32, "axes"); |
|
|
|
converted_axes = array_ops.where_v2(converted_axes >= 0, converted_axes, converted_axes + rank_a); |
|
|
|
(converted_free, var _) = gen_ops.list_diff(gen_math_ops.range(ops.convert_to_tensor(0), rank_a, ops.convert_to_tensor(1)), |
|
|
|
converted_axes, dtypes.int32); |
|
|
|
} |
|
|
|
var converted_free_dims = array_ops.gather(converted_shape_a, converted_free); |
|
|
|
var converted_axes_dims = array_ops.gather(converted_shape_a, converted_axes); |
|
|
|
var prod_free_dims = reduce_prod(converted_free_dims); |
|
|
|
var prod_axes_dims = reduce_prod(converted_axes_dims); |
|
|
|
Tensor reshaped_a; |
|
|
|
if (flipped) |
|
|
|
{ |
|
|
|
var perm = array_ops.concat(new[] { converted_axes, converted_free }, 0); |
|
|
|
var new_shape = array_ops.stack(new[] { prod_axes_dims, prod_free_dims }); |
|
|
|
reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
var perm = array_ops.concat(new[] { converted_free, converted_axes }, 0); |
|
|
|
var new_shape = array_ops.stack(new[] { prod_free_dims, prod_axes_dims }); |
|
|
|
reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape); |
|
|
|
} |
|
|
|
return (reshaped_a, converted_free_dims, free_dims_static); |
|
|
|
} |
|
|
|
|
|
|
|
throw new NotImplementedException("_tensordot_reshape"); |
|
|
|
} |
|
|
|