From e859b20d09260896d98bd0b008c2876e87c0db86 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Mon, 9 Aug 2021 23:10:27 -0500 Subject: [PATCH] np.expand_dims --- src/TensorFlowNET.Core/APIs/tf.array.cs | 5 ++-- .../NumPy/Numpy.Manipulation.cs | 2 +- .../Operations/array_ops.cs | 5 +--- .../NumPy/Manipulation.Test.cs | 28 +++++++++++++++++++ 4 files changed, 32 insertions(+), 8 deletions(-) create mode 100644 test/TensorFlowNET.UnitTest/NumPy/Manipulation.Test.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs index be614294..1d2e55a7 100644 --- a/src/TensorFlowNET.Core/APIs/tf.array.cs +++ b/src/TensorFlowNET.Core/APIs/tf.array.cs @@ -99,13 +99,12 @@ namespace Tensorflow /// /// /// - /// /// /// A `Tensor` with the same data as `input`, but its shape has an additional /// dimension of size 1 added. /// - public Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1) - => array_ops.expand_dims(input, axis, name, dim); + public Tensor expand_dims(Tensor input, int axis = -1, string name = null) + => array_ops.expand_dims(input, axis, name); /// /// Creates a tensor filled with a scalar value. diff --git a/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs b/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs index 685b0e38..698e6fcc 100644 --- a/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs +++ b/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs @@ -15,7 +15,7 @@ namespace Tensorflow.NumPy public static NDArray dstack(params NDArray[] tup) => throw new NotImplementedException(""); [AutoNumPy] - public static NDArray expand_dims(NDArray a, Axis? axis = null) => throw new NotImplementedException(""); + public static NDArray expand_dims(NDArray a, Axis? axis = null) => new NDArray(array_ops.expand_dims(a, axis: axis ?? -1)); [AutoNumPy] public static NDArray reshape(NDArray x1, Shape newshape) => x1.reshape(newshape); diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index e821dfb0..3dc8cf12 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -300,10 +300,7 @@ namespace Tensorflow return result; } - public static Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1) - => expand_dims_v2(input, axis, name); - - private static Tensor expand_dims_v2(Tensor input, int axis, string name = null) + public static Tensor expand_dims(Tensor input, int axis = -1, string name = null) => gen_array_ops.expand_dims(input, axis, name); /// diff --git a/test/TensorFlowNET.UnitTest/NumPy/Manipulation.Test.cs b/test/TensorFlowNET.UnitTest/NumPy/Manipulation.Test.cs new file mode 100644 index 00000000..a7437f66 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/NumPy/Manipulation.Test.cs @@ -0,0 +1,28 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow; +using Tensorflow.NumPy; + +namespace TensorFlowNET.UnitTest.NumPy +{ + /// + /// https://numpy.org/doc/stable/reference/routines.array-manipulation.html + /// + [TestClass] + public class ManipulationTest : EagerModeTestBase + { + [TestMethod] + public void expand_dims() + { + var x = np.array(new[] { 1, 2 }); + var y = np.expand_dims(x, axis: 0); + Assert.AreEqual(y.shape, (1, 2)); + + y = np.expand_dims(x, axis: 1); + Assert.AreEqual(y.shape, (2, 1)); + } + } +}