@@ -6,5 +6,8 @@ namespace Tensorflow.Keras | |||||
{ | { | ||||
public Sequence sequence => new Sequence(); | public Sequence sequence => new Sequence(); | ||||
public DatasetUtils dataset_utils => new DatasetUtils(); | public DatasetUtils dataset_utils => new DatasetUtils(); | ||||
public TextApi text => _text; | |||||
private static TextApi _text = new TextApi(); | |||||
} | } | ||||
} | } |
@@ -0,0 +1,176 @@ | |||||
using NumSharp; | |||||
using Serilog.Debugging; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Collections.Specialized; | |||||
using System.Linq; | |||||
using System.Net.Sockets; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.Text | |||||
{ | |||||
public class Tokenizer | |||||
{ | |||||
private readonly int num_words; | |||||
private readonly string filters; | |||||
private readonly bool lower; | |||||
private readonly char split; | |||||
private readonly bool char_level; | |||||
private readonly string oov_token; | |||||
private readonly Func<string, IEnumerable<string>> analyzer; | |||||
private int document_count = 0; | |||||
private Dictionary<string, int> word_docs = new Dictionary<string, int>(); | |||||
private Dictionary<string, int> word_counts = new Dictionary<string, int>(); | |||||
public Dictionary<string, int> word_index = null; | |||||
public Dictionary<int, string> index_word = null; | |||||
private Dictionary<int, int> index_docs = null; | |||||
public Tokenizer( | |||||
int num_words = -1, | |||||
string filters = "!\"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n", | |||||
bool lower = true, | |||||
char split = ' ', | |||||
bool char_level = false, | |||||
string oov_token = null, | |||||
Func<string, IEnumerable<string>> analyzer = null) | |||||
{ | |||||
this.num_words = num_words; | |||||
this.filters = filters; | |||||
this.lower = lower; | |||||
this.split = split; | |||||
this.char_level = char_level; | |||||
this.oov_token = oov_token; | |||||
this.analyzer = analyzer; | |||||
} | |||||
public void fit_on_texts(IEnumerable<string> texts) | |||||
{ | |||||
foreach (var text in texts) | |||||
{ | |||||
IEnumerable<string> seq = null; | |||||
document_count += 1; | |||||
if (char_level) | |||||
{ | |||||
throw new NotImplementedException("char_level == true"); | |||||
} | |||||
else | |||||
{ | |||||
seq = analyzer(lower ? text.ToLower() : text); | |||||
} | |||||
foreach (var w in seq) | |||||
{ | |||||
var count = 0; | |||||
word_counts.TryGetValue(w, out count); | |||||
word_counts[w] = count + 1; | |||||
} | |||||
foreach (var w in new HashSet<string>(seq)) | |||||
{ | |||||
var count = 0; | |||||
word_docs.TryGetValue(w, out count); | |||||
word_docs[w] = count + 1; | |||||
} | |||||
} | |||||
var wcounts = word_counts.AsEnumerable().ToList(); | |||||
wcounts.Sort((kv1, kv2) => -kv1.Value.CompareTo(kv2.Value)); | |||||
var sorted_voc = (oov_token == null) ? new List<string>() : new List<string>(){oov_token}; | |||||
sorted_voc.AddRange(word_counts.Select(kv => kv.Key)); | |||||
if (num_words > 0 -1) | |||||
{ | |||||
sorted_voc = sorted_voc.Take<string>((oov_token == null) ? num_words : num_words + 1).ToList(); | |||||
} | |||||
word_index = new Dictionary<string, int>(sorted_voc.Count); | |||||
index_word = new Dictionary<int,string>(sorted_voc.Count); | |||||
index_docs = new Dictionary<int, int>(word_docs.Count); | |||||
for (int i = 0; i < sorted_voc.Count; i++) | |||||
{ | |||||
word_index.Add(sorted_voc[i], i + 1); | |||||
index_word.Add(i + 1, sorted_voc[i]); | |||||
} | |||||
foreach (var kv in word_docs) | |||||
{ | |||||
var idx = -1; | |||||
if (word_index.TryGetValue(kv.Key, out idx)) | |||||
{ | |||||
index_docs.Add(idx, kv.Value); | |||||
} | |||||
} | |||||
} | |||||
public void fit_on_sequences(IEnumerable<int[]> sequences) | |||||
{ | |||||
throw new NotImplementedException("fit_on_sequences"); | |||||
} | |||||
public IList<int[]> texts_to_sequences(IEnumerable<string> texts) | |||||
{ | |||||
return texts_to_sequences_generator(texts).ToArray(); | |||||
} | |||||
public IEnumerable<int[]> texts_to_sequences_generator(IEnumerable<string> texts) | |||||
{ | |||||
int oov_index = -1; | |||||
var _ = (oov_token != null) && word_index.TryGetValue(oov_token, out oov_index); | |||||
return texts.Select(text => { | |||||
IEnumerable<string> seq = null; | |||||
if (char_level) | |||||
{ | |||||
throw new NotImplementedException("char_level == true"); | |||||
} | |||||
else | |||||
{ | |||||
seq = analyzer(lower ? text.ToLower() : text); | |||||
} | |||||
var vect = new List<int>(); | |||||
foreach (var w in seq) | |||||
{ | |||||
var i = -1; | |||||
if (word_index.TryGetValue(w, out i)) | |||||
{ | |||||
if (num_words != -1 && i >= num_words) | |||||
{ | |||||
if (oov_index != -1) | |||||
{ | |||||
vect.Add(oov_index); | |||||
} | |||||
} | |||||
else | |||||
{ | |||||
vect.Add(i); | |||||
} | |||||
} | |||||
else | |||||
{ | |||||
vect.Add(oov_index); | |||||
} | |||||
} | |||||
return vect.ToArray(); | |||||
}); | |||||
} | |||||
public IEnumerable<string> sequences_to_texts(IEnumerable<int[]> sequences) | |||||
{ | |||||
throw new NotImplementedException("sequences_to_texts"); | |||||
} | |||||
public NDArray sequences_to_matrix(IEnumerable<int[]> sequences) | |||||
{ | |||||
throw new NotImplementedException("sequences_to_matrix"); | |||||
} | |||||
} | |||||
} |
@@ -15,7 +15,9 @@ | |||||
******************************************************************************/ | ******************************************************************************/ | ||||
using NumSharp; | using NumSharp; | ||||
using NumSharp.Utilities; | |||||
using System; | using System; | ||||
using System.Collections.Generic; | |||||
using System.Linq; | using System.Linq; | ||||
namespace Tensorflow.Keras | namespace Tensorflow.Keras | ||||
@@ -34,14 +36,18 @@ namespace Tensorflow.Keras | |||||
/// <param name="truncating">String, 'pre' or 'post'</param> | /// <param name="truncating">String, 'pre' or 'post'</param> | ||||
/// <param name="value">Float or String, padding value.</param> | /// <param name="value">Float or String, padding value.</param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public NDArray pad_sequences(NDArray sequences, | |||||
public NDArray pad_sequences(IEnumerable<int[]> sequences, | |||||
int? maxlen = null, | int? maxlen = null, | ||||
string dtype = "int32", | string dtype = "int32", | ||||
string padding = "pre", | string padding = "pre", | ||||
string truncating = "pre", | string truncating = "pre", | ||||
object value = null) | object value = null) | ||||
{ | { | ||||
int[] length = new int[sequences.size]; | |||||
if (value != null) throw new NotImplementedException("padding with a specific value."); | |||||
if (padding != "pre" && padding != "post") throw new InvalidArgumentError("padding must be 'pre' or 'post'."); | |||||
if (truncating != "pre" && truncating != "post") throw new InvalidArgumentError("truncating must be 'pre' or 'post'."); | |||||
var length = sequences.Select(s => s.Length); | |||||
if (maxlen == null) | if (maxlen == null) | ||||
maxlen = length.Max(); | maxlen = length.Max(); | ||||
@@ -49,19 +55,28 @@ namespace Tensorflow.Keras | |||||
if (value == null) | if (value == null) | ||||
value = 0f; | value = 0f; | ||||
var nd = new NDArray(np.int32, new Shape(sequences.size, maxlen.Value)); | |||||
var type = getNPType(dtype); | |||||
var nd = new NDArray(type, new Shape(length.Count(), maxlen.Value),true); | |||||
#pragma warning disable CS0162 // Unreachable code detected | #pragma warning disable CS0162 // Unreachable code detected | ||||
for (int i = 0; i < nd.shape[0]; i++) | for (int i = 0; i < nd.shape[0]; i++) | ||||
#pragma warning restore CS0162 // Unreachable code detected | #pragma warning restore CS0162 // Unreachable code detected | ||||
{ | { | ||||
switch (sequences[i]) | |||||
var s = sequences.ElementAt(i); | |||||
if (s.Length > maxlen.Value) | |||||
{ | { | ||||
default: | |||||
throw new NotImplementedException("pad_sequences"); | |||||
s = (truncating == "pre") ? s.Slice(s.Length - maxlen.Value, s.Length) : s.Slice(0, maxlen.Value); | |||||
} | } | ||||
var sliceString = (padding == "pre") ? (i.ToString() + "," + (maxlen-s.Length).ToString() + ":") : (i.ToString() + ",:" + s.Length); | |||||
nd[sliceString] = np.array(s); | |||||
} | } | ||||
return nd; | return nd; | ||||
} | } | ||||
private Type getNPType(string typeName) | |||||
{ | |||||
return System.Type.GetType("NumSharp.np,NumSharp.Lite").GetField(typeName).GetValue(null) as Type; | |||||
} | |||||
} | } | ||||
} | } |
@@ -0,0 +1,42 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using Tensorflow.Keras.Text; | |||||
namespace Tensorflow.Keras | |||||
{ | |||||
public class TextApi | |||||
{ | |||||
public Tensorflow.Keras.Text.Tokenizer Tokenizer( | |||||
int num_words = -1, | |||||
string filters = DefaultFilter, | |||||
bool lower = true, | |||||
char split = ' ', | |||||
bool char_level = false, | |||||
string oov_token = null, | |||||
Func<string, IEnumerable<string>> analyzer = null) | |||||
{ | |||||
if (analyzer != null) | |||||
{ | |||||
return new Keras.Text.Tokenizer(num_words, filters, lower, split, char_level, oov_token, analyzer); | |||||
} | |||||
else | |||||
{ | |||||
return new Keras.Text.Tokenizer(num_words, filters, lower, split, char_level, oov_token, (text) => text_to_word_sequence(text, filters, lower, split)); | |||||
} | |||||
} | |||||
public static IEnumerable<string> text_to_word_sequence(string text, string filters = DefaultFilter, bool lower = true, char split = ' ') | |||||
{ | |||||
if (lower) | |||||
{ | |||||
text = text.ToLower(); | |||||
} | |||||
var newText = new String(text.Where(c => !filters.Contains(c)).ToArray()); | |||||
return newText.Split(split); | |||||
} | |||||
private const string DefaultFilter = "!\"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n"; | |||||
} | |||||
} |
@@ -0,0 +1,120 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using System; | |||||
using System.Linq; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using NumSharp; | |||||
using static Tensorflow.KerasApi; | |||||
using Tensorflow; | |||||
using Tensorflow.Keras.Datasets; | |||||
namespace TensorFlowNET.Keras.UnitTest | |||||
{ | |||||
[TestClass] | |||||
public class PreprocessingTests : EagerModeTestBase | |||||
{ | |||||
private readonly string[] texts = new string[] { | |||||
"It was the best of times, it was the worst of times.", | |||||
"this is a new dawn, an era to follow the previous era. It can not be said to start anew.", | |||||
"It was the best of times, it was the worst of times.", | |||||
"this is a new dawn, an era to follow the previous era.", | |||||
}; | |||||
private const string OOV = "<OOV>"; | |||||
[TestMethod] | |||||
public void TokenizeWithNoOOV() | |||||
{ | |||||
var tokenizer = keras.preprocessing.text.Tokenizer(lower: true); | |||||
tokenizer.fit_on_texts(texts); | |||||
Assert.AreEqual(23, tokenizer.word_index.Count); | |||||
Assert.AreEqual(7, tokenizer.word_index["worst"]); | |||||
Assert.AreEqual(12, tokenizer.word_index["dawn"]); | |||||
Assert.AreEqual(16, tokenizer.word_index["follow"]); | |||||
} | |||||
[TestMethod] | |||||
public void TokenizeWithOOV() | |||||
{ | |||||
var tokenizer = keras.preprocessing.text.Tokenizer(lower: true, oov_token: OOV); | |||||
tokenizer.fit_on_texts(texts); | |||||
Assert.AreEqual(24, tokenizer.word_index.Count); | |||||
Assert.AreEqual(1, tokenizer.word_index[OOV]); | |||||
Assert.AreEqual(8, tokenizer.word_index["worst"]); | |||||
Assert.AreEqual(13, tokenizer.word_index["dawn"]); | |||||
Assert.AreEqual(17, tokenizer.word_index["follow"]); | |||||
} | |||||
[TestMethod] | |||||
public void PadSequencesWithDefaults() | |||||
{ | |||||
var tokenizer = keras.preprocessing.text.Tokenizer(lower: true, oov_token: OOV); | |||||
tokenizer.fit_on_texts(texts); | |||||
var sequences = tokenizer.texts_to_sequences(texts); | |||||
var padded = keras.preprocessing.sequence.pad_sequences(sequences); | |||||
Assert.AreEqual(4, padded.shape[0]); | |||||
Assert.AreEqual(20, padded.shape[1]); | |||||
var firstRow = padded[0]; | |||||
var secondRow = padded[1]; | |||||
Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 17].GetInt32()); | |||||
for (var i = 0; i < 8; i++) | |||||
Assert.AreEqual(0, padded[0, i].GetInt32()); | |||||
Assert.AreEqual(tokenizer.word_index["previous"], padded[1, 10].GetInt32()); | |||||
for (var i = 0; i < 20; i++) | |||||
Assert.AreNotEqual(0, padded[1, i].GetInt32()); | |||||
} | |||||
[TestMethod] | |||||
public void PadSequencesPrePaddingTrunc() | |||||
{ | |||||
var tokenizer = keras.preprocessing.text.Tokenizer(lower: true, oov_token: OOV); | |||||
tokenizer.fit_on_texts(texts); | |||||
var sequences = tokenizer.texts_to_sequences(texts); | |||||
var padded = keras.preprocessing.sequence.pad_sequences(sequences,maxlen:15); | |||||
Assert.AreEqual(4, padded.shape[0]); | |||||
Assert.AreEqual(15, padded.shape[1]); | |||||
var firstRow = padded[0]; | |||||
var secondRow = padded[1]; | |||||
Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 12].GetInt32()); | |||||
for (var i = 0; i < 3; i++) | |||||
Assert.AreEqual(0, padded[0, i].GetInt32()); | |||||
Assert.AreEqual(tokenizer.word_index["previous"], padded[1, 5].GetInt32()); | |||||
for (var i = 0; i < 15; i++) | |||||
Assert.AreNotEqual(0, padded[1, i].GetInt32()); | |||||
} | |||||
[TestMethod] | |||||
public void PadSequencesPostPaddingTrunc() | |||||
{ | |||||
var tokenizer = keras.preprocessing.text.Tokenizer(lower: true, oov_token: OOV); | |||||
tokenizer.fit_on_texts(texts); | |||||
var sequences = tokenizer.texts_to_sequences(texts); | |||||
var padded = keras.preprocessing.sequence.pad_sequences(sequences, maxlen: 15, padding: "post", truncating: "post"); | |||||
Assert.AreEqual(4, padded.shape[0]); | |||||
Assert.AreEqual(15, padded.shape[1]); | |||||
var firstRow = padded[0]; | |||||
var secondRow = padded[1]; | |||||
Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 9].GetInt32()); | |||||
for (var i = 12; i < 15; i++) | |||||
Assert.AreEqual(0, padded[0, i].GetInt32()); | |||||
Assert.AreEqual(tokenizer.word_index["previous"], padded[1, 10].GetInt32()); | |||||
for (var i = 0; i < 15; i++) | |||||
Assert.AreNotEqual(0, padded[1, i].GetInt32()); | |||||
} | |||||
} | |||||
} |
@@ -6,6 +6,8 @@ | |||||
<IsPackable>false</IsPackable> | <IsPackable>false</IsPackable> | ||||
<Platforms>AnyCPU;x64</Platforms> | <Platforms>AnyCPU;x64</Platforms> | ||||
<RootNamespace>TensorflowNET.Keras.UnitTest</RootNamespace> | |||||
</PropertyGroup> | </PropertyGroup> | ||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | ||||