Browse Source

slice_step_setter_diff_shape

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
30f8f67351
4 changed files with 86 additions and 15 deletions
  1. +60
    -14
      src/TensorFlowNET.Core/NumPy/NDArray.Index.cs
  2. +2
    -0
      src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs
  3. +11
    -0
      src/TensorFlowNET.Core/NumPy/SliceHelper.cs
  4. +13
    -1
      test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs

+ 60
- 14
src/TensorFlowNET.Core/NumPy/NDArray.Index.cs View File

@@ -57,7 +57,7 @@ namespace Tensorflow.NumPy
} }
} }


[AutoNumPy]
unsafe NDArray GetData(Slice[] slices) unsafe NDArray GetData(Slice[] slices)
{ {
if (shape.IsScalar) if (shape.IsScalar)
@@ -170,9 +170,9 @@ namespace Tensorflow.NumPy
} }


void SetData(IEnumerable<Slice> slices, NDArray array) void SetData(IEnumerable<Slice> slices, NDArray array)
=> SetData(array, data, slices.ToArray(), new int[shape.ndim].ToArray(), -1);
=> SetData(array, slices.ToArray(), new int[shape.ndim].ToArray(), -1);


unsafe void SetData(NDArray src, IntPtr dst, Slice[] slices, int[] indices, int currentNDim)
unsafe void SetData(NDArray src, Slice[] slices, int[] indices, int currentNDim)
{ {
if (dtype != src.dtype) if (dtype != src.dtype)
src = src.astype(dtype); src = src.astype(dtype);
@@ -181,20 +181,23 @@ namespace Tensorflow.NumPy
if (!slices.Any()) if (!slices.Any())
return; return;


if (shape.Equals(src.shape))
{
System.Buffer.MemoryCopy(src.data.ToPointer(), data.ToPointer(), src.bytesize, src.bytesize);
return;
}

// first iteration // first iteration
if(currentNDim == -1) if(currentNDim == -1)
{ {
slices = SliceHelper.AlignWithShape(shape, slices); slices = SliceHelper.AlignWithShape(shape, slices);
if (!shape.Equals(src.shape))
{
var newShape = ShapeHelper.AlignWithShape(shape, src.shape);
src = src.reshape(newShape);
}
} }


// last dimension // last dimension
if (currentNDim == ndim - 1) if (currentNDim == ndim - 1)
{ {
var offset = (int)ShapeHelper.GetOffset(shape, indices);
var dst = data + offset * (int)dtypesize;
System.Buffer.MemoryCopy(src.data.ToPointer(), dst.ToPointer(), src.bytesize, src.bytesize); System.Buffer.MemoryCopy(src.data.ToPointer(), dst.ToPointer(), src.bytesize, src.bytesize);
return; return;
} }
@@ -206,13 +209,56 @@ namespace Tensorflow.NumPy
var stop = slice.Stop ?? (int)dims[currentNDim]; var stop = slice.Stop ?? (int)dims[currentNDim];
var step = slice.Step; var step = slice.Step;


for (var i = start; i < stop; i += step)
if(step != 1)
{ {
indices[currentNDim] = i;
var offset = (int)ShapeHelper.GetOffset(shape, indices);
dst = data + offset * (int)dtypesize;
var srcIndex = (i - start) / step;
SetData(src[srcIndex], dst, slices, indices, currentNDim);
for (var i = start; i < stop; i += step)
{
if (i >= dims[currentNDim])
throw new OutOfRangeError($"Index should be in [0, {dims[currentNDim]}] but got {i}");

indices[currentNDim] = i;
if (currentNDim < ndim - src.ndim)
{
SetData(src, slices, indices, currentNDim);
}
else
{
var srcIndex = (i - start) / step;
SetData(src[srcIndex], slices, indices, currentNDim);
}
}
}
else
{
for (var i = start; i < stop; i++)
{
if (i >= dims[currentNDim])
throw new OutOfRangeError($"Index should be in [0, {dims[currentNDim]}] but got {i}");

indices[currentNDim] = i;
if (currentNDim < ndim - src.ndim)
{
SetData(src, slices, indices, currentNDim);
}
// last dimension
else if(currentNDim == ndim - 1)
{
SetData(src, slices, indices, currentNDim);
break;
}
else if(SliceHelper.IsContinuousBlock(slices, currentNDim))
{
var offset = (int)ShapeHelper.GetOffset(shape, indices);
var dst = data + offset * (int)dtypesize;
System.Buffer.MemoryCopy(src.data.ToPointer(), dst.ToPointer(), src.bytesize, src.bytesize);
return;
}
else
{
var srcIndex = i - start;
SetData(src[srcIndex], slices, indices, currentNDim);
}
}
} }


// reset indices // reset indices


+ 2
- 0
src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs View File

@@ -16,6 +16,8 @@ namespace Tensorflow.NumPy
public static NDArray operator *(NDArray lhs, NDArray rhs) => new NDArray(BinaryOpWrapper("mul", lhs, rhs)); public static NDArray operator *(NDArray lhs, NDArray rhs) => new NDArray(BinaryOpWrapper("mul", lhs, rhs));
[AutoNumPy] [AutoNumPy]
public static NDArray operator /(NDArray lhs, NDArray rhs) => new NDArray(BinaryOpWrapper("div", lhs, rhs)); public static NDArray operator /(NDArray lhs, NDArray rhs) => new NDArray(BinaryOpWrapper("div", lhs, rhs));
[AutoNumPy]
public static NDArray operator %(NDArray lhs, NDArray rhs) => new NDArray(BinaryOpWrapper("mod", lhs, rhs));
[AutoNumPy] [AutoNumPy]
public static NDArray operator >(NDArray lhs, NDArray rhs) => new NDArray(gen_math_ops.greater(lhs, rhs)); public static NDArray operator >(NDArray lhs, NDArray rhs) => new NDArray(gen_math_ops.greater(lhs, rhs));
[AutoNumPy] [AutoNumPy]


+ 11
- 0
src/TensorFlowNET.Core/NumPy/SliceHelper.cs View File

@@ -55,5 +55,16 @@ namespace Tensorflow.NumPy
} }
return true; return true;
} }

