@@ -99,13 +99,12 @@ namespace Tensorflow | |||||
/// <param name="input"></param> | /// <param name="input"></param> | ||||
/// <param name="axis"></param> | /// <param name="axis"></param> | ||||
/// <param name="name"></param> | /// <param name="name"></param> | ||||
/// <param name="dim"></param> | |||||
/// <returns> | /// <returns> | ||||
/// A `Tensor` with the same data as `input`, but its shape has an additional | /// A `Tensor` with the same data as `input`, but its shape has an additional | ||||
/// dimension of size 1 added. | /// dimension of size 1 added. | ||||
/// </returns> | /// </returns> | ||||
public Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1) | |||||
=> array_ops.expand_dims(input, axis, name, dim); | |||||
public Tensor expand_dims(Tensor input, int axis = -1, string name = null) | |||||
=> array_ops.expand_dims(input, axis, name); | |||||
/// <summary> | /// <summary> | ||||
/// Creates a tensor filled with a scalar value. | /// Creates a tensor filled with a scalar value. | ||||
@@ -15,7 +15,7 @@ namespace Tensorflow.NumPy | |||||
public static NDArray dstack(params NDArray[] tup) => throw new NotImplementedException(""); | public static NDArray dstack(params NDArray[] tup) => throw new NotImplementedException(""); | ||||
[AutoNumPy] | [AutoNumPy] | ||||
public static NDArray expand_dims(NDArray a, Axis? axis = null) => throw new NotImplementedException(""); | |||||
public static NDArray expand_dims(NDArray a, Axis? axis = null) => new NDArray(array_ops.expand_dims(a, axis: axis ?? -1)); | |||||
[AutoNumPy] | [AutoNumPy] | ||||
public static NDArray reshape(NDArray x1, Shape newshape) => x1.reshape(newshape); | public static NDArray reshape(NDArray x1, Shape newshape) => x1.reshape(newshape); | ||||
@@ -300,10 +300,7 @@ namespace Tensorflow | |||||
return result; | return result; | ||||
} | } | ||||
public static Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1) | |||||
=> expand_dims_v2(input, axis, name); | |||||
private static Tensor expand_dims_v2(Tensor input, int axis, string name = null) | |||||
public static Tensor expand_dims(Tensor input, int axis = -1, string name = null) | |||||
=> gen_array_ops.expand_dims(input, axis, name); | => gen_array_ops.expand_dims(input, axis, name); | ||||
/// <summary> | /// <summary> | ||||
@@ -0,0 +1,28 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using Tensorflow; | |||||
using Tensorflow.NumPy; | |||||
namespace TensorFlowNET.UnitTest.NumPy | |||||
{ | |||||
/// <summary> | |||||
/// https://numpy.org/doc/stable/reference/routines.array-manipulation.html | |||||
/// </summary> | |||||
[TestClass] | |||||
public class ManipulationTest : EagerModeTestBase | |||||
{ | |||||
[TestMethod] | |||||
public void expand_dims() | |||||
{ | |||||
var x = np.array(new[] { 1, 2 }); | |||||
var y = np.expand_dims(x, axis: 0); | |||||
Assert.AreEqual(y.shape, (1, 2)); | |||||
y = np.expand_dims(x, axis: 1); | |||||
Assert.AreEqual(y.shape, (2, 1)); | |||||
} | |||||
} | |||||
} |