diff --git a/src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs b/src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs index 5aecfd6e..61feb5e7 100644 --- a/src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs +++ b/src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs @@ -14,7 +14,7 @@ namespace Tensorflow.NumPy [AutoNumPy] public static NDArray argsort(NDArray a, Axis axis = null) - => new NDArray(math_ops.argmax(a, axis ?? -1)); + => new NDArray(sort_ops.argsort(a, axis: axis ?? -1)); [AutoNumPy] public static (NDArray, NDArray) unique(NDArray a) diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs index 346ba2dd..5b09810e 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs @@ -281,7 +281,7 @@ namespace Tensorflow.Operations data_format })); - public static Tensor[] top_kv2(Tensor input, int k, bool sorted = true, string name = null) + public static Tensor[] top_kv2(Tensor input, T k, bool sorted = true, string name = null) { var _op = tf.OpDefLib._apply_op_helper("TopKV2", name: name, args: new { diff --git a/src/TensorFlowNET.Core/Operations/sort_ops.cs b/src/TensorFlowNET.Core/Operations/sort_ops.cs new file mode 100644 index 00000000..314daefd --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/sort_ops.cs @@ -0,0 +1,44 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.Operations; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class sort_ops + { + public static Tensor argsort(Tensor values, Axis axis = null, string direction = "ASCENDING", bool stable = false, string name = null) + { + axis = axis ?? new Axis(-1); + var k = array_ops.shape(values)[axis]; + values = -values; + var (_, indices) = tf.Context.ExecuteOp("TopKV2", name, + new ExecuteOpArgs(values, k).SetAttributes(new + { + sorted = true + })); + return indices; + } + + public Tensor matrix_inverse(Tensor input, bool adjoint = false, string name = null) + => tf.Context.ExecuteOp("MatrixInverse", name, + new ExecuteOpArgs(input).SetAttributes(new + { + adjoint + })); + } +} diff --git a/test/TensorFlowNET.UnitTest/NumPy/Array.Sorting.Test.cs b/test/TensorFlowNET.UnitTest/NumPy/Array.Sorting.Test.cs new file mode 100644 index 00000000..2a617d40 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/NumPy/Array.Sorting.Test.cs @@ -0,0 +1,34 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow; +using Tensorflow.NumPy; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.NumPy +{ + /// + /// https://numpy.org/doc/stable/user/basics.indexing.html + /// + [TestClass] + public class ArraySortingTest : EagerModeTestBase + { + /// + /// https://numpy.org/doc/stable/reference/generated/numpy.argsort.html + /// + [TestMethod] + public void argsort() + { + var x = np.array(new[] { 3, 1, 2 }); + var ind = np.argsort(x); + Assert.AreEqual(ind, new[] { 1, 2, 0 }); + + var y = np.array(new[,] { { 0, 3 }, { 2, 2 } }); + ind = np.argsort(y, axis: 0); + Assert.AreEqual(ind[0], new[] { 0, 1 }); + Assert.AreEqual(ind[1], new[] { 1, 0 }); + } + } +}