Browse Source

tf.tensordot #898

tags/TimeSeries
Oceania2018 3 years ago
parent
commit
824308a15c
5 changed files with 96 additions and 112 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/APIs/tf.linalg.cs
  2. +1
    -1
      src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Statistics.cs
  3. +8
    -0
      src/TensorFlowNET.Core/NumPy/ShapeHelper.cs
  4. +70
    -111
      src/TensorFlowNET.Core/Operations/math_ops.cs
  5. +14
    -0
      test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs

+ 3
- 0
src/TensorFlowNET.Core/APIs/tf.linalg.cs View File

@@ -57,6 +57,9 @@ namespace Tensorflow
public Tensor lstsq(Tensor matrix, Tensor rhs,
NDArray l2_regularizer = null, bool fast = true, string name = null)
=> ops.matrix_solve_ls(matrix, rhs, l2_regularizer: l2_regularizer, fast: fast, name: name);

public Tensor tensordot(Tensor x, Tensor y, NDArray axes, string name = null)
=> math_ops.tensordot(x, y, axes, name: name);
}

public Tensor diag(Tensor diagonal, string name = null)


+ 1
- 1
src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Statistics.cs View File

@@ -20,7 +20,7 @@ namespace Tensorflow.NumPy
if(a.rank != weights.rank)
{
var weights_sum = math_ops.reduce_sum(tensorW);
var axes = ops.convert_to_tensor(new[,] { { axis }, { 0 } });
var axes = np.array(new[,] { { axis }, { 0 } });
var avg = math_ops.tensordot(a, weights, axes) / weights_sum;
}


+ 8
- 0
src/TensorFlowNET.Core/NumPy/ShapeHelper.cs View File

@@ -104,6 +104,14 @@ namespace Tensorflow.NumPy
if (shape.ndim != shape3.Length)
return false;
return Enumerable.SequenceEqual(shape.as_int_list(), shape3);
case List<long> shape4:
if (shape.ndim != shape4.Count)
return false;
return Enumerable.SequenceEqual(shape.dims, shape4);
case List<int> shape5:
if (shape.ndim != shape5.Count)
return false;
return Enumerable.SequenceEqual(shape.as_int_list(), shape5);
default:
return false;
}


+ 70
- 111
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -868,133 +868,92 @@ namespace Tensorflow
public static Tensor tanh(Tensor x, string name = null)
=> gen_math_ops.tanh(x, name);

public static Tensor tensordot(Tensor x, Tensor y, int[] axes, string name = null)
public static Tensor tensordot(Tensor a, Tensor b, NDArray axes, string name = null)
{
Tensor _tensordot_reshape(Tensor a, int[] axes, bool flipped = false)
return tf_with(ops.name_scope(name, "Tensordot", new { a, b, axes }), scope =>
{
if (a.shape.IsFullyDefined && isinstance(axes, (typeof(List<object>), typeof(Tuple))))
{
var shape_a = a.shape.dims;

// axes
int iter = 0;
foreach (int i in axes)
{
if (i >= 0)
axes[0 + iter] = i;
else
axes[0 + iter] = i + len(shape_a);
iter++;
}

// free
int[] free = { };
iter = 0;
foreach (int i in Enumerable.Range(0, len(axes)))
if (!Array.Exists(axes, i => i == i))
free[free.Length] = i;

// free_dims
int[] free_dims = { };
foreach (int i in free)
free_dims[free_dims.Length] = (int)shape_a[i];

int prod_free = (int)np.prod(free_dims);

// prod_axes
int[] prod_axes_pre = { };
foreach (int i in axes)
prod_axes_pre[prod_axes_pre.Length] = (int)shape_a[i];
int prod_axes = (int)np.prod(prod_axes_pre);

// perm
Tensor perm;
if (flipped)
perm = ops.convert_to_tensor(list(free)) + ops.convert_to_tensor(free);
else
perm = ops.convert_to_tensor(list(free)) + ops.convert_to_tensor(free)
+ ops.convert_to_tensor(list(axes));

// new_shape
Shape new_shape;
if (flipped)
new_shape = new Shape(new int[] { prod_axes, prod_free });
else
new_shape = new Shape(new int[] { prod_free, prod_axes });
}
name = scope;
var (a_axes, b_axes) = _tensordot_axes(a, axes);
var (a_reshape, a_free_dims, a_free_dims_static) = _tensordot_reshape(a, a_axes);
var (b_reshape, b_free_dims, b_free_dims_static) = _tensordot_reshape(b, b_axes, true);
var ab_matmul = matmul(a_reshape, b_reshape);
var dims = new List<int>();
dims.AddRange(a_free_dims);
dims.AddRange(b_free_dims);
if (ab_matmul.shape.Equals(dims))
return ab_matmul;
else
return array_ops.reshape(ab_matmul, tf.constant(dims.ToArray()), name: name);
});
}

