diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs index 0073b4f5..9797e861 100644 --- a/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs +++ b/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs @@ -57,7 +57,7 @@ namespace Tensorflow.NumPy } } - + [AutoNumPy] unsafe NDArray GetData(Slice[] slices) { if (shape.IsScalar) @@ -170,9 +170,9 @@ namespace Tensorflow.NumPy } void SetData(IEnumerable slices, NDArray array) - => SetData(array, data, slices.ToArray(), new int[shape.ndim].ToArray(), -1); + => SetData(array, slices.ToArray(), new int[shape.ndim].ToArray(), -1); - unsafe void SetData(NDArray src, IntPtr dst, Slice[] slices, int[] indices, int currentNDim) + unsafe void SetData(NDArray src, Slice[] slices, int[] indices, int currentNDim) { if (dtype != src.dtype) src = src.astype(dtype); @@ -181,20 +181,23 @@ namespace Tensorflow.NumPy if (!slices.Any()) return; + if (shape.Equals(src.shape)) + { + System.Buffer.MemoryCopy(src.data.ToPointer(), data.ToPointer(), src.bytesize, src.bytesize); + return; + } + // first iteration if(currentNDim == -1) { slices = SliceHelper.AlignWithShape(shape, slices); - if (!shape.Equals(src.shape)) - { - var newShape = ShapeHelper.AlignWithShape(shape, src.shape); - src = src.reshape(newShape); - } } // last dimension if (currentNDim == ndim - 1) { + var offset = (int)ShapeHelper.GetOffset(shape, indices); + var dst = data + offset * (int)dtypesize; System.Buffer.MemoryCopy(src.data.ToPointer(), dst.ToPointer(), src.bytesize, src.bytesize); return; } @@ -206,13 +209,56 @@ namespace Tensorflow.NumPy var stop = slice.Stop ?? (int)dims[currentNDim]; var step = slice.Step; - for (var i = start; i < stop; i += step) + if(step != 1) { - indices[currentNDim] = i; - var offset = (int)ShapeHelper.GetOffset(shape, indices); - dst = data + offset * (int)dtypesize; - var srcIndex = (i - start) / step; - SetData(src[srcIndex], dst, slices, indices, currentNDim); + for (var i = start; i < stop; i += step) + { + if (i >= dims[currentNDim]) + throw new OutOfRangeError($"Index should be in [0, {dims[currentNDim]}] but got {i}"); + + indices[currentNDim] = i; + if (currentNDim < ndim - src.ndim) + { + SetData(src, slices, indices, currentNDim); + } + else + { + var srcIndex = (i - start) / step; + SetData(src[srcIndex], slices, indices, currentNDim); + } + } + } + else + { + for (var i = start; i < stop; i++) + { + if (i >= dims[currentNDim]) + throw new OutOfRangeError($"Index should be in [0, {dims[currentNDim]}] but got {i}"); + + indices[currentNDim] = i; + if (currentNDim < ndim - src.ndim) + { + SetData(src, slices, indices, currentNDim); + } + // last dimension + else if(currentNDim == ndim - 1) + { + SetData(src, slices, indices, currentNDim); + break; + } + else if(SliceHelper.IsContinuousBlock(slices, currentNDim)) + { + var offset = (int)ShapeHelper.GetOffset(shape, indices); + var dst = data + offset * (int)dtypesize; + System.Buffer.MemoryCopy(src.data.ToPointer(), dst.ToPointer(), src.bytesize, src.bytesize); + return; + } + else + { + var srcIndex = i - start; + SetData(src[srcIndex], slices, indices, currentNDim); + } + } } // reset indices diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs index 376183f3..e5bcf749 100644 --- a/src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs +++ b/src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs @@ -16,6 +16,8 @@ namespace Tensorflow.NumPy public static NDArray operator *(NDArray lhs, NDArray rhs) => new NDArray(BinaryOpWrapper("mul", lhs, rhs)); [AutoNumPy] public static NDArray operator /(NDArray lhs, NDArray rhs) => new NDArray(BinaryOpWrapper("div", lhs, rhs)); + [AutoNumPy] + public static NDArray operator %(NDArray lhs, NDArray rhs) => new NDArray(BinaryOpWrapper("mod", lhs, rhs)); [AutoNumPy] public static NDArray operator >(NDArray lhs, NDArray rhs) => new NDArray(gen_math_ops.greater(lhs, rhs)); [AutoNumPy] diff --git a/src/TensorFlowNET.Core/NumPy/SliceHelper.cs b/src/TensorFlowNET.Core/NumPy/SliceHelper.cs index d0739eff..30a14c9e 100644 --- a/src/TensorFlowNET.Core/NumPy/SliceHelper.cs +++ b/src/TensorFlowNET.Core/NumPy/SliceHelper.cs @@ -55,5 +55,16 @@ namespace Tensorflow.NumPy } return true; } + + public static bool IsContinuousBlock(Slice[] slices, int ndim) + { + for (int i = ndim + 1; i < slices.Length; i++) + { + if (slices[i].Equals(Slice.All)) + continue; + return false; + } + return true; + } } } diff --git a/test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs b/test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs index 573c2fd2..41bf1264 100644 --- a/test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs +++ b/test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs @@ -118,7 +118,7 @@ namespace TensorFlowNET.UnitTest.NumPy } [TestMethod] - public void slice_step() + public void slice_step_setter() { var array = np.arange(32).reshape((4, 8)); var s1 = array[Slice.All, new Slice(2, 5, 2)] + 1; @@ -131,5 +131,17 @@ namespace TensorFlowNET.UnitTest.NumPy Assert.AreEqual(array[2], new[] { 16, 17, 19, 19, 21, 21, 22, 23 }); Assert.AreEqual(array[3], new[] { 24, 25, 27, 27, 29, 29, 30, 31 }); } + + [TestMethod] + public void slice_step_setter_diff_shape() + { + var array = np.arange(32).reshape((4, 8)); + var s1 = np.array(new[] { 100, 200 }); + array[Slice.All, new Slice(2, 5, 2)] = s1; + Assert.AreEqual(array[0], new[] { 0, 1, 100, 3, 200, 5, 6, 7 }); + Assert.AreEqual(array[1], new[] { 8, 9, 100, 11, 200, 13, 14, 15 }); + Assert.AreEqual(array[2], new[] { 16, 17, 100, 19, 200, 21, 22, 23 }); + Assert.AreEqual(array[3], new[] { 24, 25, 100, 27, 200, 29, 30, 31 }); + } } }