Browse Source

completed the CoNLLDataset load for LSTM + CRF NER example.

tags/v0.9
Oceania2018 6 years ago
parent
commit
664e1c779b
5 changed files with 213 additions and 22 deletions
  1. +6
    -0
      TensorFlow.NET.sln
  2. +4
    -0
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  3. +61
    -1
      test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs
  4. +91
    -2
      test/TensorFlowNET.Examples/TextProcess/NER/LstmCrfNer.cs
  5. +51
    -19
      test/TensorFlowNET.Examples/Utility/CoNLLDataset.cs

+ 6
- 0
TensorFlow.NET.sln View File

@@ -9,6 +9,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "t
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\TensorFlowNET.Core\TensorFlowNET.Core.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{265765E1-C746-4241-AF2B-39B8045292D8}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -27,6 +29,10 @@ Global
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.Build.0 = Debug|Any CPU
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.ActiveCfg = Release|Any CPU
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.Build.0 = Release|Any CPU
{265765E1-C746-4241-AF2B-39B8045292D8}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{265765E1-C746-4241-AF2B-39B8045292D8}.Debug|Any CPU.Build.0 = Debug|Any CPU
{265765E1-C746-4241-AF2B-39B8045292D8}.Release|Any CPU.ActiveCfg = Release|Any CPU
{265765E1-C746-4241-AF2B-39B8045292D8}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE


+ 4
- 0
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -62,4 +62,8 @@ Add Word2Vec example.</PackageReleaseNotes>
<Folder Include="Keras\Initializers\" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" />
</ItemGroup>

</Project>

+ 61
- 1
test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs View File

@@ -6,7 +6,7 @@ using System.Linq;
using System.Text;
using System.Text.RegularExpressions;

namespace TensorFlowNET.Examples.CnnTextClassification
namespace TensorFlowNET.Examples
{
public class DataHelpers
{
@@ -90,5 +90,65 @@ namespace TensorFlowNET.Examples.CnnTextClassification
str = Regex.Replace(str, @"\'s", " \'s");
return str;
}

/// <summary>
/// Padding
/// </summary>
/// <param name="sequences"></param>
/// <param name="pad_tok">the char to pad with</param>
/// <returns>a list of list where each sublist has same length</returns>
public static (int[][], int[]) pad_sequences(int[][] sequences, int pad_tok = 0)
{
int max_length = sequences.Select(x => x.Length).Max();
return _pad_sequences(sequences, pad_tok, max_length);
}

public static (int[][][], int[][]) pad_sequences(int[][][] sequences, int pad_tok = 0)
{
int max_length_word = sequences.Select(x => x.Select(w => w.Length).Max()).Max();
int[][][] sequence_padded;
var sequence_length = new int[sequences.Length][];
for (int i = 0; i < sequences.Length; i++)
{
// all words are same length now
var (sp, sl) = _pad_sequences(sequences[i], pad_tok, max_length_word);
sequence_length[i] = sl;
}

int max_length_sentence = sequences.Select(x => x.Length).Max();
(sequence_padded, _) = _pad_sequences(sequences, np.repeat(pad_tok, max_length_word).Data<int>(), max_length_sentence);
(sequence_length, _) = _pad_sequences(sequence_length, 0, max_length_sentence);

return (sequence_padded, sequence_length);
}

private static (int[][], int[]) _pad_sequences(int[][] sequences, int pad_tok, int max_length)
{
var sequence_length = new int[sequences.Length];
for (int i = 0; i < sequences.Length; i++)
{
sequence_length[i] = sequences[i].Length;
Array.Resize(ref sequences[i], max_length);
}

return (sequences, sequence_length);
}

private static (int[][][], int[]) _pad_sequences(int[][][] sequences, int[] pad_tok, int max_length)
{
var sequence_length = new int[sequences.Length];
for (int i = 0; i < sequences.Length; i++)
{
sequence_length[i] = sequences[i].Length;
Array.Resize(ref sequences[i], max_length);
for (int j = 0; j < max_length - sequence_length[i]; j++)
{
sequences[i][max_length - j - 1] = new int[pad_tok.Length];
Array.Copy(pad_tok, sequences[i][max_length - j - 1], pad_tok.Length);
}
}

return (sequences, sequence_length);
}
}
}

+ 91
- 2
test/TensorFlowNET.Examples/TextProcess/NER/LstmCrfNer.cs View File

@@ -1,4 +1,5 @@
using System;
using NumSharp;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
@@ -7,6 +8,7 @@ using Tensorflow;
using Tensorflow.Estimator;
using TensorFlowNET.Examples.Utility;
using static Tensorflow.Python;
using static TensorFlowNET.Examples.DataHelpers;

namespace TensorFlowNET.Examples.Text.NER
{
@@ -27,10 +29,17 @@ namespace TensorFlowNET.Examples.Text.NER

HyperParams hp;
Dictionary<string, int> vocab_tags = new Dictionary<string, int>();
int nwords, nchars, ntags;
CoNLLDataset dev, train;

Tensor word_ids_tensor;
Tensor sequence_lengths_tensor;
Tensor char_ids_tensor;
Tensor word_lengths_tensor;
Tensor labels_tensor;
Tensor dropout_tensor;
Tensor lr_tensor;

public bool Run()
{
PrepareData();
@@ -38,6 +47,14 @@ namespace TensorFlowNET.Examples.Text.NER

tf.train.import_meta_graph("graph/lstm_crf_ner.meta");

word_ids_tensor = graph.OperationByName("word_ids");
sequence_lengths_tensor = graph.OperationByName("sequence_lengths");
char_ids_tensor = graph.OperationByName("char_ids");
word_lengths_tensor = graph.OperationByName("word_lengths");
labels_tensor = graph.OperationByName("labels");
dropout_tensor = graph.OperationByName("dropout");
lr_tensor = graph.OperationByName("lr");

var init = tf.global_variables_initializer();

with(tf.Session(), sess =>
@@ -47,6 +64,7 @@ namespace TensorFlowNET.Examples.Text.NER
foreach (var epoch in range(hp.epochs))
{
print($"Epoch {epoch + 1} out of {hp.epochs}");
run_epoch(train, dev, epoch);
}

});
@@ -54,6 +72,77 @@ namespace TensorFlowNET.Examples.Text.NER
return true;
}

