Browse Source

add reconstruction and setstate of NDArray for loading pickled npy file.

tags/v0.110.4-Transformer-Model
lingbai-kong 2 years ago
parent
commit
9d10daf30f
8 changed files with 178 additions and 53 deletions
  1. +40
    -15
      src/TensorFlowNET.Core/NumPy/DtypeConstructor.cs
  2. +0
    -3
      src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Creation.cs
  3. +12
    -12
      src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.load.cs
  4. +22
    -13
      src/TensorFlowNET.Core/NumPy/MultiArrayConstructor.cs
  5. +94
    -5
      src/TensorFlowNET.Core/NumPy/NDArray.Pickle.cs
  6. +1
    -0
      src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs
  7. +3
    -1
      src/TensorFlowNET.Core/Numpy/Numpy.cs
  8. +6
    -4
      src/TensorFlowNET.Keras/Datasets/Imdb.cs

+ 40
- 15
src/TensorFlowNET.Core/NumPy/DtypeConstructor.cs View File

@@ -16,25 +16,50 @@ namespace Tensorflow.NumPy
{
public object construct(object[] args)
{
Console.WriteLine("DtypeConstructor");
Console.WriteLine(args.Length);
for (int i = 0; i < args.Length; i++)
{
Console.WriteLine(args[i]);
}
return new demo();
var typeCode = (string)args[0];
TF_DataType dtype;
if (typeCode == "b1")
dtype = np.@bool;
else if (typeCode == "i1")
dtype = np.@byte;
else if (typeCode == "i2")
dtype = np.int16;
else if (typeCode == "i4")
dtype = np.int32;
else if (typeCode == "i8")
dtype = np.int64;
else if (typeCode == "u1")
dtype = np.ubyte;
else if (typeCode == "u2")
dtype = np.uint16;
else if (typeCode == "u4")
dtype = np.uint32;
else if (typeCode == "u8")
dtype = np.uint64;
else if (typeCode == "f4")
dtype = np.float32;
else if (typeCode == "f8")
dtype = np.float64;
else if (typeCode.StartsWith("S"))
dtype = np.@string;
else if (typeCode.StartsWith("O"))
dtype = np.@object;
else
throw new NotSupportedException();
return new TF_DataType_Warpper(dtype);
}
}
class demo
public class TF_DataType_Warpper
{
public void __setstate__(object[] args)
TF_DataType dtype { get; set; }
public TF_DataType_Warpper(TF_DataType dtype)
{
Console.WriteLine("demo __setstate__");
Console.WriteLine(args.Length);
for (int i = 0; i < args.Length; i++)
{
Console.WriteLine(args[i]);
}
this.dtype = dtype;
}
public void __setstate__(object[] args) { }
public static implicit operator TF_DataType(TF_DataType_Warpper dtypeWarpper)
{
return dtypeWarpper.dtype;
}
}
}

+ 0
- 3
src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Creation.cs View File

@@ -99,9 +99,6 @@ namespace Tensorflow.NumPy

NDArray ReadObjectMatrix(BinaryReader reader, Array matrix, int[] shape)
{
//int data = reader.ReadByte();
//Console.WriteLine(data);
//Console.WriteLine(reader.ReadByte());
Stream stream = reader.BaseStream;
Unpickler.registerConstructor("numpy.core.multiarray", "_reconstruct", new MultiArrayConstructor());
Unpickler.registerConstructor("numpy", "dtype", new DtypeConstructor());


+ 12
- 12
src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.load.cs View File

@@ -28,17 +28,17 @@ namespace Tensorflow.NumPy

//if (type == typeof(String))
//return ReadStringMatrix(reader, matrix, bytes, type, shape);
NDArray res = ReadObjectMatrix(reader, matrix, shape);
Console.WriteLine("LoadMatrix");
Console.WriteLine(res.dims[0]);
Console.WriteLine((int)res[0][0]);
Console.WriteLine(res.dims[1]);
//if (type == typeof(Object))
//{
//}
//else
return ReadValueMatrix(reader, matrix, bytes, type, shape);
if (type == typeof(Object))
{
NDArray res = ReadObjectMatrix(reader, matrix, shape);
// res = res.reconstructedNDArray;
return res.reconstructedArray;
}
else
{
return ReadValueMatrix(reader, matrix, bytes, type, shape);
}
}

}
@@ -133,7 +133,7 @@ namespace Tensorflow.NumPy
return typeof(Double);
if (typeCode.StartsWith("S"))
return typeof(String);
if (typeCode == "O")
if (typeCode.StartsWith("O"))
return typeof(Object);

