Browse Source

fix ToMultiDimArray

tags/TensorFlowOpLayer
Oceania2018 4 years ago
parent
commit
c7ee2308fb
5 changed files with 100 additions and 23 deletions
  1. +57
    -0
      src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs
  2. +17
    -17
      src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs
  3. +5
    -1
      src/TensorFlowNET.Core/Numpy/NDArray.cs
  4. +5
    -5
      src/TensorFlowNET.Keras/Saving/hdf5_format.cs
  5. +16
    -0
      test/TensorFlowNET.UnitTest/Numpy/Array.Creation.Test.cs

+ 57
- 0
src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs View File

@@ -30,5 +30,62 @@ namespace Tensorflow.NumPy
TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single),
_ => throw new NotImplementedException("")
};

public static unsafe Array ToMultiDimArray<T>(NDArray nd) where T : unmanaged
{
var ret = Array.CreateInstance(typeof(T), nd.shape.as_int_list());

var addr = ret switch
{
T[] array => Addr(array),
T[,] array => Addr(array),
T[,,] array => Addr(array),
T[,,,] array => Addr(array),
T[,,,,] array => Addr(array),
T[,,,,,] array => Addr(array),
_ => throw new NotImplementedException("")
};

System.Buffer.MemoryCopy(nd.data.ToPointer(), addr, nd.bytesize, nd.bytesize);
return ret;
}

#region multiple array
static unsafe T* Addr<T>(T[] array) where T : unmanaged
{
fixed (T* a = &array[0])
return a;
}

static unsafe T* Addr<T>(T[,] array) where T : unmanaged
{
fixed (T* a = &array[0, 0])
return a;
}

static unsafe T* Addr<T>(T[,,] array) where T : unmanaged
{
fixed (T* a = &array[0, 0, 0])
return a;
}

static unsafe T* Addr<T>(T[,,,] array) where T : unmanaged
{
fixed (T* a = &array[0, 0, 0, 0])
return a;
}

static unsafe T* Addr<T>(T[,,,,] array) where T : unmanaged
{
fixed (T* a = &array[0, 0, 0, 0, 0])
return a;
}

static unsafe T* Addr<T>(T[,,,,,] array) where T : unmanaged
{
fixed (T* a = &array[0, 0, 0, 0, 0, 0])
return a;
}
#endregion
}
}

+ 17
- 17
src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs View File

@@ -8,28 +8,28 @@ namespace Tensorflow.NumPy
{
public partial class NDArray
{
public NDArray(bool value) : base(value) { NewEagerTensorHandle(); }
public NDArray(byte value) : base(value) { NewEagerTensorHandle(); }
public NDArray(short value) : base(value) { NewEagerTensorHandle(); }
public NDArray(int value) : base(value) { NewEagerTensorHandle(); }
public NDArray(long value) : base(value) { NewEagerTensorHandle(); }
public NDArray(float value) : base(value) { NewEagerTensorHandle(); }
public NDArray(double value) : base(value) { NewEagerTensorHandle(); }
public NDArray(bool value) : base(value) => NewEagerTensorHandle();
public NDArray(byte value) : base(value) => NewEagerTensorHandle();
public NDArray(short value) : base(value) => NewEagerTensorHandle();
public NDArray(int value) : base(value) => NewEagerTensorHandle();
public NDArray(long value) : base(value) => NewEagerTensorHandle();
public NDArray(float value) : base(value) => NewEagerTensorHandle();
public NDArray(double value) : base(value) => NewEagerTensorHandle();

public NDArray(Array value, Shape? shape = null)
: base(value, shape) { NewEagerTensorHandle(); }
public NDArray(Array value, Shape? shape = null) : base(value, shape)
=> NewEagerTensorHandle();

public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE)
: base(shape, dtype: dtype) { NewEagerTensorHandle(); }
public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) : base(shape, dtype: dtype)
=> NewEagerTensorHandle();

public NDArray(byte[] bytes, Shape shape, TF_DataType dtype)
: base(bytes, shape, dtype) { NewEagerTensorHandle(); }
public NDArray(byte[] bytes, Shape shape, TF_DataType dtype) : base(bytes, shape, dtype)
=> NewEagerTensorHandle();

