Browse Source

optimize slice.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
d88fec40d2
14 changed files with 209 additions and 37 deletions
  1. +5
    -2
      src/TensorFlowNET.Core/Binding.Util.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Data/MnistDataSet.cs
  3. +1
    -3
      src/TensorFlowNET.Core/Data/MnistModelLoader.cs
  4. +2
    -1
      src/TensorFlowNET.Core/NumPy/Axis.cs
  5. +58
    -12
      src/TensorFlowNET.Core/NumPy/NDArray.Index.cs
  6. +29
    -1
      src/TensorFlowNET.Core/NumPy/ShapeHelper.cs
  7. +56
    -0
      src/TensorFlowNET.Core/NumPy/SliceHelper.cs
  8. +7
    -4
      src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs
  9. +2
    -1
      src/TensorFlowNET.Core/Numpy/Slice.cs
  10. +5
    -10
      src/TensorFlowNET.Core/Tensors/SafeStringTensorHandle.cs
  11. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.String.cs
  12. +3
    -0
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  13. +1
    -1
      src/TensorFlowNET.Core/ops.cs
  14. +38
    -0
      test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs

+ 5
- 2
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -513,10 +513,13 @@ namespace Tensorflow
if (data is NDArray nd)
return nd.shape;

if (data is Tensor tensor)
else if (data is Tensor tensor)
return tensor.shape;

if (!data.GetType().IsArray)
else if (data is Axis axis)
return axis.IsScalar ? Shape.Scalar : new Shape(axis.axis);

else if (!data.GetType().IsArray)
return Shape.Scalar;

switch (data)


+ 1
- 1
src/TensorFlowNET.Core/Data/MnistDataSet.cs View File

@@ -17,7 +17,7 @@ namespace Tensorflow

NumOfExamples = (int)images.dims[0];

images = images.reshape((images.dims[0], images.dims[1] * images.dims[2]));
// images = images.reshape((images.dims[0], images.dims[1] * images.dims[2]));
images = images.astype(dataType);
// for debug np.multiply performance
var sw = new Stopwatch();


+ 1
- 3
src/TensorFlowNET.Core/Data/MnistModelLoader.cs View File

@@ -123,9 +123,7 @@ namespace Tensorflow

bytestream.Read(buf, 0, buf.Length);

var data = np.frombuffer(buf, new Shape(buf.Length), np.@byte);
data = data.reshape((num_images, rows, cols, 1));

var data = np.frombuffer(buf, (num_images, rows * cols), np.@byte);
return data;
}
}


+ 2
- 1
src/TensorFlowNET.Core/NumPy/Axis.cs View File

@@ -24,6 +24,7 @@ namespace Tensorflow
public record Axis(params int[] axis)
{
public int size => axis == null ? -1 : axis.Length;
public bool IsScalar { get; init; }

public int this[int index] => axis[index];

@@ -34,7 +35,7 @@ namespace Tensorflow
=> axis.axis[0];

public static implicit operator Axis(int axis)
=> new Axis(axis);
=> new Axis(axis) { IsScalar = true };

public static implicit operator Axis((int, int) axis)
=> new Axis(axis.Item1, axis.Item2);


+ 58
- 12
src/TensorFlowNET.Core/NumPy/NDArray.Index.cs View File

@@ -15,7 +15,7 @@ namespace Tensorflow.NumPy
Start = x,
Stop = x + 1,
IsIndex = true
}));
}).ToArray());

set => SetData(indices.Select(x =>
{
@@ -55,21 +55,58 @@ namespace Tensorflow.NumPy
}
}

NDArray GetData(IEnumerable<Slice> slices)
unsafe NDArray GetData(Slice[] slices)
{
if (shape.IsScalar)
return GetScalar();

if (SliceHelper.AreAllIndex(slices, out var indices1))
{
var newshape = ShapeHelper.GetShape(shape, slices);
if (newshape.IsScalar)
{
var offset = ShapeHelper.GetOffset(shape, indices1);
return GetScalar((ulong)offset);
}
else
{
return GetArrayData(newshape, indices1);
}
}
else if (slices.Count() == 1)
{
var slice = slices[0];
if (slice.Step == 1)
{
var newshape = ShapeHelper.GetShape(shape, slice);
var array = new NDArray(newshape, dtype: dtype);

var new_dims = new int[shape.ndim];
new_dims[0] = slice.Start ?? 0;
//for (int i = 1; i < shape.ndim; i++)
//new_dims[i] = (int)shape.dims[i];

var offset = ShapeHelper.GetOffset(shape, new_dims);
var src = (byte*)data + (ulong)offset * dtypesize;
var dst = (byte*)array.data;
var len = (ulong)newshape.size * dtypesize;

System.Buffer.MemoryCopy(src, dst, len, len);

return array;
}
}

// default, performance is bad
var tensor = base[slices.ToArray()];
if (tensor.Handle == null)
{
if (tf.executing_eagerly())
tensor = tf.defaultSession.eval(tensor);
else
return new NDArray(tensor);
}
return new NDArray(tensor);

return new NDArray(tensor, tf.executing_eagerly());
}

unsafe T GetAtIndex<T>(params int[] indices) where T : unmanaged
@@ -78,17 +115,26 @@ namespace Tensorflow.NumPy
return *((T*)data + offset);
}