throw new NotSupportedException();


+ 22
- 13
src/TensorFlowNET.Core/NumPy/MultiArrayConstructor.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Text;
using Razorvine.Pickle;
using Razorvine.Pickle.Objects;

namespace Tensorflow.NumPy
{
@@ -17,28 +18,36 @@ namespace Tensorflow.NumPy
{
public object construct(object[] args)
{
//Console.WriteLine(args.Length);
//for (int i = 0; i < args.Length; i++)
//{
// Console.WriteLine(args[i]);
//}
Console.WriteLine("MultiArrayConstructor");
if (args.Length != 3)
throw new InvalidArgumentError($"Invalid number of arguments in MultiArrayConstructor._reconstruct. Expected three arguments. Given {args.Length} arguments.");
var types = (ClassDictConstructor)args[0];
if (types.module != "numpy" || types.name != "ndarray")
throw new RuntimeError("_reconstruct: First argument must be a sub-type of ndarray");
var arg1 = (Object[])args[1];
var dims = new int[arg1.Length];
for (var i = 0; i < arg1.Length; i++)
{
dims[i] = (int)arg1[i];
}
var shape = new Shape(dims);

var dtype = TF_DataType.DtInvalid;
switch (args[2])
TF_DataType dtype;
string identifier;
if (args[2].GetType() == typeof(string))
identifier = (string)args[2];
else
identifier = Encoding.UTF8.GetString((byte[])args[2]);
switch (identifier)
{
case "b": dtype = TF_DataType.DtUint8Ref; break;
default: throw new NotImplementedException("cannot parse" + args[2]);
case "u": dtype = np.uint32; break;
case "c": dtype = np.complex_; break;
case "f": dtype = np.float32; break;
case "b": dtype = np.@bool; break;
default: throw new NotImplementedException($"Unsupported data type: {args[2]}");
}
return new NDArray(new Shape(dims), dtype);

return new NDArray(shape, dtype);
}
}
}

+ 94
- 5
src/TensorFlowNET.Core/NumPy/NDArray.Pickle.cs View File

@@ -1,4 +1,7 @@
using System;
using Newtonsoft.Json.Linq;
using Serilog.Debugging;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Text;

@@ -6,14 +9,100 @@ namespace Tensorflow.NumPy
{
public partial class NDArray
{
public NDArray reconstructedNDArray { get; set; }
public Array reconstructedArray { get; set; }
public void __setstate__(object[] args)
{
Console.WriteLine("NDArray __setstate__");
Console.WriteLine(args.Length);
for (int i = 0; i < args.Length; i++)
if (args.Length != 5)
throw new InvalidArgumentError($"Invalid number of arguments in NDArray.__setstate__. Expected five arguments. Given {args.Length} arguments.");

var version = (int)args[0]; // version

var arg1 = (Object[])args[1];
var dims = new int[arg1.Length];
for (var i = 0; i < arg1.Length; i++)
{
dims[i] = (int)arg1[i];
}
var _ShapeLike = new Shape(dims); // shape

TF_DataType _DType_co = (TF_DataType_Warpper)args[2]; // DType

var F_continuous = (bool)args[3]; // F-continuous
if (F_continuous)
throw new InvalidArgumentError("Fortran Continuous memory layout is not supported. Please use C-continuous layout or check the data format.");

var data = args[4]; // Data
/*
* If we ever need another pickle format, increment the version
* number. But we should still be able to handle the old versions.
*/
if (version < 0 || version > 4)
throw new ValueError($"can't handle version {version} of numpy.dtype pickle");

// TODO: Implement the missing details and checks from the official Numpy C code here.
// https://github.com/numpy/numpy/blob/2f0bd6e86a77e4401d0384d9a75edf9470c5deb6/numpy/core/src/multiarray/descriptor.c#L2761

if (data.GetType() == typeof(ArrayList))
{
SetState((ArrayList)data);
}
else
throw new NotImplementedException("");
}
private void SetState(ArrayList arrayList)
{
int ndim = 1;
var subArrayList = arrayList;
while (subArrayList.Count > 0 && subArrayList[0] != null && subArrayList[0].GetType() == typeof(ArrayList))
{
subArrayList = (ArrayList)subArrayList[0];
ndim += 1;
}
var type = subArrayList[0].GetType();
if (type == typeof(int))
{
Console.WriteLine(args[i]);
if (ndim == 1)
{
int[] list = (int[])arrayList.ToArray(typeof(int));
Shape shape = new Shape(new int[] { arrayList.Count });
reconstructedArray = list;
reconstructedNDArray = new NDArray(list, shape);
//SetData(new[] { new Slice() }, new NDArray(list, shape));
//set_shape(shape);
}
if (ndim == 2)
{
int secondDim = 0;
foreach (ArrayList subArray in arrayList)
{
secondDim = subArray.Count > secondDim ? subArray.Count : secondDim;
}
int[,] list = new int[arrayList.Count, secondDim];
for (int i = 0; i < arrayList.Count; i++)
{
var subArray = (ArrayList?)arrayList[i];
if (subArray == null)
throw new NullReferenceException("");
for (int j = 0; j < subArray.Count; j++)
{
var element = subArray[j];
if (element == null)
throw new NoNullAllowedException("the element of ArrayList cannot be null.");
list[i,j] = (int) element;
}
}
Shape shape = new Shape(new int[] { arrayList.Count, secondDim });
reconstructedArray = list;
reconstructedNDArray = new NDArray(list, shape);
//SetData(new[] { new Slice() }, new NDArray(list, shape));
//set_shape(shape);
}
if (ndim > 2)
throw new NotImplementedException("can't handle ArrayList with more than two dimensions.");
}
else
throw new NotImplementedException("");
}
}
}