public NDArray(long[] value, Shape? shape = null)
: base(value, shape) { NewEagerTensorHandle(); }
public NDArray(long[] value, Shape? shape = null) : base(value, shape)
=> NewEagerTensorHandle();

public NDArray(IntPtr address, Shape shape, TF_DataType dtype)
: base(address, shape, dtype) { NewEagerTensorHandle(); }
public NDArray(IntPtr address, Shape shape, TF_DataType dtype) : base(address, shape, dtype)
=> NewEagerTensorHandle();

public NDArray(Tensor tensor, bool clone = false) : base(tensor.Handle, clone: clone)
{


+ 5
- 1
src/TensorFlowNET.Core/Numpy/NDArray.cs View File

@@ -19,6 +19,7 @@ using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Util;
using static Tensorflow.Binding;

namespace Tensorflow.NumPy
@@ -35,7 +36,10 @@ namespace Tensorflow.NumPy
public NDArray astype(TF_DataType dtype) => new NDArray(math_ops.cast(this, dtype));
public NDArray ravel() => throw new NotImplementedException("");
public void shuffle(NDArray nd) => np.random.shuffle(nd);
public Array ToMuliDimArray<T>() => throw new NotImplementedException("");

public unsafe Array ToMultiDimArray<T>() where T : unmanaged
=> NDArrayConverter.ToMultiDimArray<T>(this);

public byte[] ToByteArray() => BufferToArray();
public override string ToString() => NDArrayRender.ToString(this);



+ 5
- 5
src/TensorFlowNET.Keras/Saving/hdf5_format.cs View File

@@ -273,19 +273,19 @@ namespace Tensorflow.Keras.Saving
switch (data.dtype)
{
case TF_DataType.TF_FLOAT:
Hdf5.WriteDatasetFromArray<float>(f, name, data.numpy().ToMuliDimArray<float>());
Hdf5.WriteDatasetFromArray<float>(f, name, data.numpy().ToMultiDimArray<float>());
break;
case TF_DataType.TF_DOUBLE:
Hdf5.WriteDatasetFromArray<double>(f, name, data.numpy().ToMuliDimArray<double>());
Hdf5.WriteDatasetFromArray<double>(f, name, data.numpy().ToMultiDimArray<double>());
break;
case TF_DataType.TF_INT32:
Hdf5.WriteDatasetFromArray<int>(f, name, data.numpy().ToMuliDimArray<int>());
Hdf5.WriteDatasetFromArray<int>(f, name, data.numpy().ToMultiDimArray<int>());
break;
case TF_DataType.TF_INT64:
Hdf5.WriteDatasetFromArray<long>(f, name, data.numpy().ToMuliDimArray<long>());
Hdf5.WriteDatasetFromArray<long>(f, name, data.numpy().ToMultiDimArray<long>());
break;
default:
Hdf5.WriteDatasetFromArray<float>(f, name, data.numpy().ToMuliDimArray<float>());
Hdf5.WriteDatasetFromArray<float>(f, name, data.numpy().ToMultiDimArray<float>());
break;
}
}


+ 16
- 0
test/TensorFlowNET.UnitTest/Numpy/Array.Creation.Test.cs View File

@@ -50,6 +50,22 @@ namespace TensorFlowNET.UnitTest.NumPy
AssetSequenceEqual(new[] { 1, 2, 3, 4, 5, 6 }, x.ToArray<int>());
}

[TestMethod]
public void to_multi_dim_array()
{
var x1 = np.arange(12);
var y1 = x1.ToMultiDimArray<int>();
AssetSequenceEqual((int[])y1, x1.ToArray<int>());

var x2 = np.arange(12).reshape((2, 6));
var y2 = (int[,])x2.ToMultiDimArray<int>();
Assert.AreEqual(x2[0, 5], y2[0, 5]);

var x3 = np.arange(12).reshape((2, 2, 3));
var y3 = (int[,,])x3.ToMultiDimArray<int>();
Assert.AreEqual(x3[0, 1, 2], y3[0, 1, 2]);
}

[TestMethod]
public void eye()
{


Loading…
Cancel
Save