diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln
index b96f8203..23c4296c 100644
--- a/TensorFlow.NET.sln
+++ b/TensorFlow.NET.sln
@@ -9,6 +9,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "t
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\TensorFlowNET.Core\TensorFlowNET.Core.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}"
EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{265765E1-C746-4241-AF2B-39B8045292D8}"
+EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -27,6 +29,10 @@ Global
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.Build.0 = Debug|Any CPU
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.ActiveCfg = Release|Any CPU
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.Build.0 = Release|Any CPU
+ {265765E1-C746-4241-AF2B-39B8045292D8}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {265765E1-C746-4241-AF2B-39B8045292D8}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {265765E1-C746-4241-AF2B-39B8045292D8}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {265765E1-C746-4241-AF2B-39B8045292D8}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
index 3aa0192b..2aa6ad17 100644
--- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
+++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
@@ -62,4 +62,8 @@ Add Word2Vec example.
+
+
+
+
diff --git a/test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs b/test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs
index bb5d5675..4bc1d84d 100644
--- a/test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs
+++ b/test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs
@@ -6,7 +6,7 @@ using System.Linq;
using System.Text;
using System.Text.RegularExpressions;
-namespace TensorFlowNET.Examples.CnnTextClassification
+namespace TensorFlowNET.Examples
{
public class DataHelpers
{
@@ -90,5 +90,65 @@ namespace TensorFlowNET.Examples.CnnTextClassification
str = Regex.Replace(str, @"\'s", " \'s");
return str;
}
+
+ ///
+ /// Padding
+ ///
+ ///
+ /// the char to pad with
+ /// a list of list where each sublist has same length
+ public static (int[][], int[]) pad_sequences(int[][] sequences, int pad_tok = 0)
+ {
+ int max_length = sequences.Select(x => x.Length).Max();
+ return _pad_sequences(sequences, pad_tok, max_length);
+ }
+
+ public static (int[][][], int[][]) pad_sequences(int[][][] sequences, int pad_tok = 0)
+ {
+ int max_length_word = sequences.Select(x => x.Select(w => w.Length).Max()).Max();
+ int[][][] sequence_padded;
+ var sequence_length = new int[sequences.Length][];
+ for (int i = 0; i < sequences.Length; i++)
+ {
+ // all words are same length now
+ var (sp, sl) = _pad_sequences(sequences[i], pad_tok, max_length_word);
+ sequence_length[i] = sl;
+ }
+
+ int max_length_sentence = sequences.Select(x => x.Length).Max();
+ (sequence_padded, _) = _pad_sequences(sequences, np.repeat(pad_tok, max_length_word).Data(), max_length_sentence);
+ (sequence_length, _) = _pad_sequences(sequence_length, 0, max_length_sentence);
+
+ return (sequence_padded, sequence_length);
+ }
+
+ private static (int[][], int[]) _pad_sequences(int[][] sequences, int pad_tok, int max_length)
+ {
+ var sequence_length = new int[sequences.Length];
+ for (int i = 0; i < sequences.Length; i++)
+ {
+ sequence_length[i] = sequences[i].Length;
+ Array.Resize(ref sequences[i], max_length);
+ }
+
+ return (sequences, sequence_length);
+ }
+
+ private static (int[][][], int[]) _pad_sequences(int[][][] sequences, int[] pad_tok, int max_length)
+ {
+ var sequence_length = new int[sequences.Length];
+ for (int i = 0; i < sequences.Length; i++)
+ {
+ sequence_length[i] = sequences[i].Length;
+ Array.Resize(ref sequences[i], max_length);
+ for (int j = 0; j < max_length - sequence_length[i]; j++)
+ {
+ sequences[i][max_length - j - 1] = new int[pad_tok.Length];
+ Array.Copy(pad_tok, sequences[i][max_length - j - 1], pad_tok.Length);
+ }
+ }
+
+ return (sequences, sequence_length);
+ }
}
}
diff --git a/test/TensorFlowNET.Examples/TextProcess/NER/LstmCrfNer.cs b/test/TensorFlowNET.Examples/TextProcess/NER/LstmCrfNer.cs
index 71e20b65..f34b132b 100644
--- a/test/TensorFlowNET.Examples/TextProcess/NER/LstmCrfNer.cs
+++ b/test/TensorFlowNET.Examples/TextProcess/NER/LstmCrfNer.cs
@@ -1,4 +1,5 @@
-using System;
+using NumSharp;
+using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
@@ -7,6 +8,7 @@ using Tensorflow;
using Tensorflow.Estimator;
using TensorFlowNET.Examples.Utility;
using static Tensorflow.Python;
+using static TensorFlowNET.Examples.DataHelpers;
namespace TensorFlowNET.Examples.Text.NER
{
@@ -27,10 +29,17 @@ namespace TensorFlowNET.Examples.Text.NER
HyperParams hp;
- Dictionary vocab_tags = new Dictionary();
int nwords, nchars, ntags;
CoNLLDataset dev, train;
+ Tensor word_ids_tensor;
+ Tensor sequence_lengths_tensor;
+ Tensor char_ids_tensor;
+ Tensor word_lengths_tensor;
+ Tensor labels_tensor;
+ Tensor dropout_tensor;
+ Tensor lr_tensor;
+
public bool Run()
{
PrepareData();
@@ -38,6 +47,14 @@ namespace TensorFlowNET.Examples.Text.NER
tf.train.import_meta_graph("graph/lstm_crf_ner.meta");
+ word_ids_tensor = graph.OperationByName("word_ids");
+ sequence_lengths_tensor = graph.OperationByName("sequence_lengths");
+ char_ids_tensor = graph.OperationByName("char_ids");
+ word_lengths_tensor = graph.OperationByName("word_lengths");
+ labels_tensor = graph.OperationByName("labels");
+ dropout_tensor = graph.OperationByName("dropout");
+ lr_tensor = graph.OperationByName("lr");
+
var init = tf.global_variables_initializer();
with(tf.Session(), sess =>
@@ -47,6 +64,7 @@ namespace TensorFlowNET.Examples.Text.NER
foreach (var epoch in range(hp.epochs))
{
print($"Epoch {epoch + 1} out of {hp.epochs}");
+ run_epoch(train, dev, epoch);
}
});
@@ -54,6 +72,77 @@ namespace TensorFlowNET.Examples.Text.NER
return true;
}
+ private void run_epoch(CoNLLDataset train, CoNLLDataset dev, int epoch)
+ {
+ int i = 0;
+ // iterate over dataset
+ var batches = minibatches(train, hp.batch_size);
+ foreach (var(words, labels) in batches)
+ {
+ get_feed_dict(words, labels, hp.lr, hp.dropout);
+ }
+ }
+
+ private IEnumerable<((int[][], int[])[], int[][])> minibatches(CoNLLDataset data, int minibatch_size)
+ {
+ var x_batch = new List<(int[][], int[])>();
+ var y_batch = new List();
+ foreach(var (x, y) in data.GetItems())
+ {
+ if (len(y_batch) == minibatch_size)
+ {
+ yield return (x_batch.ToArray(), y_batch.ToArray());
+ x_batch.Clear();
+ y_batch.Clear();
+ }
+
+ var x3 = (x.Select(x1 => x1.Item1).ToArray(), x.Select(x2 => x2.Item2).ToArray());
+ x_batch.Add(x3);
+ y_batch.Add(y);
+ }
+
+ if (len(y_batch) > 0)
+ yield return (x_batch.ToArray(), y_batch.ToArray());
+ }
+
+ ///
+ /// Given some data, pad it and build a feed dictionary
+ ///
+ ///
+ /// list of sentences. A sentence is a list of ids of a list of
+ /// words. A word is a list of ids
+ ///
+ /// list of ids
+ /// learning rate
+ /// keep prob
+ private FeedItem[] get_feed_dict((int[][], int[])[] words, int[][] labels, float lr = 0f, float dropout = 0f)
+ {
+ int[] sequence_lengths;
+ int[][] word_lengths;
+ int[][] word_ids;
+ int[][][] char_ids;
+
+ if (true) // use_chars
+ {
+ (char_ids, word_ids) = (words.Select(x => x.Item1).ToArray(), words.Select(x => x.Item2).ToArray());
+ (word_ids, sequence_lengths) = pad_sequences(word_ids, pad_tok: 0);
+ (char_ids, word_lengths) = pad_sequences(char_ids, pad_tok: 0);
+ }
+
+ // build feed dictionary
+ var feeds = new List();
+ feeds.Add(new FeedItem(word_ids_tensor, np.array(word_ids)));
+ feeds.Add(new FeedItem(sequence_lengths_tensor, np.array(sequence_lengths)));
+
+ if(true) // use_chars
+ {
+ feeds.Add(new FeedItem(char_ids_tensor, np.array(char_ids)));
+ feeds.Add(new FeedItem(word_lengths_tensor, np.array(word_lengths)));
+ }
+
+ throw new NotImplementedException("get_feed_dict");
+ }
+
public void PrepareData()
{
hp = new HyperParams("LstmCrfNer")
diff --git a/test/TensorFlowNET.Examples/Utility/CoNLLDataset.cs b/test/TensorFlowNET.Examples/Utility/CoNLLDataset.cs
index 8fc7b25a..9b50bfd6 100644
--- a/test/TensorFlowNET.Examples/Utility/CoNLLDataset.cs
+++ b/test/TensorFlowNET.Examples/Utility/CoNLLDataset.cs
@@ -8,13 +8,14 @@ using Tensorflow.Estimator;
namespace TensorFlowNET.Examples.Utility
{
- public class CoNLLDataset : IEnumerable
+ public class CoNLLDataset
{
static Dictionary vocab_chars;
static Dictionary vocab_words;
+ static Dictionary vocab_tags;
- List> _elements;
HyperParams _hp;
+ string _path;
public CoNLLDataset(string path, HyperParams hp)
{
@@ -24,22 +25,10 @@ namespace TensorFlowNET.Examples.Utility
if (vocab_words == null)
vocab_words = load_vocab(hp.filepath_words);
- var lines = File.ReadAllLines(path);
+ if (vocab_tags == null)
+ vocab_tags = load_vocab(hp.filepath_tags);
- foreach (var l in lines)
- {
- string line = l.Trim();
- if (string.IsNullOrEmpty(line) || line.StartsWith("-DOCSTART-"))
- {
-
- }
- else
- {
- var ls = line.Split(' ');
- // process word
- var word = processing_word(ls[0]);
- }
- }
+ _path = path;
}
private (int[], int) processing_word(string word)
@@ -58,6 +47,20 @@ namespace TensorFlowNET.Examples.Utility
return (char_ids, id);
}
+ private int processing_tag(string word)
+ {
+ // 1. preprocess word
+ if (false) // lowercase
+ word = word.ToLower();
+ if (false) // isdigit
+ word = "$NUM$";
+
+ // 2. get id of word
+ int id = vocab_tags.GetValueOrDefault(word, -1);
+
+ return id;
+ }
+
private Dictionary load_vocab(string filename)
{
var dict = new Dictionary();
@@ -68,9 +71,38 @@ namespace TensorFlowNET.Examples.Utility
return dict;
}
- public IEnumerator GetEnumerator()
+ public IEnumerable<((int[], int)[], int[])> GetItems()
{
- return _elements.GetEnumerator();
+ var lines = File.ReadAllLines(_path);
+
+ int niter = 0;
+ var words = new List<(int[], int)>();
+ var tags = new List();
+
+ foreach (var l in lines)
+ {
+ string line = l.Trim();
+ if (string.IsNullOrEmpty(line) || line.StartsWith("-DOCSTART-"))
+ {
+ if (words.Count > 0)
+ {
+ niter++;
+ yield return (words.ToArray(), tags.ToArray());
+ words.Clear();
+ tags.Clear();
+ }
+ }
+ else
+ {
+ var ls = line.Split(' ');
+ // process word
+ var word = processing_word(ls[0]);
+ var tag = processing_tag(ls[1]);
+
+ words.Add(word);
+ tags.Add(tag);
+ }
+ }
}
}
}