+ 1
- 0
src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs View File

@@ -10,6 +10,7 @@ namespace Tensorflow.NumPy
public unsafe static T Scalar<T>(NDArray nd) where T : unmanaged
=> nd.dtype switch
{
TF_DataType.TF_BOOL => Scalar<T>(*(bool*)nd.data),
TF_DataType.TF_UINT8 => Scalar<T>(*(byte*)nd.data),
TF_DataType.TF_FLOAT => Scalar<T>(*(float*)nd.data),
TF_DataType.TF_INT32 => Scalar<T>(*(int*)nd.data),


+ 3
- 1
src/TensorFlowNET.Core/Numpy/Numpy.cs View File

@@ -43,7 +43,9 @@ public partial class np
public static readonly TF_DataType @decimal = TF_DataType.TF_DOUBLE;
public static readonly TF_DataType complex_ = TF_DataType.TF_COMPLEX;
public static readonly TF_DataType complex64 = TF_DataType.TF_COMPLEX64;
public static readonly TF_DataType complex128 = TF_DataType.TF_COMPLEX128;
public static readonly TF_DataType complex128 = TF_DataType.TF_COMPLEX128;
public static readonly TF_DataType @string = TF_DataType.TF_STRING;
public static readonly TF_DataType @object = TF_DataType.TF_VARIANT;
#endregion

public static double nan => double.NaN;


+ 6
- 4
src/TensorFlowNET.Keras/Datasets/Imdb.cs View File

@@ -70,7 +70,7 @@ namespace Tensorflow.Keras.Datasets
public class Imdb
{
string origin_folder = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/";
string file_name = "imdb.npz";
string file_name = "simple.npz";
string dest_folder = "imdb";
/// <summary>
/// Loads the [IMDB dataset](https://ai.stanford.edu/~amaas/data/sentiment/).
@@ -128,13 +128,15 @@ namespace Tensorflow.Keras.Datasets

(NDArray, NDArray) LoadX(byte[] bytes)
{
var y = np.Load_Npz<byte[]>(bytes);
return (y["x_train.npy"], y["x_test.npy"]);
var y = np.Load_Npz<int[,]>(bytes);
var x_train = y["x_train.npy"];
var x_test = y["x_test.npy"];
return (x_train, x_test);
}

(NDArray, NDArray) LoadY(byte[] bytes)
{
var y = np.Load_Npz<long[]>(bytes);
var y = np.Load_Npz<int[]>(bytes);
return (y["y_train.npy"], y["y_test.npy"]);
}



Loading…
Cancel
Save