diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs index be340789..a6022917 100644 --- a/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs +++ b/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs @@ -10,24 +10,32 @@ namespace Tensorflow.NumPy { public NDArray this[params int[] index] { - get => _tensor[index.Select(x => new Slice + get => GetData(index.Select(x => new Slice { Start = x, Stop = x + 1, IsIndex = true - }).ToArray()]; + })); - set => SetData(index.Select(x => new Slice + set => SetData(index.Select(x => { - Start = x, - Stop = x + 1, - IsIndex = true + if(x < 0) + x = (int)dims[0] + x; + + var slice = new Slice + { + Start = x, + Stop = x + 1, + IsIndex = true + }; + + return slice; }), value); } public NDArray this[params Slice[] slices] { - get => _tensor[slices]; + get => GetData(slices); set => SetData(slices, value); } @@ -44,6 +52,11 @@ namespace Tensorflow.NumPy } } + NDArray GetData(IEnumerable slices) + { + return _tensor[slices.ToArray()]; + } + void SetData(IEnumerable slices, NDArray array) => SetData(slices, array, -1, slices.Select(x => 0).ToArray()); @@ -61,7 +74,10 @@ namespace Tensorflow.NumPy { if (slice.Step != 1) - throw new NotImplementedException(""); + throw new NotImplementedException("slice.step should == 1"); + + if (slice.Start < 0) + throw new NotImplementedException("slice.start should > -1"); indices[indices.Length - 1] = slice.Start ?? 0; var offset = (ulong)ShapeHelper.GetOffset(shape, indices); diff --git a/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs b/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs index 538d5867..1493b05d 100644 --- a/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs +++ b/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs @@ -81,6 +81,9 @@ namespace Tensorflow.NumPy for (int i = 0; i < indices.Length; i++) offset += strides[i] * indices[i]; + if (offset < 0) + throw new NotImplementedException(""); + return offset; } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.String.cs b/src/TensorFlowNET.Core/Tensors/Tensor.String.cs index 642e3571..11a53279 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.String.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.String.cs @@ -29,7 +29,7 @@ namespace Tensorflow var tstr = c_api.TF_TensorData(handle); #if TRACK_TENSOR_LIFE - print($"New TString 0x{handle.ToString("x16")} {AllocationType} Data: 0x{tstr.ToString("x16")}"); + print($"New TString 0x{handle.ToString("x16")} Data: 0x{tstr.ToString("x16")}"); #endif for (int i = 0; i < buffer.Length; i++) { diff --git a/test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs b/test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs index 8100be3f..d1408445 100644 --- a/test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs +++ b/test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs @@ -5,6 +5,7 @@ using System.Linq; using System.Text; using Tensorflow; using Tensorflow.NumPy; +using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest.NumPy { @@ -53,5 +54,14 @@ namespace TensorFlowNET.UnitTest.NumPy Assert.AreEqual(y.shape, (1, 2)); Assert.AreEqual(y, np.array(new[] { 2, 3 }).reshape((1, 2))); } + + [TestMethod] + public void slice_out_bound() + { + var input_shape = tf.constant(new int[] { 1, 1 }); + var input_shape_val = input_shape.numpy(); + input_shape_val[(int)input_shape.size - 1] = 1; + input_shape.Dispose(); + } } }