private void run_epoch(CoNLLDataset train, CoNLLDataset dev, int epoch)
{
int i = 0;
// iterate over dataset
var batches = minibatches(train, hp.batch_size);
foreach (var(words, labels) in batches)
{
get_feed_dict(words, labels, hp.lr, hp.dropout);
}
}

private IEnumerable<((int[][], int[])[], int[][])> minibatches(CoNLLDataset data, int minibatch_size)
{
var x_batch = new List<(int[][], int[])>();
var y_batch = new List<int[]>();
foreach(var (x, y) in data.GetItems())
{
if (len(y_batch) == minibatch_size)
{
yield return (x_batch.ToArray(), y_batch.ToArray());
x_batch.Clear();
y_batch.Clear();
}

var x3 = (x.Select(x1 => x1.Item1).ToArray(), x.Select(x2 => x2.Item2).ToArray());
x_batch.Add(x3);
y_batch.Add(y);
}

if (len(y_batch) > 0)
yield return (x_batch.ToArray(), y_batch.ToArray());
}

/// <summary>
/// Given some data, pad it and build a feed dictionary
/// </summary>
/// <param name="words">
/// list of sentences. A sentence is a list of ids of a list of
/// words. A word is a list of ids
/// </param>
/// <param name="labels">list of ids</param>
/// <param name="lr">learning rate</param>
/// <param name="dropout">keep prob</param>
private FeedItem[] get_feed_dict((int[][], int[])[] words, int[][] labels, float lr = 0f, float dropout = 0f)
{
int[] sequence_lengths;
int[][] word_lengths;
int[][] word_ids;
int[][][] char_ids;

if (true) // use_chars
{
(char_ids, word_ids) = (words.Select(x => x.Item1).ToArray(), words.Select(x => x.Item2).ToArray());
(word_ids, sequence_lengths) = pad_sequences(word_ids, pad_tok: 0);
(char_ids, word_lengths) = pad_sequences(char_ids, pad_tok: 0);
}

// build feed dictionary
var feeds = new List<FeedItem>();
feeds.Add(new FeedItem(word_ids_tensor, np.array(word_ids)));
feeds.Add(new FeedItem(sequence_lengths_tensor, np.array(sequence_lengths)));
if(true) // use_chars
{
feeds.Add(new FeedItem(char_ids_tensor, np.array(char_ids)));
feeds.Add(new FeedItem(word_lengths_tensor, np.array(word_lengths)));
}

throw new NotImplementedException("get_feed_dict");
}

public void PrepareData()
{
hp = new HyperParams("LstmCrfNer")


+ 51
- 19
test/TensorFlowNET.Examples/Utility/CoNLLDataset.cs View File

@@ -8,13 +8,14 @@ using Tensorflow.Estimator;

namespace TensorFlowNET.Examples.Utility
{
public class CoNLLDataset : IEnumerable
public class CoNLLDataset
{
static Dictionary<string, int> vocab_chars;
static Dictionary<string, int> vocab_words;
static Dictionary<string, int> vocab_tags;

List<Tuple<int[], int>> _elements;
HyperParams _hp;
string _path;

public CoNLLDataset(string path, HyperParams hp)
{
@@ -24,22 +25,10 @@ namespace TensorFlowNET.Examples.Utility
if (vocab_words == null)
vocab_words = load_vocab(hp.filepath_words);

var lines = File.ReadAllLines(path);
if (vocab_tags == null)
vocab_tags = load_vocab(hp.filepath_tags);

foreach (var l in lines)
{
string line = l.Trim();
if (string.IsNullOrEmpty(line) || line.StartsWith("-DOCSTART-"))
{

}
else
{
var ls = line.Split(' ');
// process word
var word = processing_word(ls[0]);
}
}
_path = path;
}

private (int[], int) processing_word(string word)
@@ -58,6 +47,20 @@ namespace TensorFlowNET.Examples.Utility
return (char_ids, id);
}

private int processing_tag(string word)
{
// 1. preprocess word
if (false) // lowercase
word = word.ToLower();
if (false) // isdigit
word = "$NUM$";

// 2. get id of word
int id = vocab_tags.GetValueOrDefault(word, -1);

return id;
}

private Dictionary<string, int> load_vocab(string filename)
{
var dict = new Dictionary<string, int>();
@@ -68,9 +71,38 @@ namespace TensorFlowNET.Examples.Utility
return dict;
}

public IEnumerator GetEnumerator()
public IEnumerable<((int[], int)[], int[])> GetItems()
{
return _elements.GetEnumerator();
var lines = File.ReadAllLines(_path);

int niter = 0;
var words = new List<(int[], int)>();
var tags = new List<int>();

foreach (var l in lines)
{
string line = l.Trim();
if (string.IsNullOrEmpty(line) || line.StartsWith("-DOCSTART-"))
{
if (words.Count > 0)
{
niter++;
yield return (words.ToArray(), tags.ToArray());
words.Clear();
tags.Clear();
}
}
else
{
var ls = line.Split(' ');
// process word
var word = processing_word(ls[0]);
var tag = processing_tag(ls[1]);

words.Add(word);
tags.Add(tag);
}
}
}
}
}

Loading…
Cancel
Save