From c7ee2308fbd9747cfce904ec7e51a8d9454d849d Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Mon, 9 Aug 2021 22:58:05 -0500 Subject: [PATCH] fix ToMultiDimArray --- .../NumPy/NDArrayConverter.cs | 57 +++++++++++++++++++ .../Numpy/NDArray.Creation.cs | 34 +++++------ src/TensorFlowNET.Core/Numpy/NDArray.cs | 6 +- src/TensorFlowNET.Keras/Saving/hdf5_format.cs | 10 ++-- .../Numpy/Array.Creation.Test.cs | 16 ++++++ 5 files changed, 100 insertions(+), 23 deletions(-) diff --git a/src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs b/src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs index 6e1b8da1..921c213c 100644 --- a/src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs +++ b/src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs @@ -30,5 +30,62 @@ namespace Tensorflow.NumPy TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single), _ => throw new NotImplementedException("") }; + + public static unsafe Array ToMultiDimArray(NDArray nd) where T : unmanaged + { + var ret = Array.CreateInstance(typeof(T), nd.shape.as_int_list()); + + var addr = ret switch + { + T[] array => Addr(array), + T[,] array => Addr(array), + T[,,] array => Addr(array), + T[,,,] array => Addr(array), + T[,,,,] array => Addr(array), + T[,,,,,] array => Addr(array), + _ => throw new NotImplementedException("") + }; + + System.Buffer.MemoryCopy(nd.data.ToPointer(), addr, nd.bytesize, nd.bytesize); + return ret; + } + + #region multiple array + static unsafe T* Addr(T[] array) where T : unmanaged + { + fixed (T* a = &array[0]) + return a; + } + + static unsafe T* Addr(T[,] array) where T : unmanaged + { + fixed (T* a = &array[0, 0]) + return a; + } + + static unsafe T* Addr(T[,,] array) where T : unmanaged + { + fixed (T* a = &array[0, 0, 0]) + return a; + } + + static unsafe T* Addr(T[,,,] array) where T : unmanaged + { + fixed (T* a = &array[0, 0, 0, 0]) + return a; + } + + static unsafe T* Addr(T[,,,,] array) where T : unmanaged + { + fixed (T* a = &array[0, 0, 0, 0, 0]) + return a; + } + + static unsafe T* Addr(T[,,,,,] array) where T : unmanaged + { + fixed (T* a = &array[0, 0, 0, 0, 0, 0]) + return a; + } + #endregion } } diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs b/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs index daec3b5c..ba61e3a7 100644 --- a/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs +++ b/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs @@ -8,28 +8,28 @@ namespace Tensorflow.NumPy { public partial class NDArray { - public NDArray(bool value) : base(value) { NewEagerTensorHandle(); } - public NDArray(byte value) : base(value) { NewEagerTensorHandle(); } - public NDArray(short value) : base(value) { NewEagerTensorHandle(); } - public NDArray(int value) : base(value) { NewEagerTensorHandle(); } - public NDArray(long value) : base(value) { NewEagerTensorHandle(); } - public NDArray(float value) : base(value) { NewEagerTensorHandle(); } - public NDArray(double value) : base(value) { NewEagerTensorHandle(); } + public NDArray(bool value) : base(value) => NewEagerTensorHandle(); + public NDArray(byte value) : base(value) => NewEagerTensorHandle(); + public NDArray(short value) : base(value) => NewEagerTensorHandle(); + public NDArray(int value) : base(value) => NewEagerTensorHandle(); + public NDArray(long value) : base(value) => NewEagerTensorHandle(); + public NDArray(float value) : base(value) => NewEagerTensorHandle(); + public NDArray(double value) : base(value) => NewEagerTensorHandle(); - public NDArray(Array value, Shape? shape = null) - : base(value, shape) { NewEagerTensorHandle(); } + public NDArray(Array value, Shape? shape = null) : base(value, shape) + => NewEagerTensorHandle(); - public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) - : base(shape, dtype: dtype) { NewEagerTensorHandle(); } + public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) : base(shape, dtype: dtype) + => NewEagerTensorHandle(); - public NDArray(byte[] bytes, Shape shape, TF_DataType dtype) - : base(bytes, shape, dtype) { NewEagerTensorHandle(); } + public NDArray(byte[] bytes, Shape shape, TF_DataType dtype) : base(bytes, shape, dtype) + => NewEagerTensorHandle(); - public NDArray(long[] value, Shape? shape = null) - : base(value, shape) { NewEagerTensorHandle(); } + public NDArray(long[] value, Shape? shape = null) : base(value, shape) + => NewEagerTensorHandle(); - public NDArray(IntPtr address, Shape shape, TF_DataType dtype) - : base(address, shape, dtype) { NewEagerTensorHandle(); } + public NDArray(IntPtr address, Shape shape, TF_DataType dtype) : base(address, shape, dtype) + => NewEagerTensorHandle(); public NDArray(Tensor tensor, bool clone = false) : base(tensor.Handle, clone: clone) { diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.cs b/src/TensorFlowNET.Core/Numpy/NDArray.cs index dfc67e0d..75a5803d 100644 --- a/src/TensorFlowNET.Core/Numpy/NDArray.cs +++ b/src/TensorFlowNET.Core/Numpy/NDArray.cs @@ -19,6 +19,7 @@ using System.Collections; using System.Collections.Generic; using System.Linq; using System.Text; +using Tensorflow.Util; using static Tensorflow.Binding; namespace Tensorflow.NumPy @@ -35,7 +36,10 @@ namespace Tensorflow.NumPy public NDArray astype(TF_DataType dtype) => new NDArray(math_ops.cast(this, dtype)); public NDArray ravel() => throw new NotImplementedException(""); public void shuffle(NDArray nd) => np.random.shuffle(nd); - public Array ToMuliDimArray() => throw new NotImplementedException(""); + + public unsafe Array ToMultiDimArray() where T : unmanaged + => NDArrayConverter.ToMultiDimArray(this); + public byte[] ToByteArray() => BufferToArray(); public override string ToString() => NDArrayRender.ToString(this); diff --git a/src/TensorFlowNET.Keras/Saving/hdf5_format.cs b/src/TensorFlowNET.Keras/Saving/hdf5_format.cs index 9f4ff149..50463b71 100644 --- a/src/TensorFlowNET.Keras/Saving/hdf5_format.cs +++ b/src/TensorFlowNET.Keras/Saving/hdf5_format.cs @@ -273,19 +273,19 @@ namespace Tensorflow.Keras.Saving switch (data.dtype) { case TF_DataType.TF_FLOAT: - Hdf5.WriteDatasetFromArray(f, name, data.numpy().ToMuliDimArray()); + Hdf5.WriteDatasetFromArray(f, name, data.numpy().ToMultiDimArray()); break; case TF_DataType.TF_DOUBLE: - Hdf5.WriteDatasetFromArray(f, name, data.numpy().ToMuliDimArray()); + Hdf5.WriteDatasetFromArray(f, name, data.numpy().ToMultiDimArray()); break; case TF_DataType.TF_INT32: - Hdf5.WriteDatasetFromArray(f, name, data.numpy().ToMuliDimArray()); + Hdf5.WriteDatasetFromArray(f, name, data.numpy().ToMultiDimArray()); break; case TF_DataType.TF_INT64: - Hdf5.WriteDatasetFromArray(f, name, data.numpy().ToMuliDimArray()); + Hdf5.WriteDatasetFromArray(f, name, data.numpy().ToMultiDimArray()); break; default: - Hdf5.WriteDatasetFromArray(f, name, data.numpy().ToMuliDimArray()); + Hdf5.WriteDatasetFromArray(f, name, data.numpy().ToMultiDimArray()); break; } } diff --git a/test/TensorFlowNET.UnitTest/Numpy/Array.Creation.Test.cs b/test/TensorFlowNET.UnitTest/Numpy/Array.Creation.Test.cs index 799d40c4..fc309c3c 100644 --- a/test/TensorFlowNET.UnitTest/Numpy/Array.Creation.Test.cs +++ b/test/TensorFlowNET.UnitTest/Numpy/Array.Creation.Test.cs @@ -50,6 +50,22 @@ namespace TensorFlowNET.UnitTest.NumPy AssetSequenceEqual(new[] { 1, 2, 3, 4, 5, 6 }, x.ToArray()); } + [TestMethod] + public void to_multi_dim_array() + { + var x1 = np.arange(12); + var y1 = x1.ToMultiDimArray(); + AssetSequenceEqual((int[])y1, x1.ToArray()); + + var x2 = np.arange(12).reshape((2, 6)); + var y2 = (int[,])x2.ToMultiDimArray(); + Assert.AreEqual(x2[0, 5], y2[0, 5]); + + var x3 = np.arange(12).reshape((2, 2, 3)); + var y3 = (int[,,])x3.ToMultiDimArray(); + Assert.AreEqual(x3[0, 1, 2], y3[0, 1, 2]); + } + [TestMethod] public void eye() {