diff --git a/src/TensorFlowNET.Core/NumPy/DtypeConstructor.cs b/src/TensorFlowNET.Core/NumPy/DtypeConstructor.cs index f84f408e..30ef82df 100644 --- a/src/TensorFlowNET.Core/NumPy/DtypeConstructor.cs +++ b/src/TensorFlowNET.Core/NumPy/DtypeConstructor.cs @@ -16,25 +16,50 @@ namespace Tensorflow.NumPy { public object construct(object[] args) { - Console.WriteLine("DtypeConstructor"); - Console.WriteLine(args.Length); - for (int i = 0; i < args.Length; i++) - { - Console.WriteLine(args[i]); - } - return new demo(); + var typeCode = (string)args[0]; + TF_DataType dtype; + if (typeCode == "b1") + dtype = np.@bool; + else if (typeCode == "i1") + dtype = np.@byte; + else if (typeCode == "i2") + dtype = np.int16; + else if (typeCode == "i4") + dtype = np.int32; + else if (typeCode == "i8") + dtype = np.int64; + else if (typeCode == "u1") + dtype = np.ubyte; + else if (typeCode == "u2") + dtype = np.uint16; + else if (typeCode == "u4") + dtype = np.uint32; + else if (typeCode == "u8") + dtype = np.uint64; + else if (typeCode == "f4") + dtype = np.float32; + else if (typeCode == "f8") + dtype = np.float64; + else if (typeCode.StartsWith("S")) + dtype = np.@string; + else if (typeCode.StartsWith("O")) + dtype = np.@object; + else + throw new NotSupportedException(); + return new TF_DataType_Warpper(dtype); } } - class demo + public class TF_DataType_Warpper { - public void __setstate__(object[] args) + TF_DataType dtype { get; set; } + public TF_DataType_Warpper(TF_DataType dtype) { - Console.WriteLine("demo __setstate__"); - Console.WriteLine(args.Length); - for (int i = 0; i < args.Length; i++) - { - Console.WriteLine(args[i]); - } + this.dtype = dtype; + } + public void __setstate__(object[] args) { } + public static implicit operator TF_DataType(TF_DataType_Warpper dtypeWarpper) + { + return dtypeWarpper.dtype; } } } diff --git a/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Creation.cs b/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Creation.cs index 80b62198..7b79f83c 100644 --- a/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Creation.cs +++ b/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Creation.cs @@ -99,9 +99,6 @@ namespace Tensorflow.NumPy NDArray ReadObjectMatrix(BinaryReader reader, Array matrix, int[] shape) { - //int data = reader.ReadByte(); - //Console.WriteLine(data); - //Console.WriteLine(reader.ReadByte()); Stream stream = reader.BaseStream; Unpickler.registerConstructor("numpy.core.multiarray", "_reconstruct", new MultiArrayConstructor()); Unpickler.registerConstructor("numpy", "dtype", new DtypeConstructor()); diff --git a/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.load.cs b/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.load.cs index 789f119a..bbe48e6a 100644 --- a/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.load.cs +++ b/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.load.cs @@ -28,17 +28,17 @@ namespace Tensorflow.NumPy //if (type == typeof(String)) //return ReadStringMatrix(reader, matrix, bytes, type, shape); - NDArray res = ReadObjectMatrix(reader, matrix, shape); - Console.WriteLine("LoadMatrix"); - Console.WriteLine(res.dims[0]); - Console.WriteLine((int)res[0][0]); - Console.WriteLine(res.dims[1]); - //if (type == typeof(Object)) - //{ - - //} - //else - return ReadValueMatrix(reader, matrix, bytes, type, shape); + + if (type == typeof(Object)) + { + NDArray res = ReadObjectMatrix(reader, matrix, shape); + // res = res.reconstructedNDArray; + return res.reconstructedArray; + } + else + { + return ReadValueMatrix(reader, matrix, bytes, type, shape); + } } } @@ -133,7 +133,7 @@ namespace Tensorflow.NumPy return typeof(Double); if (typeCode.StartsWith("S")) return typeof(String); - if (typeCode == "O") + if (typeCode.StartsWith("O")) return typeof(Object); throw new NotSupportedException(); diff --git a/src/TensorFlowNET.Core/NumPy/MultiArrayConstructor.cs b/src/TensorFlowNET.Core/NumPy/MultiArrayConstructor.cs index 92927cd5..43eda23e 100644 --- a/src/TensorFlowNET.Core/NumPy/MultiArrayConstructor.cs +++ b/src/TensorFlowNET.Core/NumPy/MultiArrayConstructor.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Text; using Razorvine.Pickle; +using Razorvine.Pickle.Objects; namespace Tensorflow.NumPy { @@ -17,28 +18,36 @@ namespace Tensorflow.NumPy { public object construct(object[] args) { - //Console.WriteLine(args.Length); - //for (int i = 0; i < args.Length; i++) - //{ - // Console.WriteLine(args[i]); - //} - Console.WriteLine("MultiArrayConstructor"); - + if (args.Length != 3) + throw new InvalidArgumentError($"Invalid number of arguments in MultiArrayConstructor._reconstruct. Expected three arguments. Given {args.Length} arguments."); + + var types = (ClassDictConstructor)args[0]; + if (types.module != "numpy" || types.name != "ndarray") + throw new RuntimeError("_reconstruct: First argument must be a sub-type of ndarray"); + var arg1 = (Object[])args[1]; var dims = new int[arg1.Length]; for (var i = 0; i < arg1.Length; i++) { dims[i] = (int)arg1[i]; } + var shape = new Shape(dims); - var dtype = TF_DataType.DtInvalid; - switch (args[2]) + TF_DataType dtype; + string identifier; + if (args[2].GetType() == typeof(string)) + identifier = (string)args[2]; + else + identifier = Encoding.UTF8.GetString((byte[])args[2]); + switch (identifier) { - case "b": dtype = TF_DataType.DtUint8Ref; break; - default: throw new NotImplementedException("cannot parse" + args[2]); + case "u": dtype = np.uint32; break; + case "c": dtype = np.complex_; break; + case "f": dtype = np.float32; break; + case "b": dtype = np.@bool; break; + default: throw new NotImplementedException($"Unsupported data type: {args[2]}"); } - return new NDArray(new Shape(dims), dtype); - + return new NDArray(shape, dtype); } } } diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Pickle.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Pickle.cs index b4d66243..62720826 100644 --- a/src/TensorFlowNET.Core/NumPy/NDArray.Pickle.cs +++ b/src/TensorFlowNET.Core/NumPy/NDArray.Pickle.cs @@ -1,4 +1,7 @@ -using System; +using Newtonsoft.Json.Linq; +using Serilog.Debugging; +using System; +using System.Collections; using System.Collections.Generic; using System.Text; @@ -6,14 +9,100 @@ namespace Tensorflow.NumPy { public partial class NDArray { + public NDArray reconstructedNDArray { get; set; } + public Array reconstructedArray { get; set; } public void __setstate__(object[] args) { - Console.WriteLine("NDArray __setstate__"); - Console.WriteLine(args.Length); - for (int i = 0; i < args.Length; i++) + if (args.Length != 5) + throw new InvalidArgumentError($"Invalid number of arguments in NDArray.__setstate__. Expected five arguments. Given {args.Length} arguments."); + + var version = (int)args[0]; // version + + var arg1 = (Object[])args[1]; + var dims = new int[arg1.Length]; + for (var i = 0; i < arg1.Length; i++) + { + dims[i] = (int)arg1[i]; + } + var _ShapeLike = new Shape(dims); // shape + + TF_DataType _DType_co = (TF_DataType_Warpper)args[2]; // DType + + var F_continuous = (bool)args[3]; // F-continuous + if (F_continuous) + throw new InvalidArgumentError("Fortran Continuous memory layout is not supported. Please use C-continuous layout or check the data format."); + + var data = args[4]; // Data + /* + * If we ever need another pickle format, increment the version + * number. But we should still be able to handle the old versions. + */ + if (version < 0 || version > 4) + throw new ValueError($"can't handle version {version} of numpy.dtype pickle"); + + // TODO: Implement the missing details and checks from the official Numpy C code here. + // https://github.com/numpy/numpy/blob/2f0bd6e86a77e4401d0384d9a75edf9470c5deb6/numpy/core/src/multiarray/descriptor.c#L2761 + + if (data.GetType() == typeof(ArrayList)) + { + SetState((ArrayList)data); + } + else + throw new NotImplementedException(""); + } + private void SetState(ArrayList arrayList) + { + int ndim = 1; + var subArrayList = arrayList; + while (subArrayList.Count > 0 && subArrayList[0] != null && subArrayList[0].GetType() == typeof(ArrayList)) + { + subArrayList = (ArrayList)subArrayList[0]; + ndim += 1; + } + var type = subArrayList[0].GetType(); + if (type == typeof(int)) { - Console.WriteLine(args[i]); + if (ndim == 1) + { + int[] list = (int[])arrayList.ToArray(typeof(int)); + Shape shape = new Shape(new int[] { arrayList.Count }); + reconstructedArray = list; + reconstructedNDArray = new NDArray(list, shape); + //SetData(new[] { new Slice() }, new NDArray(list, shape)); + //set_shape(shape); + } + if (ndim == 2) + { + int secondDim = 0; + foreach (ArrayList subArray in arrayList) + { + secondDim = subArray.Count > secondDim ? subArray.Count : secondDim; + } + int[,] list = new int[arrayList.Count, secondDim]; + for (int i = 0; i < arrayList.Count; i++) + { + var subArray = (ArrayList?)arrayList[i]; + if (subArray == null) + throw new NullReferenceException(""); + for (int j = 0; j < subArray.Count; j++) + { + var element = subArray[j]; + if (element == null) + throw new NoNullAllowedException("the element of ArrayList cannot be null."); + list[i,j] = (int) element; + } + } + Shape shape = new Shape(new int[] { arrayList.Count, secondDim }); + reconstructedArray = list; + reconstructedNDArray = new NDArray(list, shape); + //SetData(new[] { new Slice() }, new NDArray(list, shape)); + //set_shape(shape); + } + if (ndim > 2) + throw new NotImplementedException("can't handle ArrayList with more than two dimensions."); } + else + throw new NotImplementedException(""); } } } diff --git a/src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs b/src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs index c8c2d45f..4c64eba7 100644 --- a/src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs +++ b/src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs @@ -10,6 +10,7 @@ namespace Tensorflow.NumPy public unsafe static T Scalar(NDArray nd) where T : unmanaged => nd.dtype switch { + TF_DataType.TF_BOOL => Scalar(*(bool*)nd.data), TF_DataType.TF_UINT8 => Scalar(*(byte*)nd.data), TF_DataType.TF_FLOAT => Scalar(*(float*)nd.data), TF_DataType.TF_INT32 => Scalar(*(int*)nd.data), diff --git a/src/TensorFlowNET.Core/Numpy/Numpy.cs b/src/TensorFlowNET.Core/Numpy/Numpy.cs index 72d2e981..fee2d63f 100644 --- a/src/TensorFlowNET.Core/Numpy/Numpy.cs +++ b/src/TensorFlowNET.Core/Numpy/Numpy.cs @@ -43,7 +43,9 @@ public partial class np public static readonly TF_DataType @decimal = TF_DataType.TF_DOUBLE; public static readonly TF_DataType complex_ = TF_DataType.TF_COMPLEX; public static readonly TF_DataType complex64 = TF_DataType.TF_COMPLEX64; - public static readonly TF_DataType complex128 = TF_DataType.TF_COMPLEX128; + public static readonly TF_DataType complex128 = TF_DataType.TF_COMPLEX128; + public static readonly TF_DataType @string = TF_DataType.TF_STRING; + public static readonly TF_DataType @object = TF_DataType.TF_VARIANT; #endregion public static double nan => double.NaN; diff --git a/src/TensorFlowNET.Keras/Datasets/Imdb.cs b/src/TensorFlowNET.Keras/Datasets/Imdb.cs index 016b352d..6808035c 100644 --- a/src/TensorFlowNET.Keras/Datasets/Imdb.cs +++ b/src/TensorFlowNET.Keras/Datasets/Imdb.cs @@ -70,7 +70,7 @@ namespace Tensorflow.Keras.Datasets public class Imdb { string origin_folder = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"; - string file_name = "imdb.npz"; + string file_name = "simple.npz"; string dest_folder = "imdb"; /// /// Loads the [IMDB dataset](https://ai.stanford.edu/~amaas/data/sentiment/). @@ -128,13 +128,15 @@ namespace Tensorflow.Keras.Datasets (NDArray, NDArray) LoadX(byte[] bytes) { - var y = np.Load_Npz(bytes); - return (y["x_train.npy"], y["x_test.npy"]); + var y = np.Load_Npz(bytes); + var x_train = y["x_train.npy"]; + var x_test = y["x_test.npy"]; + return (x_train, x_test); } (NDArray, NDArray) LoadY(byte[] bytes) { - var y = np.Load_Npz(bytes); + var y = np.Load_Npz(bytes); return (y["y_train.npy"], y["y_test.npy"]); }