Browse Source

add numpy api of np.moveaxis #891

tags/TimeSeries
Oceania2018 3 years ago
parent
commit
2815724887
3 changed files with 37 additions and 0 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs
  2. +20
    -0
      src/TensorFlowNET.Core/Operations/array_ops.cs
  3. +14
    -0
      test/TensorFlowNET.UnitTest/NumPy/Manipulation.Test.cs

+ 3
- 0
src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs View File

@@ -25,5 +25,8 @@ namespace Tensorflow.NumPy

[AutoNumPy]
public static NDArray stack(params NDArray[] arrays) => new NDArray(array_ops.stack(arrays));

[AutoNumPy]
public static NDArray moveaxis(NDArray array, Axis source, Axis destination) => new NDArray(array_ops.moveaxis(array, source, destination));
}
}

+ 20
- 0
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -792,6 +792,26 @@ namespace Tensorflow
});
}

public static Tensor moveaxis(NDArray array, Axis source, Axis destination)
{
List<int> perm = null;
source = source.axis.Select(x => x < 0 ? array.rank + x : x).ToArray();
destination = destination.axis.Select(x => x < 0 ? array.rank + x : x).ToArray();

if (array.shape.rank > -1)
{
perm = range(0, array.rank).Where(i => !source.axis.Contains(i)).ToList();
foreach (var (dest, src) in zip(destination.axis, source.axis).OrderBy(x => x.Item1))
{
perm.Insert(dest, src);
}
}
else
throw new NotImplementedException("");

return array_ops.transpose(array, perm.ToArray());
}

/// <summary>
/// Computes the shape of a broadcast given symbolic shapes.
/// When shape_x and shape_y are Tensors representing shapes(i.e.the result of


+ 14
- 0
test/TensorFlowNET.UnitTest/NumPy/Manipulation.Test.cs View File

@@ -24,5 +24,19 @@ namespace TensorFlowNET.UnitTest.NumPy
y = np.expand_dims(x, axis: 1);
Assert.AreEqual(y.shape, (2, 1));
}

[TestMethod]
public void moveaxis()
{
var x = np.zeros((3, 4, 5));
var y = np.moveaxis(x, 0, -1);
Assert.AreEqual(y.shape, (4, 5, 3));

y = np.moveaxis(x, (0, 1), (-1, -2));
Assert.AreEqual(y.shape, (5, 4, 3));

y = np.moveaxis(x, (0, 1, 2), (-1, -2, -3));
Assert.AreEqual(y.shape, (5, 4, 3));
}
}
}

Loading…
Cancel
Save