diff --git a/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Creation.cs b/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Creation.cs index 7b79f83c..fa4ef019 100644 --- a/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Creation.cs +++ b/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Creation.cs @@ -5,6 +5,7 @@ using System.Linq; using System.Text; using Tensorflow.Util; using Razorvine.Pickle; +using Tensorflow.NumPy.Pickle; using static Tensorflow.Binding; namespace Tensorflow.NumPy @@ -94,20 +95,15 @@ namespace Tensorflow.NumPy var buffer = reader.ReadBytes(bytes * total); System.Buffer.BlockCopy(buffer, 0, matrix, 0, buffer.Length); + return matrix; } - NDArray ReadObjectMatrix(BinaryReader reader, Array matrix, int[] shape) + Array ReadObjectMatrix(BinaryReader reader, Array matrix, int[] shape) { Stream stream = reader.BaseStream; - Unpickler.registerConstructor("numpy.core.multiarray", "_reconstruct", new MultiArrayConstructor()); - Unpickler.registerConstructor("numpy", "dtype", new DtypeConstructor()); - var unpickler = new Unpickler(); - - NDArray result = (NDArray) unpickler.load(stream); - Console.WriteLine(result.dims); - return result; + return (MultiArrayPickleWarpper)unpickler.load(stream); } public (NDArray, NDArray) meshgrid(T[] array, bool copy = true, bool sparse = false) diff --git a/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.load.cs b/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.load.cs index bbe48e6a..199e5ced 100644 --- a/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.load.cs +++ b/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.load.cs @@ -30,17 +30,12 @@ namespace Tensorflow.NumPy //return ReadStringMatrix(reader, matrix, bytes, type, shape); if (type == typeof(Object)) - { - NDArray res = ReadObjectMatrix(reader, matrix, shape); - // res = res.reconstructedNDArray; - return res.reconstructedArray; - } + return ReadObjectMatrix(reader, matrix, shape); else { return ReadValueMatrix(reader, matrix, bytes, type, shape); } } - } public T Load(Stream stream) @@ -59,7 +54,7 @@ namespace Tensorflow.NumPy shape = null; // The first 6 bytes are a magic string: exactly "x93NUMPY" - if (reader.ReadByte() != 0x93) return false; + if (reader.ReadChar() != 63) return false; if (reader.ReadChar() != 'N') return false; if (reader.ReadChar() != 'U') return false; if (reader.ReadChar() != 'M') return false; @@ -75,7 +70,6 @@ namespace Tensorflow.NumPy ushort len = reader.ReadUInt16(); string header = new String(reader.ReadChars(len)); - Console.WriteLine(header); string mark = "'descr': '"; int s = header.IndexOf(mark) + mark.Length; int e = header.IndexOf("'", s + 1); diff --git a/src/TensorFlowNET.Core/NumPy/Pickle/DTypePickleWarpper.cs b/src/TensorFlowNET.Core/NumPy/Pickle/DTypePickleWarpper.cs new file mode 100644 index 00000000..5dff6c16 --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/Pickle/DTypePickleWarpper.cs @@ -0,0 +1,20 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.NumPy.Pickle +{ + public class DTypePickleWarpper + { + TF_DataType dtype { get; set; } + public DTypePickleWarpper(TF_DataType dtype) + { + this.dtype = dtype; + } + public void __setstate__(object[] args) { } + public static implicit operator TF_DataType(DTypePickleWarpper dTypeWarpper) + { + return dTypeWarpper.dtype; + } + } +} diff --git a/src/TensorFlowNET.Core/NumPy/DtypeConstructor.cs b/src/TensorFlowNET.Core/NumPy/Pickle/DtypeConstructor.cs similarity index 77% rename from src/TensorFlowNET.Core/NumPy/DtypeConstructor.cs rename to src/TensorFlowNET.Core/NumPy/Pickle/DtypeConstructor.cs index 30ef82df..160c7d4e 100644 --- a/src/TensorFlowNET.Core/NumPy/DtypeConstructor.cs +++ b/src/TensorFlowNET.Core/NumPy/Pickle/DtypeConstructor.cs @@ -4,7 +4,7 @@ using System.Diagnostics.CodeAnalysis; using System.Text; using Razorvine.Pickle; -namespace Tensorflow.NumPy +namespace Tensorflow.NumPy.Pickle { /// /// @@ -46,20 +46,7 @@ namespace Tensorflow.NumPy dtype = np.@object; else throw new NotSupportedException(); - return new TF_DataType_Warpper(dtype); - } - } - public class TF_DataType_Warpper - { - TF_DataType dtype { get; set; } - public TF_DataType_Warpper(TF_DataType dtype) - { - this.dtype = dtype; - } - public void __setstate__(object[] args) { } - public static implicit operator TF_DataType(TF_DataType_Warpper dtypeWarpper) - { - return dtypeWarpper.dtype; + return new DTypePickleWarpper(dtype); } } } diff --git a/src/TensorFlowNET.Core/NumPy/MultiArrayConstructor.cs b/src/TensorFlowNET.Core/NumPy/Pickle/MultiArrayConstructor.cs similarity index 91% rename from src/TensorFlowNET.Core/NumPy/MultiArrayConstructor.cs rename to src/TensorFlowNET.Core/NumPy/Pickle/MultiArrayConstructor.cs index 43eda23e..885f368c 100644 --- a/src/TensorFlowNET.Core/NumPy/MultiArrayConstructor.cs +++ b/src/TensorFlowNET.Core/NumPy/Pickle/MultiArrayConstructor.cs @@ -5,7 +5,7 @@ using System.Text; using Razorvine.Pickle; using Razorvine.Pickle.Objects; -namespace Tensorflow.NumPy +namespace Tensorflow.NumPy.Pickle { /// /// Creates multiarrays of objects. Returns a primitive type multiarray such as int[][] if @@ -18,14 +18,14 @@ namespace Tensorflow.NumPy { public object construct(object[] args) { - if (args.Length != 3) + if (args.Length != 3) throw new InvalidArgumentError($"Invalid number of arguments in MultiArrayConstructor._reconstruct. Expected three arguments. Given {args.Length} arguments."); - + var types = (ClassDictConstructor)args[0]; - if (types.module != "numpy" || types.name != "ndarray") + if (types.module != "numpy" || types.name != "ndarray") throw new RuntimeError("_reconstruct: First argument must be a sub-type of ndarray"); - - var arg1 = (Object[])args[1]; + + var arg1 = (object[])args[1]; var dims = new int[arg1.Length]; for (var i = 0; i < arg1.Length; i++) { @@ -47,7 +47,7 @@ namespace Tensorflow.NumPy case "b": dtype = np.@bool; break; default: throw new NotImplementedException($"Unsupported data type: {args[2]}"); } - return new NDArray(shape, dtype); + return new MultiArrayPickleWarpper(shape, dtype); } } } diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Pickle.cs b/src/TensorFlowNET.Core/NumPy/Pickle/MultiArrayPickleWarpper.cs similarity index 77% rename from src/TensorFlowNET.Core/NumPy/NDArray.Pickle.cs rename to src/TensorFlowNET.Core/NumPy/Pickle/MultiArrayPickleWarpper.cs index 62720826..af8d1ecc 100644 --- a/src/TensorFlowNET.Core/NumPy/NDArray.Pickle.cs +++ b/src/TensorFlowNET.Core/NumPy/Pickle/MultiArrayPickleWarpper.cs @@ -5,12 +5,19 @@ using System.Collections; using System.Collections.Generic; using System.Text; -namespace Tensorflow.NumPy +namespace Tensorflow.NumPy.Pickle { - public partial class NDArray + public class MultiArrayPickleWarpper { + public Shape reconstructedShape { get; set; } + public TF_DataType reconstructedDType { get; set; } public NDArray reconstructedNDArray { get; set; } - public Array reconstructedArray { get; set; } + public Array reconstructedMultiArray { get; set; } + public MultiArrayPickleWarpper(Shape shape, TF_DataType dtype) + { + reconstructedShape = shape; + reconstructedDType = dtype; + } public void __setstate__(object[] args) { if (args.Length != 5) @@ -18,7 +25,7 @@ namespace Tensorflow.NumPy var version = (int)args[0]; // version - var arg1 = (Object[])args[1]; + var arg1 = (object[])args[1]; var dims = new int[arg1.Length]; for (var i = 0; i < arg1.Length; i++) { @@ -26,7 +33,7 @@ namespace Tensorflow.NumPy } var _ShapeLike = new Shape(dims); // shape - TF_DataType _DType_co = (TF_DataType_Warpper)args[2]; // DType + TF_DataType _DType_co = (DTypePickleWarpper)args[2]; // DType var F_continuous = (bool)args[3]; // F-continuous if (F_continuous) @@ -45,12 +52,12 @@ namespace Tensorflow.NumPy if (data.GetType() == typeof(ArrayList)) { - SetState((ArrayList)data); + Reconstruct((ArrayList)data); } else throw new NotImplementedException(""); } - private void SetState(ArrayList arrayList) + private void Reconstruct(ArrayList arrayList) { int ndim = 1; var subArrayList = arrayList; @@ -66,10 +73,8 @@ namespace Tensorflow.NumPy { int[] list = (int[])arrayList.ToArray(typeof(int)); Shape shape = new Shape(new int[] { arrayList.Count }); - reconstructedArray = list; + reconstructedMultiArray = list; reconstructedNDArray = new NDArray(list, shape); - //SetData(new[] { new Slice() }, new NDArray(list, shape)); - //set_shape(shape); } if (ndim == 2) { @@ -89,14 +94,12 @@ namespace Tensorflow.NumPy var element = subArray[j]; if (element == null) throw new NoNullAllowedException("the element of ArrayList cannot be null."); - list[i,j] = (int) element; + list[i, j] = (int)element; } } Shape shape = new Shape(new int[] { arrayList.Count, secondDim }); - reconstructedArray = list; + reconstructedMultiArray = list; reconstructedNDArray = new NDArray(list, shape); - //SetData(new[] { new Slice() }, new NDArray(list, shape)); - //set_shape(shape); } if (ndim > 2) throw new NotImplementedException("can't handle ArrayList with more than two dimensions."); @@ -104,5 +107,13 @@ namespace Tensorflow.NumPy else throw new NotImplementedException(""); } + public static implicit operator Array(MultiArrayPickleWarpper arrayWarpper) + { + return arrayWarpper.reconstructedMultiArray; + } + public static implicit operator NDArray(MultiArrayPickleWarpper arrayWarpper) + { + return arrayWarpper.reconstructedNDArray; + } } } diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index dc4e48da..e368b37c 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using Razorvine.Pickle; using Serilog; using Serilog.Core; using System.Reflection; @@ -22,6 +23,7 @@ using Tensorflow.Contexts; using Tensorflow.Eager; using Tensorflow.Gradients; using Tensorflow.Keras; +using Tensorflow.NumPy.Pickle; namespace Tensorflow { @@ -98,6 +100,10 @@ namespace Tensorflow "please visit https://github.com/SciSharp/TensorFlow.NET. If it still not work after installing the backend, please submit an " + "issue to https://github.com/SciSharp/TensorFlow.NET/issues"); } + + // register numpy reconstructor for pickle + Unpickler.registerConstructor("numpy.core.multiarray", "_reconstruct", new MultiArrayConstructor()); + Unpickler.registerConstructor("numpy", "dtype", new DtypeConstructor()); } public string VERSION => c_api.StringPiece(c_api.TF_Version()); diff --git a/src/TensorFlowNET.Keras/Datasets/Imdb.cs b/src/TensorFlowNET.Keras/Datasets/Imdb.cs index 6808035c..a992ae84 100644 --- a/src/TensorFlowNET.Keras/Datasets/Imdb.cs +++ b/src/TensorFlowNET.Keras/Datasets/Imdb.cs @@ -5,13 +5,6 @@ using System.Text; using Tensorflow.Keras.Utils; using Tensorflow.NumPy; using System.Linq; -using Google.Protobuf.Collections; -using Microsoft.VisualBasic; -using OneOf.Types; -using static HDF.PInvoke.H5; -using System.Data; -using System.Reflection.Emit; -using System.Xml.Linq; namespace Tensorflow.Keras.Datasets { @@ -70,8 +63,9 @@ namespace Tensorflow.Keras.Datasets public class Imdb { string origin_folder = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"; - string file_name = "simple.npz"; + string file_name = "imdb.npz"; string dest_folder = "imdb"; + /// /// Loads the [IMDB dataset](https://ai.stanford.edu/~amaas/data/sentiment/). /// @@ -95,8 +89,9 @@ namespace Tensorflow.Keras.Datasets { var dst = Download(); var fileBytes = File.ReadAllBytes(Path.Combine(dst, file_name)); - var (x_train, x_test) = LoadX(fileBytes); var (y_train, y_test) = LoadY(fileBytes); + var (x_train, x_test) = LoadX(fileBytes); + /*var lines = File.ReadAllLines(Path.Combine(dst, "imdb_train.txt")); var x_train_string = new string[lines.Length]; var y_train = np.zeros(new int[] { lines.Length }, np.int64); @@ -129,14 +124,12 @@ namespace Tensorflow.Keras.Datasets (NDArray, NDArray) LoadX(byte[] bytes) { var y = np.Load_Npz(bytes); - var x_train = y["x_train.npy"]; - var x_test = y["x_test.npy"]; - return (x_train, x_test); + return (y["x_train.npy"], y["x_test.npy"]); } (NDArray, NDArray) LoadY(byte[] bytes) { - var y = np.Load_Npz(bytes); + var y = np.Load_Npz(bytes); return (y["y_train.npy"], y["y_test.npy"]); } diff --git a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs index 778290bb..db6252ef 100644 --- a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs +++ b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs @@ -1,6 +1,5 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; -using System.Collections.Generic; using System.Linq; using static Tensorflow.Binding; using static Tensorflow.KerasApi; @@ -197,6 +196,7 @@ namespace TensorFlowNET.UnitTest.Dataset Assert.IsFalse(allEqual); } + [Ignore] [TestMethod] public void GetData() { @@ -209,8 +209,8 @@ namespace TensorFlowNET.UnitTest.Dataset var y_val = dataset.Test.Item2; print(len(x_train) + "Training sequences"); print(len(x_val) + "Validation sequences"); - x_train = keras.preprocessing.sequence.pad_sequences((IEnumerable)x_train, maxlen: maxlen); - x_val = keras.preprocessing.sequence.pad_sequences((IEnumerable)x_val, maxlen: maxlen); + //x_train = keras.preprocessing.sequence.pad_sequences((IEnumerable)x_train, maxlen: maxlen); + //x_val = keras.preprocessing.sequence.pad_sequences((IEnumerable)x_val, maxlen: maxlen); } } }