diff --git a/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs b/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs index 0e50cd56..ea85048f 100644 --- a/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs +++ b/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs @@ -28,7 +28,16 @@ namespace Tensorflow.NumPy public static NDArray multiply(NDArray x1, NDArray x2) => new NDArray(tf.multiply(x1, x2)); [AutoNumPy] - public static NDArray maximum(NDArray x1, NDArray x2) => new NDArray(tf.maximum(x1, x2)); + //public static NDArray maximum(NDArray x1, NDArray x2) => new NDArray(tf.maximum(x1, x2)); + public static NDArray maximum(NDArray x1, NDArray x2, int? axis = null) + { + var maxValues = tf.maximum(x1, x2); + if (axis.HasValue) + { + maxValues = tf.reduce_max(maxValues, axis: axis.Value); + } + return new NDArray(maxValues); + } [AutoNumPy] public static NDArray minimum(NDArray x1, NDArray x2) => new NDArray(tf.minimum(x1, x2));