Browse Source

fix memory crash when index < 0.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
3c020dc492
4 changed files with 38 additions and 9 deletions
  1. +24
    -8
      src/TensorFlowNET.Core/NumPy/NDArray.Index.cs
  2. +3
    -0
      src/TensorFlowNET.Core/NumPy/ShapeHelper.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.String.cs
  4. +10
    -0
      test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs

+ 24
- 8
src/TensorFlowNET.Core/NumPy/NDArray.Index.cs View File

@@ -10,24 +10,32 @@ namespace Tensorflow.NumPy
{
public NDArray this[params int[] index]
{
get => _tensor[index.Select(x => new Slice
get => GetData(index.Select(x => new Slice
{
Start = x,
Stop = x + 1,
IsIndex = true
}).ToArray()];
}));

set => SetData(index.Select(x => new Slice
set => SetData(index.Select(x =>
{
Start = x,
Stop = x + 1,
IsIndex = true
if(x < 0)
x = (int)dims[0] + x;
var slice = new Slice
{
Start = x,
Stop = x + 1,
IsIndex = true
};

return slice;
}), value);
}

public NDArray this[params Slice[] slices]
{
get => _tensor[slices];
get => GetData(slices);
set => SetData(slices, value);
}

@@ -44,6 +52,11 @@ namespace Tensorflow.NumPy
}
}

NDArray GetData(IEnumerable<Slice> slices)
{
return _tensor[slices.ToArray()];
}

void SetData(IEnumerable<Slice> slices, NDArray array)
=> SetData(slices, array, -1, slices.Select(x => 0).ToArray());

@@ -61,7 +74,10 @@ namespace Tensorflow.NumPy
{

if (slice.Step != 1)
throw new NotImplementedException("");
throw new NotImplementedException("slice.step should == 1");

if (slice.Start < 0)
throw new NotImplementedException("slice.start should > -1");

indices[indices.Length - 1] = slice.Start ?? 0;
var offset = (ulong)ShapeHelper.GetOffset(shape, indices);


+ 3
- 0
src/TensorFlowNET.Core/NumPy/ShapeHelper.cs View File

@@ -81,6 +81,9 @@ namespace Tensorflow.NumPy
for (int i = 0; i < indices.Length; i++)
offset += strides[i] * indices[i];

if (offset < 0)
throw new NotImplementedException("");

return offset;
}
}


+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensor.String.cs View File

@@ -29,7 +29,7 @@ namespace Tensorflow

var tstr = c_api.TF_TensorData(handle);
#if TRACK_TENSOR_LIFE
print($"New TString 0x{handle.ToString("x16")} {AllocationType} Data: 0x{tstr.ToString("x16")}");
print($"New TString 0x{handle.ToString("x16")} Data: 0x{tstr.ToString("x16")}");
#endif
for (int i = 0; i < buffer.Length; i++)
{


+ 10
- 0
test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs View File

@@ -5,6 +5,7 @@ using System.Linq;
using System.Text;
using Tensorflow;
using Tensorflow.NumPy;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.NumPy
{
@@ -53,5 +54,14 @@ namespace TensorFlowNET.UnitTest.NumPy
Assert.AreEqual(y.shape, (1, 2));
Assert.AreEqual(y, np.array(new[] { 2, 3 }).reshape((1, 2)));
}

[TestMethod]
public void slice_out_bound()
{
var input_shape = tf.constant(new int[] { 1, 1 });
var input_shape_val = input_shape.numpy();
input_shape_val[(int)input_shape.size - 1] = 1;
input_shape.Dispose();
}
}
}

Loading…
Cancel
Save