Browse Source

init pickle support to np.load object type of npy

tags/v0.110.4-Transformer-Model
lingbai-kong 2 years ago
parent
commit
aac52940ad
8 changed files with 215 additions and 9 deletions
  1. +40
    -0
      src/TensorFlowNET.Core/NumPy/DtypeConstructor.cs
  2. +17
    -1
      src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Creation.cs
  3. +18
    -4
      src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.load.cs
  4. +44
    -0
      src/TensorFlowNET.Core/NumPy/MultiArrayConstructor.cs
  5. +19
    -0
      src/TensorFlowNET.Core/NumPy/NDArray.Pickle.cs
  6. +1
    -0
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  7. +59
    -4
      src/TensorFlowNET.Keras/Datasets/Imdb.cs
  8. +17
    -0
      test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs

+ 40
- 0
src/TensorFlowNET.Core/NumPy/DtypeConstructor.cs View File

@@ -0,0 +1,40 @@
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Text;
using Razorvine.Pickle;

namespace Tensorflow.NumPy
{
/// <summary>
///
/// </summary>
[SuppressMessage("ReSharper", "InconsistentNaming")]
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
[SuppressMessage("ReSharper", "MemberCanBeMadeStatic.Global")]
class DtypeConstructor : IObjectConstructor
{
public object construct(object[] args)
{
Console.WriteLine("DtypeConstructor");
Console.WriteLine(args.Length);
for (int i = 0; i < args.Length; i++)
{
Console.WriteLine(args[i]);
}
return new demo();
}
}
class demo
{
public void __setstate__(object[] args)
{
Console.WriteLine("demo __setstate__");
Console.WriteLine(args.Length);
for (int i = 0; i < args.Length; i++)
{
Console.WriteLine(args[i]);
}
}
}
}

+ 17
- 1
src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Creation.cs View File

@@ -4,6 +4,7 @@ using System.IO;
using System.Linq;
using System.Text;
using Tensorflow.Util;
using Razorvine.Pickle;
using static Tensorflow.Binding;

namespace Tensorflow.NumPy
@@ -93,10 +94,25 @@ namespace Tensorflow.NumPy
var buffer = reader.ReadBytes(bytes * total);
System.Buffer.BlockCopy(buffer, 0, matrix, 0, buffer.Length);

return matrix;
}

NDArray ReadObjectMatrix(BinaryReader reader, Array matrix, int[] shape)
{
//int data = reader.ReadByte();
//Console.WriteLine(data);
//Console.WriteLine(reader.ReadByte());
Stream stream = reader.BaseStream;
Unpickler.registerConstructor("numpy.core.multiarray", "_reconstruct", new MultiArrayConstructor());
Unpickler.registerConstructor("numpy", "dtype", new DtypeConstructor());

var unpickler = new Unpickler();
NDArray result = (NDArray) unpickler.load(stream);
Console.WriteLine(result.dims);
return result;
}

public (NDArray, NDArray) meshgrid<T>(T[] array, bool copy = true, bool sparse = false)
{
var tensors = array_ops.meshgrid(array, copy: copy, sparse: sparse);


+ 18
- 4
src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.load.cs View File

@@ -27,9 +27,20 @@ namespace Tensorflow.NumPy
Array matrix = Array.CreateInstance(type, shape);

//if (type == typeof(String))
//return ReadStringMatrix(reader, matrix, bytes, type, shape);
//return ReadStringMatrix(reader, matrix, bytes, type, shape);
NDArray res = ReadObjectMatrix(reader, matrix, shape);
Console.WriteLine("LoadMatrix");
Console.WriteLine(res.dims[0]);
Console.WriteLine((int)res[0][0]);
Console.WriteLine(res.dims[1]);
//if (type == typeof(Object))
//{

//}
//else
return ReadValueMatrix(reader, matrix, bytes, type, shape);
}

}

public T Load<T>(Stream stream)
@@ -37,7 +48,7 @@ namespace Tensorflow.NumPy
ICloneable, IList, ICollection, IEnumerable, IStructuralComparable, IStructuralEquatable
{
// if (typeof(T).IsArray && (typeof(T).GetElementType().IsArray || typeof(T).GetElementType() == typeof(string)))
// return LoadJagged(stream) as T;
// return LoadJagged(stream) as T;
return LoadMatrix(stream) as T;
}

@@ -48,7 +59,7 @@ namespace Tensorflow.NumPy
shape = null;

