Browse Source

Finished NER training but missing run_evaluate function.

tags/v0.9
Oceania2018 6 years ago
parent
commit
3208ee56f8
1 changed files with 30 additions and 10 deletions
  1. +30
    -10
      test/TensorFlowNET.Examples/TextProcess/NER/LstmCrfNer.cs

+ 30
- 10
test/TensorFlowNET.Examples/TextProcess/NER/LstmCrfNer.cs View File

@@ -39,6 +39,9 @@ namespace TensorFlowNET.Examples.Text.NER
Tensor labels_tensor; Tensor labels_tensor;
Tensor dropout_tensor; Tensor dropout_tensor;
Tensor lr_tensor; Tensor lr_tensor;
Operation train_op;
Tensor loss;
Tensor merged;


public bool Run() public bool Run()
{ {
@@ -47,6 +50,9 @@ namespace TensorFlowNET.Examples.Text.NER


tf.train.import_meta_graph("graph/lstm_crf_ner.meta"); tf.train.import_meta_graph("graph/lstm_crf_ner.meta");


float loss_value = 0f;

//add_summary();
word_ids_tensor = graph.OperationByName("word_ids"); word_ids_tensor = graph.OperationByName("word_ids");
sequence_lengths_tensor = graph.OperationByName("sequence_lengths"); sequence_lengths_tensor = graph.OperationByName("sequence_lengths");
char_ids_tensor = graph.OperationByName("char_ids"); char_ids_tensor = graph.OperationByName("char_ids");
@@ -54,6 +60,9 @@ namespace TensorFlowNET.Examples.Text.NER
labels_tensor = graph.OperationByName("labels"); labels_tensor = graph.OperationByName("labels");
dropout_tensor = graph.OperationByName("dropout"); dropout_tensor = graph.OperationByName("dropout");
lr_tensor = graph.OperationByName("lr"); lr_tensor = graph.OperationByName("lr");
train_op = graph.OperationByName("train_step/Adam");
loss = graph.OperationByName("Mean");
//merged = graph.OperationByName("Merge/MergeSummary");


var init = tf.global_variables_initializer(); var init = tf.global_variables_initializer();


@@ -63,24 +72,28 @@ namespace TensorFlowNET.Examples.Text.NER


foreach (var epoch in range(hp.epochs)) foreach (var epoch in range(hp.epochs))
{ {
print($"Epoch {epoch + 1} out of {hp.epochs}");
run_epoch(train, dev, epoch);
Console.Write($"Epoch {epoch + 1} out of {hp.epochs}, ");
loss_value = run_epoch(sess, train, dev, epoch);
print($"train loss: {loss_value}");
} }

}); });


return true;
return loss_value < 0.1;
} }


private void run_epoch(CoNLLDataset train, CoNLLDataset dev, int epoch)
private float run_epoch(Session sess, CoNLLDataset train, CoNLLDataset dev, int epoch)
{ {
int i = 0;
NDArray results = null;

// iterate over dataset // iterate over dataset
var batches = minibatches(train, hp.batch_size); var batches = minibatches(train, hp.batch_size);
foreach (var(words, labels) in batches) foreach (var(words, labels) in batches)
{ {
get_feed_dict(words, labels, hp.lr, hp.dropout);
var (fd, _) = get_feed_dict(words, labels, hp.lr, hp.dropout);
results = sess.run(new ITensorOrOperation[] { train_op, loss }, feed_dict: fd);
} }

return results[1];
} }


private IEnumerable<((int[][], int[])[], int[][])> minibatches(CoNLLDataset data, int minibatch_size) private IEnumerable<((int[][], int[])[], int[][])> minibatches(CoNLLDataset data, int minibatch_size)
@@ -115,7 +128,7 @@ namespace TensorFlowNET.Examples.Text.NER
/// <param name="labels">list of ids</param> /// <param name="labels">list of ids</param>
/// <param name="lr">learning rate</param> /// <param name="lr">learning rate</param>
/// <param name="dropout">keep prob</param> /// <param name="dropout">keep prob</param>
private FeedItem[] get_feed_dict((int[][], int[])[] words, int[][] labels, float lr = 0f, float dropout = 0f)
private (FeedItem[], int[]) get_feed_dict((int[][], int[])[] words, int[][] labels, float lr = 0f, float dropout = 0f)
{ {
int[] sequence_lengths; int[] sequence_lengths;
int[][] word_lengths; int[][] word_lengths;
@@ -140,14 +153,21 @@ namespace TensorFlowNET.Examples.Text.NER
feeds.Add(new FeedItem(word_lengths_tensor, np.array(word_lengths))); feeds.Add(new FeedItem(word_lengths_tensor, np.array(word_lengths)));
} }


throw new NotImplementedException("get_feed_dict");
(labels, _) = pad_sequences(labels, 0);
feeds.Add(new FeedItem(labels_tensor, np.array(labels)));

feeds.Add(new FeedItem(lr_tensor, lr));

feeds.Add(new FeedItem(dropout_tensor, dropout));

return (feeds.ToArray(), sequence_lengths);
} }


public void PrepareData() public void PrepareData()
{ {
hp = new HyperParams("LstmCrfNer") hp = new HyperParams("LstmCrfNer")
{ {
epochs = 15,
epochs = 50,
dropout = 0.5f, dropout = 0.5f,
batch_size = 20, batch_size = 20,
lr_method = "adam", lr_method = "adam",


Loading…
Cancel
Save