throw new NotImplementedException("_tensordot_reshape");
static (int[], int[]) _tensordot_axes(Tensor a, NDArray axes)
{
if (axes.rank == 0)
{
int axe = axes;
if (axe > a.shape.ndim)
throw new ValueError("`axes` must not be larger than the number of " +
$"dimensions of tensor {a}. Received {axes}, vs " +
$"tensor dimensions {a.ndim}.");
return (Binding.range(a.shape.ndim - axe, a.shape.ndim).ToArray(),
Binding.range(0, axe).ToArray());
}
else
{
(int a_axe, int b_axe) = (axes[0], axes[1]);
return (new[] { a_axe }, new[] { b_axe });
}

throw new NotImplementedException("tensordot");
}

public static Tensor tensordot(Tensor x, Tensor y, Tensor axes, string name = null)
static (Tensor, int[], int[]) _tensordot_reshape(Tensor a, int[] axes, bool flipped = false)
{
Tensor _tensordot_reshape(Tensor a, int[] axes, bool flipped = false)
if (a.shape.IsFullyDefined && isinstance(axes, (typeof(int[]), typeof(Tuple))))
{
if (a.shape.IsFullyDefined && isinstance(axes, (typeof(List<object>), typeof(Tuple))))
{
var shape_a = a.shape.dims;
var shape_a = a.shape.as_int_list();

// axes
int iter = 0;
foreach (int i in axes)
{
if (i >= 0)
axes[0 + iter] = i;
else
axes[0 + iter] = i + len(shape_a);
iter++;
}
// axes
axes = axes.Select(i => i >= 0 ? i : i + len(shape_a)).ToArray();

// free
int[] free = Binding.range(a.shape.ndim).Where(i => !axes.Contains(i)).ToArray();

// free_dims
int[] free_dims = free.Select(i => shape_a[i]).ToArray();

// free
int[] free = { };
iter = 0;
foreach (int i in Enumerable.Range(0, len(axes)))
if (!Array.Exists(axes, i => i == i))
free[free.Length] = i;

// free_dims
int[] free_dims = { };
foreach (int i in free)
free_dims[free_dims.Length] = (int)shape_a[i];

int prod_free = (int)np.prod(free_dims);

// prod_axes
int[] prod_axes_pre = { };
foreach (int i in axes)
prod_axes_pre[prod_axes_pre.Length] = (int)shape_a[i];
int prod_axes = (int)np.prod(prod_axes_pre);

// perm
Tensor perm;
if (flipped)
perm = ops.convert_to_tensor(list(free)) + ops.convert_to_tensor(free);
else
perm = ops.convert_to_tensor(list(free)) + ops.convert_to_tensor(free)
+ ops.convert_to_tensor(list(axes));

// new_shape
Shape new_shape;
if (flipped)
new_shape = new Shape(new int[] { prod_axes, prod_free });
else
new_shape = new Shape(new int[] { prod_free, prod_axes });
int prod_free = np.prod(free_dims);

// prod_axes
int prod_axes = np.prod(axes.Select(i => shape_a[i]).ToArray());

// perm
List<int> perm = new List<int>();
if (flipped)
{
perm.AddRange(axes);
perm.AddRange(free);
}
else
{
perm.AddRange(free);
perm.AddRange(axes);
}

throw new NotImplementedException("_tensordot_reshape");
// new_shape
Shape new_shape;
if (flipped)
new_shape = new Shape(new int[] { prod_axes, prod_free });
else
new_shape = new Shape(new int[] { prod_free, prod_axes });
var a_trans = a;
var reshaped_a = array_ops.reshape(a_trans, new_shape);
return (reshaped_a, free_dims, free_dims);
}

return tf_with(ops.name_scope(name, "Tensordot", new { x, y, axes }), scope =>
{
name = scope;
var (a_axes, b_axes) = (axes[0], axes[1]);
return x;
});
throw new NotImplementedException("_tensordot_reshape");
}


public static Tensor truediv(Tensor x, Tensor y, string name = null)
=> _truediv_python3(x, y, name);



+ 14
- 0
test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs View File

@@ -63,5 +63,19 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
var norm = tf.linalg.global_norm(t_list);
Assert.AreEqual(norm.numpy(), 14.282857f);
}

[TestMethod]
public void Tensordot()
{
var a = tf.constant(new[] { 1, 2 });
var b = tf.constant(new[] { 2, 3 });
var c = tf.linalg.tensordot(a, b, 0);
Assert.AreEqual(c.shape, (2, 2));
AssetSequenceEqual(c.ToArray<int>(), new[] { 2, 3, 4, 6 });

c = tf.linalg.tensordot(a, b, new[] { 0, 0 });
Assert.AreEqual(c.shape.ndim, 0);
Assert.AreEqual(c.numpy(), 8);
}
}
}

Loading…
Cancel
Save