// The first 6 bytes are a magic string: exactly "x93NUMPY"
if (reader.ReadChar() != 63) return false;
if (reader.ReadByte() != 0x93) return false;
if (reader.ReadChar() != 'N') return false;
if (reader.ReadChar() != 'U') return false;
if (reader.ReadChar() != 'M') return false;
@@ -64,6 +75,7 @@ namespace Tensorflow.NumPy
ushort len = reader.ReadUInt16();

string header = new String(reader.ReadChars(len));
Console.WriteLine(header);
string mark = "'descr': '";
int s = header.IndexOf(mark) + mark.Length;
int e = header.IndexOf("'", s + 1);
@@ -93,7 +105,7 @@ namespace Tensorflow.NumPy
Type GetType(string dtype, out int bytes, out bool? isLittleEndian)
{
isLittleEndian = IsLittleEndian(dtype);
bytes = Int32.Parse(dtype.Substring(2));
bytes = dtype.Length > 2 ? Int32.Parse(dtype.Substring(2)) : 0;

string typeCode = dtype.Substring(1);

@@ -121,6 +133,8 @@ namespace Tensorflow.NumPy
return typeof(Double);
if (typeCode.StartsWith("S"))
return typeof(String);
if (typeCode == "O")
return typeof(Object);

throw new NotSupportedException();
}


+ 44
- 0
src/TensorFlowNET.Core/NumPy/MultiArrayConstructor.cs View File

@@ -0,0 +1,44 @@
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Text;
using Razorvine.Pickle;

namespace Tensorflow.NumPy
{
/// <summary>
/// Creates multiarrays of objects. Returns a primitive type multiarray such as int[][] if
/// the objects are ints, etc.
/// </summary>
[SuppressMessage("ReSharper", "InconsistentNaming")]
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
[SuppressMessage("ReSharper", "MemberCanBeMadeStatic.Global")]
public class MultiArrayConstructor : IObjectConstructor
{
public object construct(object[] args)
{
//Console.WriteLine(args.Length);
//for (int i = 0; i < args.Length; i++)
//{
// Console.WriteLine(args[i]);
//}
Console.WriteLine("MultiArrayConstructor");

var arg1 = (Object[])args[1];
var dims = new int[arg1.Length];
for (var i = 0; i < arg1.Length; i++)
{
dims[i] = (int)arg1[i];
}

var dtype = TF_DataType.DtInvalid;
switch (args[2])
{
case "b": dtype = TF_DataType.DtUint8Ref; break;
default: throw new NotImplementedException("cannot parse" + args[2]);
}
return new NDArray(new Shape(dims), dtype);

}
}
}

+ 19
- 0
src/TensorFlowNET.Core/NumPy/NDArray.Pickle.cs View File

@@ -0,0 +1,19 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.NumPy
{
public partial class NDArray
{
public void __setstate__(object[] args)
{
Console.WriteLine("NDArray __setstate__");
Console.WriteLine(args.Length);
for (int i = 0; i < args.Length; i++)
{
Console.WriteLine(args[i]);
}
}
}
}

+ 1
- 0
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -112,6 +112,7 @@ https://tensorflownet.readthedocs.io</Description>
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
<PackageReference Include="OneOf" Version="3.0.223" />
<PackageReference Include="Protobuf.Text" Version="0.7.0" />
<PackageReference Include="Razorvine.Pickle" Version="1.4.0" />
<PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" />
</ItemGroup>
</Project>

+ 59
- 4
src/TensorFlowNET.Keras/Datasets/Imdb.cs View File

@@ -5,6 +5,13 @@ using System.Text;
using Tensorflow.Keras.Utils;
using Tensorflow.NumPy;
using System.Linq;
using Google.Protobuf.Collections;
using Microsoft.VisualBasic;
using OneOf.Types;
using static HDF.PInvoke.H5;
using System.Data;
using System.Reflection.Emit;
using System.Xml.Linq;

