diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs index 160e1d6e..0b751c39 100644 --- a/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs +++ b/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs @@ -175,11 +175,25 @@ namespace Tensorflow.NumPy void SetData(IEnumerable slices, NDArray array, int currentNDim, int[] indices) { if (dtype != array.dtype) - throw new ArrayTypeMismatchException($"Required dtype {dtype} but {array.dtype} is assigned."); + array = array.astype(dtype); + // throw new ArrayTypeMismatchException($"Required dtype {dtype} but {array.dtype} is assigned."); if (!slices.Any()) return; + var newshape = ShapeHelper.GetShape(shape, slices.ToArray()); + if(newshape.Equals(array.shape)) + { + var offset = ShapeHelper.GetOffset(shape, slices.First().Start ?? 0); + unsafe + { + var dst = (byte*)data + (ulong)offset * dtypesize; + System.Buffer.MemoryCopy(array.data.ToPointer(), dst, array.bytesize, array.bytesize); + } + return; + } + + var slice = slices.First(); if (slices.Count() == 1) @@ -204,6 +218,9 @@ namespace Tensorflow.NumPy } currentNDim++; + if (slice.Stop == null) + slice.Stop = (int)dims[currentNDim]; + for (var i = slice.Start ?? 0; i < slice.Stop; i++) { indices[currentNDim] = i; diff --git a/src/TensorFlowNET.Core/NumPy/NumPy.Logical.cs b/src/TensorFlowNET.Core/NumPy/NumPy.Logical.cs index 65417731..b4add508 100644 --- a/src/TensorFlowNET.Core/NumPy/NumPy.Logical.cs +++ b/src/TensorFlowNET.Core/NumPy/NumPy.Logical.cs @@ -2,7 +2,7 @@ using System.Collections; using System.Collections.Generic; using System.Numerics; -using System.Text; +using System.Linq; using static Tensorflow.Binding; namespace Tensorflow.NumPy @@ -10,7 +10,7 @@ namespace Tensorflow.NumPy public partial class np { [AutoNumPy] - public static NDArray any(NDArray a, Axis axis = null) => throw new NotImplementedException(""); + public static NDArray any(NDArray a, Axis axis = null) => new NDArray(a.ToArray().Any(x => x)); [AutoNumPy] public static NDArray logical_or(NDArray x1, NDArray x2) => new NDArray(tf.logical_or(x1, x2)); diff --git a/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs b/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs index afecd7b9..685b0e38 100644 --- a/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs +++ b/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs @@ -8,6 +8,12 @@ namespace Tensorflow.NumPy { public partial class np { + [AutoNumPy] + public static NDArray concatenate(NDArray[] arrays, int axis = 0) => new NDArray(array_ops.concat(arrays, axis)); + + [AutoNumPy] + public static NDArray dstack(params NDArray[] tup) => throw new NotImplementedException(""); + [AutoNumPy] public static NDArray expand_dims(NDArray a, Axis? axis = null) => throw new NotImplementedException(""); @@ -19,8 +25,5 @@ namespace Tensorflow.NumPy [AutoNumPy] public static NDArray stack(params NDArray[] arrays) => new NDArray(array_ops.stack(arrays)); - - [AutoNumPy] - public static NDArray dstack(params NDArray[] tup) => throw new NotImplementedException(""); } } diff --git a/src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs b/src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs index 61141cd0..7e6a2b65 100644 --- a/src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs +++ b/src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs @@ -11,7 +11,11 @@ namespace Tensorflow.NumPy public partial class np { [AutoNumPy] - public static NDArray array(Array data) => new NDArray(data); + public static NDArray array(Array data, TF_DataType? dtype = null) + { + var nd = new NDArray(data); + return dtype == null ? nd : nd.astype(dtype.Value); + } [AutoNumPy] public static NDArray array(params T[] data) diff --git a/src/TensorFlowNET.Core/Numpy/Numpy.cs b/src/TensorFlowNET.Core/Numpy/Numpy.cs index 89077796..cd9373d4 100644 --- a/src/TensorFlowNET.Core/Numpy/Numpy.cs +++ b/src/TensorFlowNET.Core/Numpy/Numpy.cs @@ -69,9 +69,6 @@ namespace Tensorflow.NumPy public static bool array_equal(NDArray a, NDArray b) => a.Equals(b); - public static NDArray concatenate(NDArray[] arrays, int axis = 0) - => throw new NotImplementedException(""); - public static bool allclose(NDArray a, NDArray b, double rtol = 1.0E-5, double atol = 1.0E-8, bool equal_nan = false) => throw new NotImplementedException(""); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs index fe45d259..ef71be2c 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs @@ -278,53 +278,37 @@ namespace Tensorflow protected static Tensor BinaryOpWrapper(string name, Tx x, Ty y) { - TF_DataType dtype = TF_DataType.DtInvalid; - - if (x is Tensor tl) - { - dtype = tl.dtype.as_base_dtype(); - } - - if (y is Tensor tr) - { - dtype = tr.dtype.as_base_dtype(); - } - return tf_with(ops.name_scope(null, name, new { x, y }), scope => { - Tensor result; - var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x"); - var y1 = ops.convert_to_tensor(y, dtype: dtype, name: "y"); + var dtype = GetBestDType(x, y); + var x1 = ops.convert_to_tensor(x, name: "x", dtype: dtype); + var y1 = ops.convert_to_tensor(y, name: "y", dtype: dtype); + string newname = scope; - switch (name.ToLowerInvariant()) + return name.ToLowerInvariant() switch { - case "add": - result = math_ops.add_v2(x1, y1, name: scope); - break; - case "div": - result = math_ops.div(x1, y1, name: scope); - break; - case "floordiv": - result = gen_math_ops.floor_div(x1, y1, name: scope); - break; - case "truediv": - result = math_ops.truediv(x1, y1, name: scope); - break; - case "mul": - result = math_ops.multiply(x1, y1, name: scope); - break; - case "sub": - result = gen_math_ops.sub(x1, y1, name: scope); - break; - case "mod": - result = gen_math_ops.floor_mod(x1, y1, name: scope); - break; - default: - throw new NotImplementedException($"BinaryOpWrapper: {name} - {typeof(Tx).Name}, {typeof(Ty).Name}"); - } - - return result; + "add" => math_ops.add_v2(x1, y1, name: newname), + "div" => math_ops.div(x1, y1, name: newname), + "floordiv" => gen_math_ops.floor_div(x1, y1, name: newname), + "truediv" => math_ops.truediv(x1, y1, name: newname), + "mul" => math_ops.multiply(x1, y1, name: newname), + "sub" => gen_math_ops.sub(x1, y1, name: newname), + "mod" => gen_math_ops.floor_mod(x1, y1, name: newname), + _ => throw new NotImplementedException($"BinaryOpWrapper: {name} - {typeof(Tx).Name}, {typeof(Ty).Name}") + }; }); } + + static TF_DataType GetBestDType(Tx x, Ty y) + { + var dtype1 = x.GetDataType(); + var dtype2 = y.GetDataType(); + if (dtype1.is_integer() && dtype2.is_floating()) + return dtype2; + else if (dtype1.is_floating() && dtype2.is_integer()) + return dtype1; + else + return dtype1; + } } } diff --git a/test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs b/test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs index a3306f88..072f0079 100644 --- a/test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs +++ b/test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs @@ -35,5 +35,13 @@ namespace TensorFlowNET.UnitTest.NumPy var x1 = x.astype(np.float32); Assert.AreEqual(x1[2], 200f); } + + [TestMethod] + public void divide() + { + var x = np.array(new float[] { 1, 100, 200 }); + var y = x / 2; + Assert.AreEqual(y.dtype, np.float32); + } } }