Browse Source

slice_step_setter_diff_shape

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
30f8f67351
4 changed files with 86 additions and 15 deletions
  1. +60
    -14
      src/TensorFlowNET.Core/NumPy/NDArray.Index.cs
  2. +2
    -0
      src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs
  3. +11
    -0
      src/TensorFlowNET.Core/NumPy/SliceHelper.cs
  4. +13
    -1
      test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs

+ 60
- 14
src/TensorFlowNET.Core/NumPy/NDArray.Index.cs View File

@@ -57,7 +57,7 @@ namespace Tensorflow.NumPy
}
}

[AutoNumPy]
unsafe NDArray GetData(Slice[] slices)
{
if (shape.IsScalar)
@@ -170,9 +170,9 @@ namespace Tensorflow.NumPy
}

void SetData(IEnumerable<Slice> slices, NDArray array)
=> SetData(array, data, slices.ToArray(), new int[shape.ndim].ToArray(), -1);
=> SetData(array, slices.ToArray(), new int[shape.ndim].ToArray(), -1);

unsafe void SetData(NDArray src, IntPtr dst, Slice[] slices, int[] indices, int currentNDim)
unsafe void SetData(NDArray src, Slice[] slices, int[] indices, int currentNDim)
{
if (dtype != src.dtype)
src = src.astype(dtype);
@@ -181,20 +181,23 @@ namespace Tensorflow.NumPy
if (!slices.Any())
return;

if (shape.Equals(src.shape))
{
System.Buffer.MemoryCopy(src.data.ToPointer(), data.ToPointer(), src.bytesize, src.bytesize);
return;
}

// first iteration
if(currentNDim == -1)
{
slices = SliceHelper.AlignWithShape(shape, slices);
if (!shape.Equals(src.shape))
{
var newShape = ShapeHelper.AlignWithShape(shape, src.shape);
src = src.reshape(newShape);
}
}

// last dimension
if (currentNDim == ndim - 1)
{
var offset = (int)ShapeHelper.GetOffset(shape, indices);
var dst = data + offset * (int)dtypesize;
System.Buffer.MemoryCopy(src.data.ToPointer(), dst.ToPointer(), src.bytesize, src.bytesize);
return;
}
@@ -206,13 +209,56 @@ namespace Tensorflow.NumPy
var stop = slice.Stop ?? (int)dims[currentNDim];
var step = slice.Step;

for (var i = start; i < stop; i += step)
if(step != 1)
{
indices[currentNDim] = i;
var offset = (int)ShapeHelper.GetOffset(shape, indices);
dst = data + offset * (int)dtypesize;
var srcIndex = (i - start) / step;
SetData(src[srcIndex], dst, slices, indices, currentNDim);
for (var i = start; i < stop; i += step)
{
if (i >= dims[currentNDim])
throw new OutOfRangeError($"Index should be in [0, {dims[currentNDim]}] but got {i}");

indices[currentNDim] = i;
if (currentNDim < ndim - src.ndim)
{
SetData(src, slices, indices, currentNDim);
}
else
{
var srcIndex = (i - start) / step;
SetData(src[srcIndex], slices, indices, currentNDim);
}
}
}
else
{
for (var i = start; i < stop; i++)
{
if (i >= dims[currentNDim])
throw new OutOfRangeError($"Index should be in [0, {dims[currentNDim]}] but got {i}");

indices[currentNDim] = i;
if (currentNDim < ndim - src.ndim)
{
SetData(src, slices, indices, currentNDim);
}
// last dimension
else if(currentNDim == ndim - 1)
{
SetData(src, slices, indices, currentNDim);
break;
}
else if(SliceHelper.IsContinuousBlock(slices, currentNDim))
{
var offset = (int)ShapeHelper.GetOffset(shape, indices);
var dst = data + offset * (int)dtypesize;
System.Buffer.MemoryCopy(src.data.ToPointer(), dst.ToPointer(), src.bytesize, src.bytesize);
return;
}
else
{
var srcIndex = i - start;
SetData(src[srcIndex], slices, indices, currentNDim);
}
}
}

// reset indices


+ 2
- 0
src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs View File

@@ -16,6 +16,8 @@ namespace Tensorflow.NumPy
public static NDArray operator *(NDArray lhs, NDArray rhs) => new NDArray(BinaryOpWrapper("mul", lhs, rhs));
[AutoNumPy]
public static NDArray operator /(NDArray lhs, NDArray rhs) => new NDArray(BinaryOpWrapper("div", lhs, rhs));
[AutoNumPy]
public static NDArray operator %(NDArray lhs, NDArray rhs) => new NDArray(BinaryOpWrapper("mod", lhs, rhs));
[AutoNumPy]
public static NDArray operator >(NDArray lhs, NDArray rhs) => new NDArray(gen_math_ops.greater(lhs, rhs));
[AutoNumPy]


+ 11
- 0
src/TensorFlowNET.Core/NumPy/SliceHelper.cs View File

@@ -55,5 +55,16 @@ namespace Tensorflow.NumPy
}
return true;
}

public static bool IsContinuousBlock(Slice[] slices, int ndim)
{
for (int i = ndim + 1; i < slices.Length; i++)
{
if (slices[i].Equals(Slice.All))
continue;
return false;
}
return true;
}
}
}

+ 13
- 1
test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs View File

@@ -118,7 +118,7 @@ namespace TensorFlowNET.UnitTest.NumPy
}

[TestMethod]
public void slice_step()
public void slice_step_setter()
{
var array = np.arange(32).reshape((4, 8));
var s1 = array[Slice.All, new Slice(2, 5, 2)] + 1;
@@ -131,5 +131,17 @@ namespace TensorFlowNET.UnitTest.NumPy
Assert.AreEqual(array[2], new[] { 16, 17, 19, 19, 21, 21, 22, 23 });
Assert.AreEqual(array[3], new[] { 24, 25, 27, 27, 29, 29, 30, 31 });
}

[TestMethod]
public void slice_step_setter_diff_shape()
{
var array = np.arange(32).reshape((4, 8));
var s1 = np.array(new[] { 100, 200 });
array[Slice.All, new Slice(2, 5, 2)] = s1;
Assert.AreEqual(array[0], new[] { 0, 1, 100, 3, 200, 5, 6, 7 });
Assert.AreEqual(array[1], new[] { 8, 9, 100, 11, 200, 13, 14, 15 });
Assert.AreEqual(array[2], new[] { 16, 17, 100, 19, 200, 21, 22, 23 });
Assert.AreEqual(array[3], new[] { 24, 25, 100, 27, 200, 29, 30, 31 });
}
}
}

Loading…
Cancel
Save