|
|
@@ -868,133 +868,92 @@ namespace Tensorflow |
|
|
|
public static Tensor tanh(Tensor x, string name = null) |
|
|
|
=> gen_math_ops.tanh(x, name); |
|
|
|
|
|
|
|
public static Tensor tensordot(Tensor x, Tensor y, int[] axes, string name = null) |
|
|
|
public static Tensor tensordot(Tensor a, Tensor b, NDArray axes, string name = null) |
|
|
|
{ |
|
|
|
Tensor _tensordot_reshape(Tensor a, int[] axes, bool flipped = false) |
|
|
|
return tf_with(ops.name_scope(name, "Tensordot", new { a, b, axes }), scope => |
|
|
|
{ |
|
|
|
if (a.shape.IsFullyDefined && isinstance(axes, (typeof(List<object>), typeof(Tuple)))) |
|
|
|
{ |
|
|
|
var shape_a = a.shape.dims; |
|
|
|
|
|
|
|
// axes |
|
|
|
int iter = 0; |
|
|
|
foreach (int i in axes) |
|
|
|
{ |
|
|
|
if (i >= 0) |
|
|
|
axes[0 + iter] = i; |
|
|
|
else |
|
|
|
axes[0 + iter] = i + len(shape_a); |
|
|
|
iter++; |
|
|
|
} |
|
|
|
|
|
|
|
// free |
|
|
|
int[] free = { }; |
|
|
|
iter = 0; |
|
|
|
foreach (int i in Enumerable.Range(0, len(axes))) |
|
|
|
if (!Array.Exists(axes, i => i == i)) |
|
|
|
free[free.Length] = i; |
|
|
|
|
|
|
|
// free_dims |
|
|
|
int[] free_dims = { }; |
|
|
|
foreach (int i in free) |
|
|
|
free_dims[free_dims.Length] = (int)shape_a[i]; |
|
|
|
|
|
|
|
int prod_free = (int)np.prod(free_dims); |
|
|
|
|
|
|
|
// prod_axes |
|
|
|
int[] prod_axes_pre = { }; |
|
|
|
foreach (int i in axes) |
|
|
|
prod_axes_pre[prod_axes_pre.Length] = (int)shape_a[i]; |
|
|
|
int prod_axes = (int)np.prod(prod_axes_pre); |
|
|
|
|
|
|
|
// perm |
|
|
|
Tensor perm; |
|
|
|
if (flipped) |
|
|
|
perm = ops.convert_to_tensor(list(free)) + ops.convert_to_tensor(free); |
|
|
|
else |
|
|
|
perm = ops.convert_to_tensor(list(free)) + ops.convert_to_tensor(free) |
|
|
|
+ ops.convert_to_tensor(list(axes)); |
|
|
|
|
|
|
|
// new_shape |
|
|
|
Shape new_shape; |
|
|
|
if (flipped) |
|
|
|
new_shape = new Shape(new int[] { prod_axes, prod_free }); |
|
|
|
else |
|
|
|
new_shape = new Shape(new int[] { prod_free, prod_axes }); |
|
|
|
} |
|
|
|
name = scope; |
|
|
|
var (a_axes, b_axes) = _tensordot_axes(a, axes); |
|
|
|
var (a_reshape, a_free_dims, a_free_dims_static) = _tensordot_reshape(a, a_axes); |
|
|
|
var (b_reshape, b_free_dims, b_free_dims_static) = _tensordot_reshape(b, b_axes, true); |
|
|
|
var ab_matmul = matmul(a_reshape, b_reshape); |
|
|
|
var dims = new List<int>(); |
|
|
|
dims.AddRange(a_free_dims); |
|
|
|
dims.AddRange(b_free_dims); |
|
|
|
if (ab_matmul.shape.Equals(dims)) |
|
|
|
return ab_matmul; |
|
|
|
else |
|
|
|
return array_ops.reshape(ab_matmul, tf.constant(dims.ToArray()), name: name); |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
throw new NotImplementedException("_tensordot_reshape"); |
|
|
|
static (int[], int[]) _tensordot_axes(Tensor a, NDArray axes) |
|
|
|
{ |
|
|
|
if (axes.rank == 0) |
|
|
|
{ |
|
|
|
int axe = axes; |
|
|
|
if (axe > a.shape.ndim) |
|
|
|
throw new ValueError("`axes` must not be larger than the number of " + |
|
|
|
$"dimensions of tensor {a}. Received {axes}, vs " + |
|
|
|
$"tensor dimensions {a.ndim}."); |
|
|
|
return (Binding.range(a.shape.ndim - axe, a.shape.ndim).ToArray(), |
|
|
|
Binding.range(0, axe).ToArray()); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
(int a_axe, int b_axe) = (axes[0], axes[1]); |
|
|
|
return (new[] { a_axe }, new[] { b_axe }); |
|
|
|
} |
|
|
|
|
|
|
|
throw new NotImplementedException("tensordot"); |
|
|
|
} |
|
|
|
|
|
|
|
public static Tensor tensordot(Tensor x, Tensor y, Tensor axes, string name = null) |
|
|
|
static (Tensor, int[], int[]) _tensordot_reshape(Tensor a, int[] axes, bool flipped = false) |
|
|
|
{ |
|
|
|
Tensor _tensordot_reshape(Tensor a, int[] axes, bool flipped = false) |
|
|
|
if (a.shape.IsFullyDefined && isinstance(axes, (typeof(int[]), typeof(Tuple)))) |
|
|
|
{ |
|
|
|
if (a.shape.IsFullyDefined && isinstance(axes, (typeof(List<object>), typeof(Tuple)))) |
|
|
|
{ |
|
|
|
var shape_a = a.shape.dims; |
|
|
|
var shape_a = a.shape.as_int_list(); |
|
|
|
|
|
|
|
// axes |
|
|
|
int iter = 0; |
|
|
|
foreach (int i in axes) |
|
|
|
{ |
|
|
|
if (i >= 0) |
|
|
|
axes[0 + iter] = i; |
|
|
|
else |
|
|
|
axes[0 + iter] = i + len(shape_a); |
|
|
|
iter++; |
|
|
|
} |
|
|
|
// axes |
|
|
|
axes = axes.Select(i => i >= 0 ? i : i + len(shape_a)).ToArray(); |
|
|
|
|
|
|
|
// free |
|
|
|
int[] free = Binding.range(a.shape.ndim).Where(i => !axes.Contains(i)).ToArray(); |
|
|
|
|
|
|
|
// free_dims |
|
|
|
int[] free_dims = free.Select(i => shape_a[i]).ToArray(); |
|
|
|
|
|
|
|
// free |
|
|
|
int[] free = { }; |
|
|
|
iter = 0; |
|
|
|
foreach (int i in Enumerable.Range(0, len(axes))) |
|
|
|
if (!Array.Exists(axes, i => i == i)) |
|
|
|
free[free.Length] = i; |
|
|
|
|
|
|
|
// free_dims |
|
|
|
int[] free_dims = { }; |
|
|
|
foreach (int i in free) |
|
|
|
free_dims[free_dims.Length] = (int)shape_a[i]; |
|
|
|
|
|
|
|
int prod_free = (int)np.prod(free_dims); |
|
|
|
|
|
|
|
// prod_axes |
|
|
|
int[] prod_axes_pre = { }; |
|
|
|
foreach (int i in axes) |
|
|
|
prod_axes_pre[prod_axes_pre.Length] = (int)shape_a[i]; |
|
|
|
int prod_axes = (int)np.prod(prod_axes_pre); |
|
|
|
|
|
|
|
// perm |
|
|
|
Tensor perm; |
|
|
|
if (flipped) |
|
|
|
perm = ops.convert_to_tensor(list(free)) + ops.convert_to_tensor(free); |
|
|
|
else |
|
|
|
perm = ops.convert_to_tensor(list(free)) + ops.convert_to_tensor(free) |
|
|
|
+ ops.convert_to_tensor(list(axes)); |
|
|
|
|
|
|
|
// new_shape |
|
|
|
Shape new_shape; |
|
|
|
if (flipped) |
|
|
|
new_shape = new Shape(new int[] { prod_axes, prod_free }); |
|
|
|
else |
|
|
|
new_shape = new Shape(new int[] { prod_free, prod_axes }); |
|
|
|
int prod_free = np.prod(free_dims); |
|
|
|
|
|
|
|
// prod_axes |
|
|
|
int prod_axes = np.prod(axes.Select(i => shape_a[i]).ToArray()); |
|
|
|
|
|
|
|
// perm |
|
|
|
List<int> perm = new List<int>(); |
|
|
|
if (flipped) |
|
|
|
{ |
|
|
|
perm.AddRange(axes); |
|
|
|
perm.AddRange(free); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
perm.AddRange(free); |
|
|
|
perm.AddRange(axes); |
|
|
|
} |
|
|
|
|
|
|
|
throw new NotImplementedException("_tensordot_reshape"); |
|
|
|
// new_shape |
|
|
|
Shape new_shape; |
|
|
|
if (flipped) |
|
|
|
new_shape = new Shape(new int[] { prod_axes, prod_free }); |
|
|
|
else |
|
|
|
new_shape = new Shape(new int[] { prod_free, prod_axes }); |
|
|
|
var a_trans = a; |
|
|
|
var reshaped_a = array_ops.reshape(a_trans, new_shape); |
|
|
|
return (reshaped_a, free_dims, free_dims); |
|
|
|
} |
|
|
|
|
|
|
|
return tf_with(ops.name_scope(name, "Tensordot", new { x, y, axes }), scope => |
|
|
|
{ |
|
|
|
name = scope; |
|
|
|
var (a_axes, b_axes) = (axes[0], axes[1]); |
|
|
|
return x; |
|
|
|
}); |
|
|
|
throw new NotImplementedException("_tensordot_reshape"); |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
public static Tensor truediv(Tensor x, Tensor y, string name = null) |
|
|
|
=> _truediv_python3(x, y, name); |
|
|
|
|
|
|
|