namespace Tensorflow.Keras.Datasets
{
@@ -12,13 +19,59 @@ namespace Tensorflow.Keras.Datasets
/// This is a dataset of 25,000 movies reviews from IMDB, labeled by sentiment
/// (positive/negative). Reviews have been preprocessed, and each review is
/// encoded as a list of word indexes(integers).
/// For convenience, words are indexed by overall frequency in the dataset,
/// so that for instance the integer "3" encodes the 3rd most frequent word in
/// the data.This allows for quick filtering operations such as:
/// "only consider the top 10,000 most
/// common words, but eliminate the top 20 most common words".
/// As a convention, "0" does not stand for a specific word, but instead is used
/// to encode the pad token.
/// Args:
/// path: where to cache the data (relative to %TEMP%/imdb/imdb.npz).
/// num_words: integer or None.Words are
/// ranked by how often they occur(in the training set) and only
/// the `num_words` most frequent words are kept.Any less frequent word
/// will appear as `oov_char` value in the sequence data.If None,
/// all words are kept.Defaults to `None`.
/// skip_top: skip the top N most frequently occurring words
/// (which may not be informative). These words will appear as
/// `oov_char` value in the dataset.When 0, no words are
/// skipped. Defaults to `0`.
/// maxlen: int or None.Maximum sequence length.
/// Any longer sequence will be truncated. None, means no truncation.
/// Defaults to `None`.
/// seed: int. Seed for reproducible data shuffling.
/// start_char: int. The start of a sequence will be marked with this
/// character. 0 is usually the padding character. Defaults to `1`.
/// oov_char: int. The out-of-vocabulary character.
/// Words that were cut out because of the `num_words` or
/// `skip_top` limits will be replaced with this character.
/// index_from: int. Index actual words with this index and higher.
/// Returns:
/// Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
///
/// ** x_train, x_test**: lists of sequences, which are lists of indexes
/// (integers). If the num_words argument was specific, the maximum
/// possible index value is `num_words - 1`. If the `maxlen` argument was
/// specified, the largest possible sequence length is `maxlen`.
///
/// ** y_train, y_test**: lists of integer labels(1 or 0).
///
/// Raises:
/// ValueError: in case `maxlen` is so low
/// that no input sequence could be kept.
/// Note that the 'out of vocabulary' character is only used for
/// words that were present in the training set but are not included
/// because they're not making the `num_words` cut here.
/// Words that were not seen in the training set but are in the test set
/// have simply been skipped.
/// </summary>
/// """Loads the [IMDB dataset](https://ai.stanford.edu/~amaas/data/sentiment/).
public class Imdb
{
string origin_folder = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/";
string file_name = "imdb.npz";
string dest_folder = "imdb";

/// <summary>
/// Loads the [IMDB dataset](https://ai.stanford.edu/~amaas/data/sentiment/).
/// </summary>
@@ -41,8 +94,10 @@ namespace Tensorflow.Keras.Datasets
int index_from = 3)
{
var dst = Download();

var lines = File.ReadAllLines(Path.Combine(dst, "imdb_train.txt"));
var fileBytes = File.ReadAllBytes(Path.Combine(dst, file_name));
var (x_train, x_test) = LoadX(fileBytes);
var (y_train, y_test) = LoadY(fileBytes);
/*var lines = File.ReadAllLines(Path.Combine(dst, "imdb_train.txt"));
var x_train_string = new string[lines.Length];
var y_train = np.zeros(new int[] { lines.Length }, np.int64);
for (int i = 0; i < lines.Length; i++)
@@ -62,7 +117,7 @@ namespace Tensorflow.Keras.Datasets
x_test_string[i] = lines[i].Substring(2);
}

var x_test = np.array(x_test_string);
var x_test = np.array(x_test_string);*/

return new DatasetPass
{


+ 17
- 0
test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs View File

@@ -1,7 +1,9 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Linq;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace TensorFlowNET.UnitTest.Dataset
{
@@ -195,5 +197,20 @@ namespace TensorFlowNET.UnitTest.Dataset

Assert.IsFalse(allEqual);
}
[TestMethod]
public void GetData()
{
var vocab_size = 20000; // Only consider the top 20k words
var maxlen = 200; // Only consider the first 200 words of each movie review
var dataset = keras.datasets.imdb.load_data(num_words: vocab_size);
var x_train = dataset.Train.Item1;
var y_train = dataset.Train.Item2;
var x_val = dataset.Test.Item1;
var y_val = dataset.Test.Item2;
print(len(x_train) + "Training sequences");
print(len(x_val) + "Validation sequences");
x_train = keras.preprocessing.sequence.pad_sequences((IEnumerable<int[]>)x_train, maxlen: maxlen);
x_val = keras.preprocessing.sequence.pad_sequences((IEnumerable<int[]>)x_val, maxlen: maxlen);
}
}
}

Loading…
Cancel
Save