public static bool IsContinuousBlock(Slice[] slices, int ndim)
{
for (int i = ndim + 1; i < slices.Length; i++)
{
if (slices[i].Equals(Slice.All))
continue;
return false;
}
return true;
}
} }
} }

+ 13
- 1
test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs View File

@@ -118,7 +118,7 @@ namespace TensorFlowNET.UnitTest.NumPy
} }


[TestMethod] [TestMethod]
public void slice_step()
public void slice_step_setter()
{ {
var array = np.arange(32).reshape((4, 8)); var array = np.arange(32).reshape((4, 8));
var s1 = array[Slice.All, new Slice(2, 5, 2)] + 1; var s1 = array[Slice.All, new Slice(2, 5, 2)] + 1;
@@ -131,5 +131,17 @@ namespace TensorFlowNET.UnitTest.NumPy
Assert.AreEqual(array[2], new[] { 16, 17, 19, 19, 21, 21, 22, 23 }); Assert.AreEqual(array[2], new[] { 16, 17, 19, 19, 21, 21, 22, 23 });
Assert.AreEqual(array[3], new[] { 24, 25, 27, 27, 29, 29, 30, 31 }); Assert.AreEqual(array[3], new[] { 24, 25, 27, 27, 29, 29, 30, 31 });
} }

[TestMethod]
public void slice_step_setter_diff_shape()
{
var array = np.arange(32).reshape((4, 8));
var s1 = np.array(new[] { 100, 200 });
array[Slice.All, new Slice(2, 5, 2)] = s1;
Assert.AreEqual(array[0], new[] { 0, 1, 100, 3, 200, 5, 6, 7 });
Assert.AreEqual(array[1], new[] { 8, 9, 100, 11, 200, 13, 14, 15 });
Assert.AreEqual(array[2], new[] { 16, 17, 100, 19, 200, 21, 22, 23 });
Assert.AreEqual(array[3], new[] { 24, 25, 100, 27, 200, 29, 30, 31 });
}
} }
} }

Loading…
Cancel
Save