NDArray GetScalar()
unsafe NDArray GetScalar(ulong offset = 0)
{
var array = new NDArray(Shape.Scalar, dtype: dtype);
unsafe
{
var src = (byte*)data + dtypesize;
System.Buffer.MemoryCopy(src, array.buffer.ToPointer(), bytesize, bytesize);
}
var src = (byte*)data + offset * dtypesize;
System.Buffer.MemoryCopy(src, array.buffer.ToPointer(), dtypesize, dtypesize);
return array;
}

unsafe NDArray GetArrayData(Shape newshape, int[] indices)
{
var offset = ShapeHelper.GetOffset(shape, indices);
var len = (ulong)newshape.size * dtypesize;
var array = new NDArray(newshape, dtype: dtype);

var src = (byte*)data + (ulong)offset * dtypesize;
System.Buffer.MemoryCopy(src, array.data.ToPointer(), len, len);

return array;
}
NDArray GetData(int[] indices, int axis = 0)
{
if (shape.IsScalar)


+ 29
- 1
src/TensorFlowNET.Core/NumPy/ShapeHelper.cs View File

@@ -5,7 +5,7 @@ using System.Text;

namespace Tensorflow.NumPy
{
internal class ShapeHelper
public class ShapeHelper
{
public static long GetSize(Shape shape)
{
@@ -41,6 +41,34 @@ namespace Tensorflow.NumPy
return strides;
}

public static Shape GetShape(Shape shape1, params Slice[] slices)
{
var new_dims = shape1.dims.ToArray();
slices = SliceHelper.AlignWithShape(shape1, slices);

for (int i = 0; i < shape1.dims.Length; i++)
{
Slice slice = slices[i];
if (slice.Equals(Slice.All))
new_dims[i] = shape1.dims[i];
else if (slice.IsIndex)
new_dims[i] = 1;
else // range
new_dims[i] = (slice.Stop ?? shape1.dims[i]) - (slice.Start ?? 0);
}

// strip first dim if is index
var return_dims = new List<long>();
for (int i = 0; i< new_dims.Length; i++)
{
if (slices[i].IsIndex)
continue;
return_dims.add(new_dims[i]);
}

return new Shape(return_dims.ToArray());
}

public static bool Equals(Shape shape, object target)
{
switch (target)


+ 56
- 0
src/TensorFlowNET.Core/NumPy/SliceHelper.cs View File

@@ -0,0 +1,56 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace Tensorflow.NumPy
{
public class SliceHelper
{
public static Slice[] AlignWithShape(Shape shape, Slice[] slices)
{
// align slices
var ndim = shape.ndim;
var new_slices = new List<Slice>();
var slice_index = 0;

for (int i = 0; i < ndim; i++)
{
if (slice_index > slices.Length - 1)
{
new_slices.Add(Slice.All);
continue;
}

if (slices[slice_index] == Slice.All)
{
new_slices.Add(Slice.All);
for (int j = 0; j < ndim - slices.Length; j++)
{
new_slices.Add(Slice.All);
i++;
}
}
else
{
new_slices.Add(slices[slice_index]);
}
slice_index++;
}

return new_slices.ToArray();
}

public static bool AreAllIndex(Slice[] slices, out int[] indices)
{
indices = new int[slices.Length];
for (int i = 0; i< slices.Length; i++)
{
indices[i] = slices[i].Start ?? 0;
if (!slices[i].IsIndex)
return false;
}
return true;
}
}
}

+ 7
- 4
src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs View File

@@ -28,7 +28,7 @@ namespace Tensorflow.NumPy
public NDArray(IntPtr address, Shape shape, TF_DataType dtype)
: base(address, shape, dtype) { NewEagerTensorHandle(); }

public NDArray(Tensor tensor) : base(tensor.Handle)
public NDArray(Tensor tensor, bool eval = true) : base(tensor.Handle)
{
if (_handle is null)
{
@@ -53,9 +53,12 @@ namespace Tensorflow.NumPy

void NewEagerTensorHandle()
{
_id = ops.uid();
_eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle);
tf.Status.Check(true);
if(_handle is not null)
{
_id = ops.uid();
_eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle);
tf.Status.Check(true);
}
}
}
}

+ 2
- 1
src/TensorFlowNET.Core/Numpy/Slice.cs View File

@@ -115,11 +115,12 @@ namespace Tensorflow
/// <param name="start">Start index of the slice, null means from the start of the array</param>
/// <param name="stop">Stop index (first index after end of slice), null means to the end of the array</param>
/// <param name="step">Optional step to select every n-th element, defaults to 1</param>
public Slice(int? start = null, int? stop = null, int step = 1)
public Slice(int? start = null, int? stop = null, int step = 1, bool isIndex = false)
{
Start = start;
Stop = stop;
Step = step;
IsIndex = isIndex;
}

public Slice(string slice_notation)


+ 5
- 10
src/TensorFlowNET.Core/Tensors/SafeStringTensorHandle.cs View File

@@ -8,7 +8,7 @@ namespace Tensorflow
public sealed class SafeStringTensorHandle : SafeTensorHandle
{
Shape _shape;
SafeTensorHandle _handle;
IntPtr _handle;
const int TF_TSRING_SIZE = 24;

protected SafeStringTensorHandle()
@@ -18,7 +18,7 @@ namespace Tensorflow
public SafeStringTensorHandle(SafeTensorHandle handle, Shape shape)
: base(handle.DangerousGetHandle())
{
_handle = handle;
_handle = c_api.TF_TensorData(handle);
_shape = shape;
}

@@ -28,15 +28,10 @@ namespace Tensorflow
print($"Delete StringTensorHandle 0x{handle.ToString("x16")}");
#endif

long size = 1;
foreach (var s in _shape.dims)
size *= s;
var tstr = c_api.TF_TensorData(_handle);

for (int i = 0; i < size; i++)
for (int i = 0; i < _shape.size; i++)
{
c_api.TF_StringDealloc(tstr);
tstr += TF_TSRING_SIZE;
c_api.TF_StringDealloc(_handle);
_handle += TF_TSRING_SIZE;
}

SetHandle(IntPtr.Zero);


+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensor.String.cs View File

@@ -23,7 +23,7 @@ namespace Tensorflow
public SafeStringTensorHandle StringTensor(byte[][] buffer, Shape shape)
{
var handle = c_api.TF_AllocateTensor(TF_DataType.TF_STRING,
shape.ndim == 0 ? null : shape.dims,
shape.dims,
shape.ndim,
(ulong)shape.size * TF_TSRING_SIZE);



+ 3
- 0
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -472,6 +472,9 @@ would not be rank 1.", tensor.op.get_attr("axis")));

public static string to_numpy_string(Tensor tensor)
{
if (tensor.buffer == IntPtr.Zero)
return "Empty";

var dtype = tensor.dtype;
var shape = tensor.shape;



+ 1
- 1
src/TensorFlowNET.Core/ops.cs View File

@@ -161,7 +161,7 @@ namespace Tensorflow
IEnumerable<Tensor> tensors => array_ops._autopacking_helper(tensors, dtype, name == null ? "packed" : name),
RefVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref),
ResourceVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref),
Axis ts => constant_op.constant(ts.axis, dtype: dtype, name: name),
Axis ts => constant_op.constant(ts, dtype: dtype, name: name),
Shape ts => constant_op.constant(ts.dims, dtype: dtype, name: name),
string str => constant_op.constant(str, dtype: tf.@string, name: name),
string[] str => constant_op.constant(str, dtype: tf.@string, name: name),


+ 38
- 0
test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs View File

@@ -63,5 +63,43 @@ namespace TensorFlowNET.UnitTest.NumPy
input_shape_val[(int)input_shape.size - 1] = 1;
input_shape.Dispose();
}

[TestMethod]
public void shape_helper_get_shape_3dim()
{
var x = np.arange(24).reshape((4, 3, 2));
var shape1 = ShapeHelper.GetShape(x.shape, new Slice(1, isIndex: true));
Assert.AreEqual(shape1, (3, 2));

var shape2 = ShapeHelper.GetShape(x.shape, new Slice(1));
Assert.AreEqual(shape2, (3, 3, 2));

var shape3 = ShapeHelper.GetShape(x.shape, new Slice(2), Slice.All);
Assert.AreEqual(shape3, (2, 3, 2));

var shape4 = ShapeHelper.GetShape(x.shape, new Slice(1, isIndex: true), new Slice(2));
Assert.AreEqual(shape4, (1, 2));

var shape5 = ShapeHelper.GetShape(x.shape, new Slice(1, isIndex: true), new Slice(1));
Assert.AreEqual(shape5, (2, 2));

var shape6 = ShapeHelper.GetShape(x.shape, new Slice(1), new Slice(1, isIndex: true), new Slice(1));
Assert.AreEqual(shape6, (3, 1));
}

[TestMethod]
public void shape_helper_get_shape_4dim()
{
var x = np.arange(120).reshape((4, 3, 2, 5));
var slices = new[] { new Slice(1, isIndex: true), new Slice(1), new Slice(0, isIndex: true), new Slice(1) };
var shape1 = ShapeHelper.GetShape(x.shape, slices);
Assert.AreEqual(shape1, (2, 4));

var shape2 = ShapeHelper.GetShape(x.shape, Slice.All);
Assert.AreEqual(shape2, (4, 3, 2, 5));

var shape3 = ShapeHelper.GetShape(x.shape, Slice.All, new Slice(0, isIndex: true));
Assert.AreEqual(shape3, (4, 3, 2));
}
}
}

Loading…
Cancel
Save