@@ -5,6 +5,7 @@ using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Util; | |||
using Razorvine.Pickle; | |||
using Tensorflow.NumPy.Pickle; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.NumPy | |||
@@ -94,20 +95,15 @@ namespace Tensorflow.NumPy | |||
var buffer = reader.ReadBytes(bytes * total); | |||
System.Buffer.BlockCopy(buffer, 0, matrix, 0, buffer.Length); | |||
return matrix; | |||
} | |||
NDArray ReadObjectMatrix(BinaryReader reader, Array matrix, int[] shape) | |||
Array ReadObjectMatrix(BinaryReader reader, Array matrix, int[] shape) | |||
{ | |||
Stream stream = reader.BaseStream; | |||
Unpickler.registerConstructor("numpy.core.multiarray", "_reconstruct", new MultiArrayConstructor()); | |||
Unpickler.registerConstructor("numpy", "dtype", new DtypeConstructor()); | |||
var unpickler = new Unpickler(); | |||
NDArray result = (NDArray) unpickler.load(stream); | |||
Console.WriteLine(result.dims); | |||
return result; | |||
return (MultiArrayPickleWarpper)unpickler.load(stream); | |||
} | |||
public (NDArray, NDArray) meshgrid<T>(T[] array, bool copy = true, bool sparse = false) | |||
@@ -30,17 +30,12 @@ namespace Tensorflow.NumPy | |||
//return ReadStringMatrix(reader, matrix, bytes, type, shape); | |||
if (type == typeof(Object)) | |||
{ | |||
NDArray res = ReadObjectMatrix(reader, matrix, shape); | |||
// res = res.reconstructedNDArray; | |||
return res.reconstructedArray; | |||
} | |||
return ReadObjectMatrix(reader, matrix, shape); | |||
else | |||
{ | |||
return ReadValueMatrix(reader, matrix, bytes, type, shape); | |||
} | |||
} | |||
} | |||
public T Load<T>(Stream stream) | |||
@@ -59,7 +54,7 @@ namespace Tensorflow.NumPy | |||
shape = null; | |||
// The first 6 bytes are a magic string: exactly "x93NUMPY" | |||
if (reader.ReadByte() != 0x93) return false; | |||
if (reader.ReadChar() != 63) return false; | |||
if (reader.ReadChar() != 'N') return false; | |||
if (reader.ReadChar() != 'U') return false; | |||
if (reader.ReadChar() != 'M') return false; | |||
@@ -75,7 +70,6 @@ namespace Tensorflow.NumPy | |||
ushort len = reader.ReadUInt16(); | |||
string header = new String(reader.ReadChars(len)); | |||
Console.WriteLine(header); | |||
string mark = "'descr': '"; | |||
int s = header.IndexOf(mark) + mark.Length; | |||
int e = header.IndexOf("'", s + 1); | |||
@@ -0,0 +1,20 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.NumPy.Pickle | |||
{ | |||
public class DTypePickleWarpper | |||
{ | |||
TF_DataType dtype { get; set; } | |||
public DTypePickleWarpper(TF_DataType dtype) | |||
{ | |||
this.dtype = dtype; | |||
} | |||
public void __setstate__(object[] args) { } | |||
public static implicit operator TF_DataType(DTypePickleWarpper dTypeWarpper) | |||
{ | |||
return dTypeWarpper.dtype; | |||
} | |||
} | |||
} |
@@ -4,7 +4,7 @@ using System.Diagnostics.CodeAnalysis; | |||
using System.Text; | |||
using Razorvine.Pickle; | |||
namespace Tensorflow.NumPy | |||
namespace Tensorflow.NumPy.Pickle | |||
{ | |||
/// <summary> | |||
/// | |||
@@ -46,20 +46,7 @@ namespace Tensorflow.NumPy | |||
dtype = np.@object; | |||
else | |||
throw new NotSupportedException(); | |||
return new TF_DataType_Warpper(dtype); | |||
} | |||
} | |||
public class TF_DataType_Warpper | |||
{ | |||
TF_DataType dtype { get; set; } | |||
public TF_DataType_Warpper(TF_DataType dtype) | |||
{ | |||
this.dtype = dtype; | |||
} | |||
public void __setstate__(object[] args) { } | |||
public static implicit operator TF_DataType(TF_DataType_Warpper dtypeWarpper) | |||
{ | |||
return dtypeWarpper.dtype; | |||
return new DTypePickleWarpper(dtype); | |||
} | |||
} | |||
} |
@@ -5,7 +5,7 @@ using System.Text; | |||
using Razorvine.Pickle; | |||
using Razorvine.Pickle.Objects; | |||
namespace Tensorflow.NumPy | |||
namespace Tensorflow.NumPy.Pickle | |||
{ | |||
/// <summary> | |||
/// Creates multiarrays of objects. Returns a primitive type multiarray such as int[][] if | |||
@@ -18,14 +18,14 @@ namespace Tensorflow.NumPy | |||
{ | |||
public object construct(object[] args) | |||
{ | |||
if (args.Length != 3) | |||
if (args.Length != 3) | |||
throw new InvalidArgumentError($"Invalid number of arguments in MultiArrayConstructor._reconstruct. Expected three arguments. Given {args.Length} arguments."); | |||
var types = (ClassDictConstructor)args[0]; | |||
if (types.module != "numpy" || types.name != "ndarray") | |||
if (types.module != "numpy" || types.name != "ndarray") | |||
throw new RuntimeError("_reconstruct: First argument must be a sub-type of ndarray"); | |||
var arg1 = (Object[])args[1]; | |||
var arg1 = (object[])args[1]; | |||
var dims = new int[arg1.Length]; | |||
for (var i = 0; i < arg1.Length; i++) | |||
{ | |||
@@ -47,7 +47,7 @@ namespace Tensorflow.NumPy | |||
case "b": dtype = np.@bool; break; | |||
default: throw new NotImplementedException($"Unsupported data type: {args[2]}"); | |||
} | |||
return new NDArray(shape, dtype); | |||
return new MultiArrayPickleWarpper(shape, dtype); | |||
} | |||
} | |||
} |
@@ -5,12 +5,19 @@ using System.Collections; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.NumPy | |||
namespace Tensorflow.NumPy.Pickle | |||
{ | |||
public partial class NDArray | |||
public class MultiArrayPickleWarpper | |||
{ | |||
public Shape reconstructedShape { get; set; } | |||
public TF_DataType reconstructedDType { get; set; } | |||
public NDArray reconstructedNDArray { get; set; } | |||
public Array reconstructedArray { get; set; } | |||
public Array reconstructedMultiArray { get; set; } | |||
public MultiArrayPickleWarpper(Shape shape, TF_DataType dtype) | |||
{ | |||
reconstructedShape = shape; | |||
reconstructedDType = dtype; | |||
} | |||
public void __setstate__(object[] args) | |||
{ | |||
if (args.Length != 5) | |||
@@ -18,7 +25,7 @@ namespace Tensorflow.NumPy | |||
var version = (int)args[0]; // version | |||
var arg1 = (Object[])args[1]; | |||
var arg1 = (object[])args[1]; | |||
var dims = new int[arg1.Length]; | |||
for (var i = 0; i < arg1.Length; i++) | |||
{ | |||
@@ -26,7 +33,7 @@ namespace Tensorflow.NumPy | |||
} | |||
var _ShapeLike = new Shape(dims); // shape | |||
TF_DataType _DType_co = (TF_DataType_Warpper)args[2]; // DType | |||
TF_DataType _DType_co = (DTypePickleWarpper)args[2]; // DType | |||
var F_continuous = (bool)args[3]; // F-continuous | |||
if (F_continuous) | |||
@@ -45,12 +52,12 @@ namespace Tensorflow.NumPy | |||
if (data.GetType() == typeof(ArrayList)) | |||
{ | |||
SetState((ArrayList)data); | |||
Reconstruct((ArrayList)data); | |||
} | |||
else | |||
throw new NotImplementedException(""); | |||
} | |||
private void SetState(ArrayList arrayList) | |||
private void Reconstruct(ArrayList arrayList) | |||
{ | |||
int ndim = 1; | |||
var subArrayList = arrayList; | |||
@@ -66,10 +73,8 @@ namespace Tensorflow.NumPy | |||
{ | |||
int[] list = (int[])arrayList.ToArray(typeof(int)); | |||
Shape shape = new Shape(new int[] { arrayList.Count }); | |||
reconstructedArray = list; | |||
reconstructedMultiArray = list; | |||
reconstructedNDArray = new NDArray(list, shape); | |||
//SetData(new[] { new Slice() }, new NDArray(list, shape)); | |||
//set_shape(shape); | |||
} | |||
if (ndim == 2) | |||
{ | |||
@@ -89,14 +94,12 @@ namespace Tensorflow.NumPy | |||
var element = subArray[j]; | |||
if (element == null) | |||
throw new NoNullAllowedException("the element of ArrayList cannot be null."); | |||
list[i,j] = (int) element; | |||
list[i, j] = (int)element; | |||
} | |||
} | |||
Shape shape = new Shape(new int[] { arrayList.Count, secondDim }); | |||
reconstructedArray = list; | |||
reconstructedMultiArray = list; | |||
reconstructedNDArray = new NDArray(list, shape); | |||
//SetData(new[] { new Slice() }, new NDArray(list, shape)); | |||
//set_shape(shape); | |||
} | |||
if (ndim > 2) | |||
throw new NotImplementedException("can't handle ArrayList with more than two dimensions."); | |||
@@ -104,5 +107,13 @@ namespace Tensorflow.NumPy | |||
else | |||
throw new NotImplementedException(""); | |||
} | |||
public static implicit operator Array(MultiArrayPickleWarpper arrayWarpper) | |||
{ | |||
return arrayWarpper.reconstructedMultiArray; | |||
} | |||
public static implicit operator NDArray(MultiArrayPickleWarpper arrayWarpper) | |||
{ | |||
return arrayWarpper.reconstructedNDArray; | |||
} | |||
} | |||
} |
@@ -14,6 +14,7 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using Razorvine.Pickle; | |||
using Serilog; | |||
using Serilog.Core; | |||
using System.Reflection; | |||
@@ -22,6 +23,7 @@ using Tensorflow.Contexts; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Gradients; | |||
using Tensorflow.Keras; | |||
using Tensorflow.NumPy.Pickle; | |||
namespace Tensorflow | |||
{ | |||
@@ -98,6 +100,10 @@ namespace Tensorflow | |||
"please visit https://github.com/SciSharp/TensorFlow.NET. If it still not work after installing the backend, please submit an " + | |||
"issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
} | |||
// register numpy reconstructor for pickle | |||
Unpickler.registerConstructor("numpy.core.multiarray", "_reconstruct", new MultiArrayConstructor()); | |||
Unpickler.registerConstructor("numpy", "dtype", new DtypeConstructor()); | |||
} | |||
public string VERSION => c_api.StringPiece(c_api.TF_Version()); | |||
@@ -5,13 +5,6 @@ using System.Text; | |||
using Tensorflow.Keras.Utils; | |||
using Tensorflow.NumPy; | |||
using System.Linq; | |||
using Google.Protobuf.Collections; | |||
using Microsoft.VisualBasic; | |||
using OneOf.Types; | |||
using static HDF.PInvoke.H5; | |||
using System.Data; | |||
using System.Reflection.Emit; | |||
using System.Xml.Linq; | |||
namespace Tensorflow.Keras.Datasets | |||
{ | |||
@@ -70,8 +63,9 @@ namespace Tensorflow.Keras.Datasets | |||
public class Imdb | |||
{ | |||
string origin_folder = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"; | |||
string file_name = "simple.npz"; | |||
string file_name = "imdb.npz"; | |||
string dest_folder = "imdb"; | |||
/// <summary> | |||
/// Loads the [IMDB dataset](https://ai.stanford.edu/~amaas/data/sentiment/). | |||
/// </summary> | |||
@@ -95,8 +89,9 @@ namespace Tensorflow.Keras.Datasets | |||
{ | |||
var dst = Download(); | |||
var fileBytes = File.ReadAllBytes(Path.Combine(dst, file_name)); | |||
var (x_train, x_test) = LoadX(fileBytes); | |||
var (y_train, y_test) = LoadY(fileBytes); | |||
var (x_train, x_test) = LoadX(fileBytes); | |||
/*var lines = File.ReadAllLines(Path.Combine(dst, "imdb_train.txt")); | |||
var x_train_string = new string[lines.Length]; | |||
var y_train = np.zeros(new int[] { lines.Length }, np.int64); | |||
@@ -129,14 +124,12 @@ namespace Tensorflow.Keras.Datasets | |||
(NDArray, NDArray) LoadX(byte[] bytes) | |||
{ | |||
var y = np.Load_Npz<int[,]>(bytes); | |||
var x_train = y["x_train.npy"]; | |||
var x_test = y["x_test.npy"]; | |||
return (x_train, x_test); | |||
return (y["x_train.npy"], y["x_test.npy"]); | |||
} | |||
(NDArray, NDArray) LoadY(byte[] bytes) | |||
{ | |||
var y = np.Load_Npz<int[]>(bytes); | |||
var y = np.Load_Npz<long[]>(bytes); | |||
return (y["y_train.npy"], y["y_test.npy"]); | |||
} | |||
@@ -1,6 +1,5 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.KerasApi; | |||
@@ -197,6 +196,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||
Assert.IsFalse(allEqual); | |||
} | |||
[Ignore] | |||
[TestMethod] | |||
public void GetData() | |||
{ | |||
@@ -209,8 +209,8 @@ namespace TensorFlowNET.UnitTest.Dataset | |||
var y_val = dataset.Test.Item2; | |||
print(len(x_train) + "Training sequences"); | |||
print(len(x_val) + "Validation sequences"); | |||
x_train = keras.preprocessing.sequence.pad_sequences((IEnumerable<int[]>)x_train, maxlen: maxlen); | |||
x_val = keras.preprocessing.sequence.pad_sequences((IEnumerable<int[]>)x_val, maxlen: maxlen); | |||
//x_train = keras.preprocessing.sequence.pad_sequences((IEnumerable<int[]>)x_train, maxlen: maxlen); | |||
//x_val = keras.preprocessing.sequence.pad_sequences((IEnumerable<int[]>)x_val, maxlen: maxlen); | |||
} | |||
} | |||
} |