diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index 9f11e5b8..c79b6f3a 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -513,10 +513,13 @@ namespace Tensorflow if (data is NDArray nd) return nd.shape; - if (data is Tensor tensor) + else if (data is Tensor tensor) return tensor.shape; - if (!data.GetType().IsArray) + else if (data is Axis axis) + return axis.IsScalar ? Shape.Scalar : new Shape(axis.axis); + + else if (!data.GetType().IsArray) return Shape.Scalar; switch (data) diff --git a/src/TensorFlowNET.Core/Data/MnistDataSet.cs b/src/TensorFlowNET.Core/Data/MnistDataSet.cs index 8ccb0487..7e5d0cc2 100644 --- a/src/TensorFlowNET.Core/Data/MnistDataSet.cs +++ b/src/TensorFlowNET.Core/Data/MnistDataSet.cs @@ -17,7 +17,7 @@ namespace Tensorflow NumOfExamples = (int)images.dims[0]; - images = images.reshape((images.dims[0], images.dims[1] * images.dims[2])); + // images = images.reshape((images.dims[0], images.dims[1] * images.dims[2])); images = images.astype(dataType); // for debug np.multiply performance var sw = new Stopwatch(); diff --git a/src/TensorFlowNET.Core/Data/MnistModelLoader.cs b/src/TensorFlowNET.Core/Data/MnistModelLoader.cs index 514dbfb7..2e033f3e 100644 --- a/src/TensorFlowNET.Core/Data/MnistModelLoader.cs +++ b/src/TensorFlowNET.Core/Data/MnistModelLoader.cs @@ -123,9 +123,7 @@ namespace Tensorflow bytestream.Read(buf, 0, buf.Length); - var data = np.frombuffer(buf, new Shape(buf.Length), np.@byte); - data = data.reshape((num_images, rows, cols, 1)); - + var data = np.frombuffer(buf, (num_images, rows * cols), np.@byte); return data; } } diff --git a/src/TensorFlowNET.Core/NumPy/Axis.cs b/src/TensorFlowNET.Core/NumPy/Axis.cs index 7f32fef3..3c43686d 100644 --- a/src/TensorFlowNET.Core/NumPy/Axis.cs +++ b/src/TensorFlowNET.Core/NumPy/Axis.cs @@ -24,6 +24,7 @@ namespace Tensorflow public record Axis(params int[] axis) { public int size => axis == null ? -1 : axis.Length; + public bool IsScalar { get; init; } public int this[int index] => axis[index]; @@ -34,7 +35,7 @@ namespace Tensorflow => axis.axis[0]; public static implicit operator Axis(int axis) - => new Axis(axis); + => new Axis(axis) { IsScalar = true }; public static implicit operator Axis((int, int) axis) => new Axis(axis.Item1, axis.Item2); diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs index 2966e17c..58cbb4ce 100644 --- a/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs +++ b/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs @@ -15,7 +15,7 @@ namespace Tensorflow.NumPy Start = x, Stop = x + 1, IsIndex = true - })); + }).ToArray()); set => SetData(indices.Select(x => { @@ -55,21 +55,58 @@ namespace Tensorflow.NumPy } } - NDArray GetData(IEnumerable slices) + + unsafe NDArray GetData(Slice[] slices) { if (shape.IsScalar) return GetScalar(); + if (SliceHelper.AreAllIndex(slices, out var indices1)) + { + var newshape = ShapeHelper.GetShape(shape, slices); + if (newshape.IsScalar) + { + var offset = ShapeHelper.GetOffset(shape, indices1); + return GetScalar((ulong)offset); + } + else + { + return GetArrayData(newshape, indices1); + } + } + else if (slices.Count() == 1) + { + var slice = slices[0]; + if (slice.Step == 1) + { + var newshape = ShapeHelper.GetShape(shape, slice); + var array = new NDArray(newshape, dtype: dtype); + + var new_dims = new int[shape.ndim]; + new_dims[0] = slice.Start ?? 0; + //for (int i = 1; i < shape.ndim; i++) + //new_dims[i] = (int)shape.dims[i]; + + var offset = ShapeHelper.GetOffset(shape, new_dims); + var src = (byte*)data + (ulong)offset * dtypesize; + var dst = (byte*)array.data; + var len = (ulong)newshape.size * dtypesize; + + System.Buffer.MemoryCopy(src, dst, len, len); + + return array; + } + } + + // default, performance is bad var tensor = base[slices.ToArray()]; if (tensor.Handle == null) { if (tf.executing_eagerly()) tensor = tf.defaultSession.eval(tensor); - else - return new NDArray(tensor); } - - return new NDArray(tensor); + + return new NDArray(tensor, tf.executing_eagerly()); } unsafe T GetAtIndex(params int[] indices) where T : unmanaged @@ -78,17 +115,26 @@ namespace Tensorflow.NumPy return *((T*)data + offset); } - NDArray GetScalar() + unsafe NDArray GetScalar(ulong offset = 0) { var array = new NDArray(Shape.Scalar, dtype: dtype); - unsafe - { - var src = (byte*)data + dtypesize; - System.Buffer.MemoryCopy(src, array.buffer.ToPointer(), bytesize, bytesize); - } + var src = (byte*)data + offset * dtypesize; + System.Buffer.MemoryCopy(src, array.buffer.ToPointer(), dtypesize, dtypesize); return array; } + unsafe NDArray GetArrayData(Shape newshape, int[] indices) + { + var offset = ShapeHelper.GetOffset(shape, indices); + var len = (ulong)newshape.size * dtypesize; + var array = new NDArray(newshape, dtype: dtype); + + var src = (byte*)data + (ulong)offset * dtypesize; + System.Buffer.MemoryCopy(src, array.data.ToPointer(), len, len); + + return array; + } + NDArray GetData(int[] indices, int axis = 0) { if (shape.IsScalar) diff --git a/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs b/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs index 1493b05d..6e2a7926 100644 --- a/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs +++ b/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs @@ -5,7 +5,7 @@ using System.Text; namespace Tensorflow.NumPy { - internal class ShapeHelper + public class ShapeHelper { public static long GetSize(Shape shape) { @@ -41,6 +41,34 @@ namespace Tensorflow.NumPy return strides; } + public static Shape GetShape(Shape shape1, params Slice[] slices) + { + var new_dims = shape1.dims.ToArray(); + slices = SliceHelper.AlignWithShape(shape1, slices); + + for (int i = 0; i < shape1.dims.Length; i++) + { + Slice slice = slices[i]; + if (slice.Equals(Slice.All)) + new_dims[i] = shape1.dims[i]; + else if (slice.IsIndex) + new_dims[i] = 1; + else // range + new_dims[i] = (slice.Stop ?? shape1.dims[i]) - (slice.Start ?? 0); + } + + // strip first dim if is index + var return_dims = new List(); + for (int i = 0; i< new_dims.Length; i++) + { + if (slices[i].IsIndex) + continue; + return_dims.add(new_dims[i]); + } + + return new Shape(return_dims.ToArray()); + } + public static bool Equals(Shape shape, object target) { switch (target) diff --git a/src/TensorFlowNET.Core/NumPy/SliceHelper.cs b/src/TensorFlowNET.Core/NumPy/SliceHelper.cs new file mode 100644 index 00000000..1090ce27 --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/SliceHelper.cs @@ -0,0 +1,56 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Tensorflow.NumPy +{ + public class SliceHelper + { + public static Slice[] AlignWithShape(Shape shape, Slice[] slices) + { + // align slices + var ndim = shape.ndim; + var new_slices = new List(); + var slice_index = 0; + + for (int i = 0; i < ndim; i++) + { + if (slice_index > slices.Length - 1) + { + new_slices.Add(Slice.All); + continue; + } + + if (slices[slice_index] == Slice.All) + { + new_slices.Add(Slice.All); + for (int j = 0; j < ndim - slices.Length; j++) + { + new_slices.Add(Slice.All); + i++; + } + } + else + { + new_slices.Add(slices[slice_index]); + } + slice_index++; + } + + return new_slices.ToArray(); + } + + public static bool AreAllIndex(Slice[] slices, out int[] indices) + { + indices = new int[slices.Length]; + for (int i = 0; i< slices.Length; i++) + { + indices[i] = slices[i].Start ?? 0; + if (!slices[i].IsIndex) + return false; + } + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs b/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs index 7e19029d..87658a32 100644 --- a/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs +++ b/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs @@ -28,7 +28,7 @@ namespace Tensorflow.NumPy public NDArray(IntPtr address, Shape shape, TF_DataType dtype) : base(address, shape, dtype) { NewEagerTensorHandle(); } - public NDArray(Tensor tensor) : base(tensor.Handle) + public NDArray(Tensor tensor, bool eval = true) : base(tensor.Handle) { if (_handle is null) { @@ -53,9 +53,12 @@ namespace Tensorflow.NumPy void NewEagerTensorHandle() { - _id = ops.uid(); - _eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle); - tf.Status.Check(true); + if(_handle is not null) + { + _id = ops.uid(); + _eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle); + tf.Status.Check(true); + } } } } diff --git a/src/TensorFlowNET.Core/Numpy/Slice.cs b/src/TensorFlowNET.Core/Numpy/Slice.cs index 2bb73fe8..676ec5e9 100644 --- a/src/TensorFlowNET.Core/Numpy/Slice.cs +++ b/src/TensorFlowNET.Core/Numpy/Slice.cs @@ -115,11 +115,12 @@ namespace Tensorflow /// Start index of the slice, null means from the start of the array /// Stop index (first index after end of slice), null means to the end of the array /// Optional step to select every n-th element, defaults to 1 - public Slice(int? start = null, int? stop = null, int step = 1) + public Slice(int? start = null, int? stop = null, int step = 1, bool isIndex = false) { Start = start; Stop = stop; Step = step; + IsIndex = isIndex; } public Slice(string slice_notation) diff --git a/src/TensorFlowNET.Core/Tensors/SafeStringTensorHandle.cs b/src/TensorFlowNET.Core/Tensors/SafeStringTensorHandle.cs index ea7ec4e2..5730f0cd 100644 --- a/src/TensorFlowNET.Core/Tensors/SafeStringTensorHandle.cs +++ b/src/TensorFlowNET.Core/Tensors/SafeStringTensorHandle.cs @@ -8,7 +8,7 @@ namespace Tensorflow public sealed class SafeStringTensorHandle : SafeTensorHandle { Shape _shape; - SafeTensorHandle _handle; + IntPtr _handle; const int TF_TSRING_SIZE = 24; protected SafeStringTensorHandle() @@ -18,7 +18,7 @@ namespace Tensorflow public SafeStringTensorHandle(SafeTensorHandle handle, Shape shape) : base(handle.DangerousGetHandle()) { - _handle = handle; + _handle = c_api.TF_TensorData(handle); _shape = shape; } @@ -28,15 +28,10 @@ namespace Tensorflow print($"Delete StringTensorHandle 0x{handle.ToString("x16")}"); #endif - long size = 1; - foreach (var s in _shape.dims) - size *= s; - var tstr = c_api.TF_TensorData(_handle); - - for (int i = 0; i < size; i++) + for (int i = 0; i < _shape.size; i++) { - c_api.TF_StringDealloc(tstr); - tstr += TF_TSRING_SIZE; + c_api.TF_StringDealloc(_handle); + _handle += TF_TSRING_SIZE; } SetHandle(IntPtr.Zero); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.String.cs b/src/TensorFlowNET.Core/Tensors/Tensor.String.cs index 79406aa2..50976550 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.String.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.String.cs @@ -23,7 +23,7 @@ namespace Tensorflow public SafeStringTensorHandle StringTensor(byte[][] buffer, Shape shape) { var handle = c_api.TF_AllocateTensor(TF_DataType.TF_STRING, - shape.ndim == 0 ? null : shape.dims, + shape.dims, shape.ndim, (ulong)shape.size * TF_TSRING_SIZE); diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 99b1f2d1..f694de82 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -472,6 +472,9 @@ would not be rank 1.", tensor.op.get_attr("axis"))); public static string to_numpy_string(Tensor tensor) { + if (tensor.buffer == IntPtr.Zero) + return "Empty"; + var dtype = tensor.dtype; var shape = tensor.shape; diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 159d2bf9..c4e79d13 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -161,7 +161,7 @@ namespace Tensorflow IEnumerable tensors => array_ops._autopacking_helper(tensors, dtype, name == null ? "packed" : name), RefVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref), ResourceVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref), - Axis ts => constant_op.constant(ts.axis, dtype: dtype, name: name), + Axis ts => constant_op.constant(ts, dtype: dtype, name: name), Shape ts => constant_op.constant(ts.dims, dtype: dtype, name: name), string str => constant_op.constant(str, dtype: tf.@string, name: name), string[] str => constant_op.constant(str, dtype: tf.@string, name: name), diff --git a/test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs b/test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs index d1408445..b46b4872 100644 --- a/test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs +++ b/test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs @@ -63,5 +63,43 @@ namespace TensorFlowNET.UnitTest.NumPy input_shape_val[(int)input_shape.size - 1] = 1; input_shape.Dispose(); } + + [TestMethod] + public void shape_helper_get_shape_3dim() + { + var x = np.arange(24).reshape((4, 3, 2)); + var shape1 = ShapeHelper.GetShape(x.shape, new Slice(1, isIndex: true)); + Assert.AreEqual(shape1, (3, 2)); + + var shape2 = ShapeHelper.GetShape(x.shape, new Slice(1)); + Assert.AreEqual(shape2, (3, 3, 2)); + + var shape3 = ShapeHelper.GetShape(x.shape, new Slice(2), Slice.All); + Assert.AreEqual(shape3, (2, 3, 2)); + + var shape4 = ShapeHelper.GetShape(x.shape, new Slice(1, isIndex: true), new Slice(2)); + Assert.AreEqual(shape4, (1, 2)); + + var shape5 = ShapeHelper.GetShape(x.shape, new Slice(1, isIndex: true), new Slice(1)); + Assert.AreEqual(shape5, (2, 2)); + + var shape6 = ShapeHelper.GetShape(x.shape, new Slice(1), new Slice(1, isIndex: true), new Slice(1)); + Assert.AreEqual(shape6, (3, 1)); + } + + [TestMethod] + public void shape_helper_get_shape_4dim() + { + var x = np.arange(120).reshape((4, 3, 2, 5)); + var slices = new[] { new Slice(1, isIndex: true), new Slice(1), new Slice(0, isIndex: true), new Slice(1) }; + var shape1 = ShapeHelper.GetShape(x.shape, slices); + Assert.AreEqual(shape1, (2, 4)); + + var shape2 = ShapeHelper.GetShape(x.shape, Slice.All); + Assert.AreEqual(shape2, (4, 3, 2, 5)); + + var shape3 = ShapeHelper.GetShape(x.shape, Slice.All, new Slice(0, isIndex: true)); + Assert.AreEqual(shape3, (4, 3, 2)); + } } }