From 1b1a50371b0829363d1f9c469aedbe727a6ec41f Mon Sep 17 00:00:00 2001
From: Visagan Guruparan <103048@smsassist.com>
Date: Sun, 18 Jun 2023 22:46:36 -0500
Subject: [PATCH] np update square and dot product
---
src/TensorFlowNET.Core/APIs/tf.math.cs | 15 ++++++++--
src/TensorFlowNET.Core/Binding.Util.cs | 23 ++++++++++++++-
src/TensorFlowNET.Core/NumPy/Numpy.Math.cs | 21 ++++++++++++++
.../TensorFlowNET.UnitTest/Numpy/Math.Test.cs | 29 ++++++++++++++++++-
4 files changed, 84 insertions(+), 4 deletions(-)
diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs
index 75253700..0e53d938 100644
--- a/src/TensorFlowNET.Core/APIs/tf.math.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.math.cs
@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/
+using Tensorflow.NumPy;
using Tensorflow.Operations;
namespace Tensorflow
@@ -42,7 +43,6 @@ namespace Tensorflow
public Tensor multiply(Tensor x, Tensor y, string name = null)
=> math_ops.multiply(x, y, name: name);
-
public Tensor divide_no_nan(Tensor a, Tensor b, string name = null)
=> math_ops.div_no_nan(a, b);
@@ -452,7 +452,18 @@ namespace Tensorflow
///
public Tensor multiply(Tx x, Ty y, string name = null)
=> gen_math_ops.mul(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name);
-
+ ///
+ /// return scalar product
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public Tensor dot_prod(Tx x, Ty y, NDArray axes, string name = null)
+ => math_ops.tensordot(convert_to_tensor(x), convert_to_tensor(y), axes, name: name);
public Tensor negative(Tensor x, string name = null)
=> gen_math_ops.neg(x, name);
diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs
index 8df39334..e414ef6e 100644
--- a/src/TensorFlowNET.Core/Binding.Util.cs
+++ b/src/TensorFlowNET.Core/Binding.Util.cs
@@ -486,7 +486,28 @@ namespace Tensorflow
throw new NotImplementedException("");
}
}
-
+ public static NDArray GetFlattenArray(NDArray x)
+ {
+ switch (x.GetDataType())
+ {
+ case TF_DataType.TF_FLOAT:
+ x = x.ToArray();
+ break;
+ case TF_DataType.TF_DOUBLE:
+ x = x.ToArray();
+ break;
+ case TF_DataType.TF_INT16:
+ case TF_DataType.TF_INT32:
+ x = x.ToArray();
+ break;
+ case TF_DataType.TF_INT64:
+ x = x.ToArray();
+ break;
+ default:
+ break;
+ }
+ return x;
+ }
public static TF_DataType GetDataType(this object data)
{
var type = data.GetType();
diff --git a/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs b/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs
index ea85048f..5bc97952 100644
--- a/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs
+++ b/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs
@@ -49,9 +49,30 @@ namespace Tensorflow.NumPy
[AutoNumPy]
public static NDArray prod(params T[] array) where T : unmanaged
=> new NDArray(tf.reduce_prod(new NDArray(array)));
+ [AutoNumPy]
+ public static NDArray dot(NDArray x1, NDArray x2, NDArray? axes = null, string? name = null)
+ {
+ //if axes mentioned
+ if (axes != null)
+ {
+ return new NDArray(tf.dot_prod(x1, x2, axes, name));
+ }
+ if (x1.shape.ndim > 1)
+ {
+ x1 = GetFlattenArray(x1);
+ }
+ if (x2.shape.ndim > 1)
+ {
+ x2 = GetFlattenArray(x2);
+ }
+ //if axes not mentioned, default 0,0
+ return new NDArray(tf.dot_prod(x1, x2, axes: new int[] { 0, 0 }, name));
+ }
[AutoNumPy]
public static NDArray power(NDArray x, NDArray y) => new NDArray(tf.pow(x, y));
+ [AutoNumPy]
+ public static NDArray square(NDArray x) => new NDArray(tf.square(x));
[AutoNumPy]
public static NDArray sin(NDArray x) => new NDArray(math_ops.sin(x));
diff --git a/test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs b/test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs
index 32b517e4..65cdaedd 100644
--- a/test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs
+++ b/test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs
@@ -65,7 +65,34 @@ namespace TensorFlowNET.UnitTest.NumPy
var y = np.power(x, 3);
Assert.AreEqual(y, new[] { 0, 1, 8, 27, 64, 125 });
}
- [TestMethod]
+ [TestMethod]
+ public void square()
+ {
+ var x = np.arange(6);
+ var y = np.square(x);
+ Assert.AreEqual(y, new[] { 0, 1, 4, 9, 16, 25 });
+ }
+ [TestMethod]
+ public void dotproduct()
+ {
+ var x1 = new NDArray(new[] { 1, 2, 3 });
+ var x2 = new NDArray(new[] { 4, 5, 6 });
+ double result1 = np.dot(x1, x2);
+ NDArray y1 = new float[,] {
+ { 1.0f, 2.0f, 3.0f },
+ { 4.0f, 5.1f,6.0f },
+ { 4.0f, 5.1f,6.0f }
+ };
+ NDArray y2 = new float[,] {
+ { 3.0f, 2.0f, 1.0f },
+ { 6.0f, 5.1f, 4.0f },
+ { 6.0f, 5.1f, 4.0f }
+ };
+ double result2 = np.dot(y1, y2);
+ Assert.AreEqual(result1, 32);
+ Assert.AreEqual(Math.Round(result2, 2), 158.02);
+ }
+ [TestMethod]
public void maximum()
{
var x1 = new NDArray(new[,] { { 1, 2, 3 }, { 4, 5.1, 6 } });