@@ -6,5 +6,8 @@ namespace Tensorflow.Keras | |||
{ | |||
public Sequence sequence => new Sequence(); | |||
public DatasetUtils dataset_utils => new DatasetUtils(); | |||
public TextApi text => _text; | |||
private static TextApi _text = new TextApi(); | |||
} | |||
} |
@@ -0,0 +1,176 @@ | |||
using NumSharp; | |||
using Serilog.Debugging; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Collections.Specialized; | |||
using System.Linq; | |||
using System.Net.Sockets; | |||
using System.Text; | |||
namespace Tensorflow.Keras.Text | |||
{ | |||
public class Tokenizer | |||
{ | |||
private readonly int num_words; | |||
private readonly string filters; | |||
private readonly bool lower; | |||
private readonly char split; | |||
private readonly bool char_level; | |||
private readonly string oov_token; | |||
private readonly Func<string, IEnumerable<string>> analyzer; | |||
private int document_count = 0; | |||
private Dictionary<string, int> word_docs = new Dictionary<string, int>(); | |||
private Dictionary<string, int> word_counts = new Dictionary<string, int>(); | |||
public Dictionary<string, int> word_index = null; | |||
public Dictionary<int, string> index_word = null; | |||
private Dictionary<int, int> index_docs = null; | |||
public Tokenizer( | |||
int num_words = -1, | |||
string filters = "!\"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n", | |||
bool lower = true, | |||
char split = ' ', | |||
bool char_level = false, | |||
string oov_token = null, | |||
Func<string, IEnumerable<string>> analyzer = null) | |||
{ | |||
this.num_words = num_words; | |||
this.filters = filters; | |||
this.lower = lower; | |||
this.split = split; | |||
this.char_level = char_level; | |||
this.oov_token = oov_token; | |||
this.analyzer = analyzer; | |||
} | |||
public void fit_on_texts(IEnumerable<string> texts) | |||
{ | |||
foreach (var text in texts) | |||
{ | |||
IEnumerable<string> seq = null; | |||
document_count += 1; | |||
if (char_level) | |||
{ | |||
throw new NotImplementedException("char_level == true"); | |||
} | |||
else | |||
{ | |||
seq = analyzer(lower ? text.ToLower() : text); | |||
} | |||
foreach (var w in seq) | |||
{ | |||
var count = 0; | |||
word_counts.TryGetValue(w, out count); | |||
word_counts[w] = count + 1; | |||
} | |||
foreach (var w in new HashSet<string>(seq)) | |||
{ | |||
var count = 0; | |||
word_docs.TryGetValue(w, out count); | |||
word_docs[w] = count + 1; | |||
} | |||
} | |||
var wcounts = word_counts.AsEnumerable().ToList(); | |||
wcounts.Sort((kv1, kv2) => -kv1.Value.CompareTo(kv2.Value)); | |||
var sorted_voc = (oov_token == null) ? new List<string>() : new List<string>(){oov_token}; | |||
sorted_voc.AddRange(word_counts.Select(kv => kv.Key)); | |||
if (num_words > 0 -1) | |||
{ | |||
sorted_voc = sorted_voc.Take<string>((oov_token == null) ? num_words : num_words + 1).ToList(); | |||
} | |||
word_index = new Dictionary<string, int>(sorted_voc.Count); | |||
index_word = new Dictionary<int,string>(sorted_voc.Count); | |||
index_docs = new Dictionary<int, int>(word_docs.Count); | |||
for (int i = 0; i < sorted_voc.Count; i++) | |||
{ | |||
word_index.Add(sorted_voc[i], i + 1); | |||
index_word.Add(i + 1, sorted_voc[i]); | |||
} | |||
foreach (var kv in word_docs) | |||
{ | |||
var idx = -1; | |||
if (word_index.TryGetValue(kv.Key, out idx)) | |||
{ | |||
index_docs.Add(idx, kv.Value); | |||
} | |||
} | |||
} | |||
public void fit_on_sequences(IEnumerable<int[]> sequences) | |||
{ | |||
throw new NotImplementedException("fit_on_sequences"); | |||
} | |||
public IList<int[]> texts_to_sequences(IEnumerable<string> texts) | |||
{ | |||
return texts_to_sequences_generator(texts).ToArray(); | |||
} | |||
public IEnumerable<int[]> texts_to_sequences_generator(IEnumerable<string> texts) | |||
{ | |||
int oov_index = -1; | |||
var _ = (oov_token != null) && word_index.TryGetValue(oov_token, out oov_index); | |||
return texts.Select(text => { | |||
IEnumerable<string> seq = null; | |||
if (char_level) | |||
{ | |||
throw new NotImplementedException("char_level == true"); | |||
} | |||
else | |||
{ | |||
seq = analyzer(lower ? text.ToLower() : text); | |||
} | |||
var vect = new List<int>(); | |||
foreach (var w in seq) | |||
{ | |||
var i = -1; | |||
if (word_index.TryGetValue(w, out i)) | |||
{ | |||
if (num_words != -1 && i >= num_words) | |||
{ | |||
if (oov_index != -1) | |||
{ | |||
vect.Add(oov_index); | |||
} | |||
} | |||
else | |||
{ | |||
vect.Add(i); | |||
} | |||
} | |||
else | |||
{ | |||
vect.Add(oov_index); | |||
} | |||
} | |||
return vect.ToArray(); | |||
}); | |||
} | |||
public IEnumerable<string> sequences_to_texts(IEnumerable<int[]> sequences) | |||
{ | |||
throw new NotImplementedException("sequences_to_texts"); | |||
} | |||
public NDArray sequences_to_matrix(IEnumerable<int[]> sequences) | |||
{ | |||
throw new NotImplementedException("sequences_to_matrix"); | |||
} | |||
} | |||
} |
@@ -15,7 +15,9 @@ | |||
******************************************************************************/ | |||
using NumSharp; | |||
using NumSharp.Utilities; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
namespace Tensorflow.Keras | |||
@@ -34,14 +36,18 @@ namespace Tensorflow.Keras | |||
/// <param name="truncating">String, 'pre' or 'post'</param> | |||
/// <param name="value">Float or String, padding value.</param> | |||
/// <returns></returns> | |||
public NDArray pad_sequences(NDArray sequences, | |||
public NDArray pad_sequences(IEnumerable<int[]> sequences, | |||
int? maxlen = null, | |||
string dtype = "int32", | |||
string padding = "pre", | |||
string truncating = "pre", | |||
object value = null) | |||
{ | |||
int[] length = new int[sequences.size]; | |||
if (value != null) throw new NotImplementedException("padding with a specific value."); | |||
if (padding != "pre" && padding != "post") throw new InvalidArgumentError("padding must be 'pre' or 'post'."); | |||
if (truncating != "pre" && truncating != "post") throw new InvalidArgumentError("truncating must be 'pre' or 'post'."); | |||
var length = sequences.Select(s => s.Length); | |||
if (maxlen == null) | |||
maxlen = length.Max(); | |||
@@ -49,19 +55,28 @@ namespace Tensorflow.Keras | |||
if (value == null) | |||
value = 0f; | |||
var nd = new NDArray(np.int32, new Shape(sequences.size, maxlen.Value)); | |||
var type = getNPType(dtype); | |||
var nd = new NDArray(type, new Shape(length.Count(), maxlen.Value),true); | |||
#pragma warning disable CS0162 // Unreachable code detected | |||
for (int i = 0; i < nd.shape[0]; i++) | |||
#pragma warning restore CS0162 // Unreachable code detected | |||
{ | |||
switch (sequences[i]) | |||
var s = sequences.ElementAt(i); | |||
if (s.Length > maxlen.Value) | |||
{ | |||
default: | |||
throw new NotImplementedException("pad_sequences"); | |||
s = (truncating == "pre") ? s.Slice(s.Length - maxlen.Value, s.Length) : s.Slice(0, maxlen.Value); | |||
} | |||
var sliceString = (padding == "pre") ? (i.ToString() + "," + (maxlen-s.Length).ToString() + ":") : (i.ToString() + ",:" + s.Length); | |||
nd[sliceString] = np.array(s); | |||
} | |||
return nd; | |||
} | |||
private Type getNPType(string typeName) | |||
{ | |||
return System.Type.GetType("NumSharp.np,NumSharp.Lite").GetField(typeName).GetValue(null) as Type; | |||
} | |||
} | |||
} |
@@ -0,0 +1,42 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Keras.Text; | |||
namespace Tensorflow.Keras | |||
{ | |||
public class TextApi | |||
{ | |||
public Tensorflow.Keras.Text.Tokenizer Tokenizer( | |||
int num_words = -1, | |||
string filters = DefaultFilter, | |||
bool lower = true, | |||
char split = ' ', | |||
bool char_level = false, | |||
string oov_token = null, | |||
Func<string, IEnumerable<string>> analyzer = null) | |||
{ | |||
if (analyzer != null) | |||
{ | |||
return new Keras.Text.Tokenizer(num_words, filters, lower, split, char_level, oov_token, analyzer); | |||
} | |||
else | |||
{ | |||
return new Keras.Text.Tokenizer(num_words, filters, lower, split, char_level, oov_token, (text) => text_to_word_sequence(text, filters, lower, split)); | |||
} | |||
} | |||
public static IEnumerable<string> text_to_word_sequence(string text, string filters = DefaultFilter, bool lower = true, char split = ' ') | |||
{ | |||
if (lower) | |||
{ | |||
text = text.ToLower(); | |||
} | |||
var newText = new String(text.Where(c => !filters.Contains(c)).ToArray()); | |||
return newText.Split(split); | |||
} | |||
private const string DefaultFilter = "!\"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n"; | |||
} | |||
} |
@@ -0,0 +1,120 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System; | |||
using System.Linq; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using NumSharp; | |||
using static Tensorflow.KerasApi; | |||
using Tensorflow; | |||
using Tensorflow.Keras.Datasets; | |||
namespace TensorFlowNET.Keras.UnitTest | |||
{ | |||
[TestClass] | |||
public class PreprocessingTests : EagerModeTestBase | |||
{ | |||
private readonly string[] texts = new string[] { | |||
"It was the best of times, it was the worst of times.", | |||
"this is a new dawn, an era to follow the previous era. It can not be said to start anew.", | |||
"It was the best of times, it was the worst of times.", | |||
"this is a new dawn, an era to follow the previous era.", | |||
}; | |||
private const string OOV = "<OOV>"; | |||
[TestMethod] | |||
public void TokenizeWithNoOOV() | |||
{ | |||
var tokenizer = keras.preprocessing.text.Tokenizer(lower: true); | |||
tokenizer.fit_on_texts(texts); | |||
Assert.AreEqual(23, tokenizer.word_index.Count); | |||
Assert.AreEqual(7, tokenizer.word_index["worst"]); | |||
Assert.AreEqual(12, tokenizer.word_index["dawn"]); | |||
Assert.AreEqual(16, tokenizer.word_index["follow"]); | |||
} | |||
[TestMethod] | |||
public void TokenizeWithOOV() | |||
{ | |||
var tokenizer = keras.preprocessing.text.Tokenizer(lower: true, oov_token: OOV); | |||
tokenizer.fit_on_texts(texts); | |||
Assert.AreEqual(24, tokenizer.word_index.Count); | |||
Assert.AreEqual(1, tokenizer.word_index[OOV]); | |||
Assert.AreEqual(8, tokenizer.word_index["worst"]); | |||
Assert.AreEqual(13, tokenizer.word_index["dawn"]); | |||
Assert.AreEqual(17, tokenizer.word_index["follow"]); | |||
} | |||
[TestMethod] | |||
public void PadSequencesWithDefaults() | |||
{ | |||
var tokenizer = keras.preprocessing.text.Tokenizer(lower: true, oov_token: OOV); | |||
tokenizer.fit_on_texts(texts); | |||
var sequences = tokenizer.texts_to_sequences(texts); | |||
var padded = keras.preprocessing.sequence.pad_sequences(sequences); | |||
Assert.AreEqual(4, padded.shape[0]); | |||
Assert.AreEqual(20, padded.shape[1]); | |||
var firstRow = padded[0]; | |||
var secondRow = padded[1]; | |||
Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 17].GetInt32()); | |||
for (var i = 0; i < 8; i++) | |||
Assert.AreEqual(0, padded[0, i].GetInt32()); | |||
Assert.AreEqual(tokenizer.word_index["previous"], padded[1, 10].GetInt32()); | |||
for (var i = 0; i < 20; i++) | |||
Assert.AreNotEqual(0, padded[1, i].GetInt32()); | |||
} | |||
[TestMethod] | |||
public void PadSequencesPrePaddingTrunc() | |||
{ | |||
var tokenizer = keras.preprocessing.text.Tokenizer(lower: true, oov_token: OOV); | |||
tokenizer.fit_on_texts(texts); | |||
var sequences = tokenizer.texts_to_sequences(texts); | |||
var padded = keras.preprocessing.sequence.pad_sequences(sequences,maxlen:15); | |||
Assert.AreEqual(4, padded.shape[0]); | |||
Assert.AreEqual(15, padded.shape[1]); | |||
var firstRow = padded[0]; | |||
var secondRow = padded[1]; | |||
Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 12].GetInt32()); | |||
for (var i = 0; i < 3; i++) | |||
Assert.AreEqual(0, padded[0, i].GetInt32()); | |||
Assert.AreEqual(tokenizer.word_index["previous"], padded[1, 5].GetInt32()); | |||
for (var i = 0; i < 15; i++) | |||
Assert.AreNotEqual(0, padded[1, i].GetInt32()); | |||
} | |||
[TestMethod] | |||
public void PadSequencesPostPaddingTrunc() | |||
{ | |||
var tokenizer = keras.preprocessing.text.Tokenizer(lower: true, oov_token: OOV); | |||
tokenizer.fit_on_texts(texts); | |||
var sequences = tokenizer.texts_to_sequences(texts); | |||
var padded = keras.preprocessing.sequence.pad_sequences(sequences, maxlen: 15, padding: "post", truncating: "post"); | |||
Assert.AreEqual(4, padded.shape[0]); | |||
Assert.AreEqual(15, padded.shape[1]); | |||
var firstRow = padded[0]; | |||
var secondRow = padded[1]; | |||
Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 9].GetInt32()); | |||
for (var i = 12; i < 15; i++) | |||
Assert.AreEqual(0, padded[0, i].GetInt32()); | |||
Assert.AreEqual(tokenizer.word_index["previous"], padded[1, 10].GetInt32()); | |||
for (var i = 0; i < 15; i++) | |||
Assert.AreNotEqual(0, padded[1, i].GetInt32()); | |||
} | |||
} | |||
} |
@@ -6,6 +6,8 @@ | |||
<IsPackable>false</IsPackable> | |||
<Platforms>AnyCPU;x64</Platforms> | |||
<RootNamespace>TensorflowNET.Keras.UnitTest</RootNamespace> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | |||