diff --git a/test/TensorFlowNET.Examples/TextProcess/NER/LstmCrfNer.cs b/test/TensorFlowNET.Examples/TextProcess/NER/LstmCrfNer.cs
index f34b132b..1b83da37 100644
--- a/test/TensorFlowNET.Examples/TextProcess/NER/LstmCrfNer.cs
+++ b/test/TensorFlowNET.Examples/TextProcess/NER/LstmCrfNer.cs
@@ -39,6 +39,9 @@ namespace TensorFlowNET.Examples.Text.NER
Tensor labels_tensor;
Tensor dropout_tensor;
Tensor lr_tensor;
+ Operation train_op;
+ Tensor loss;
+ Tensor merged;
public bool Run()
{
@@ -47,6 +50,9 @@ namespace TensorFlowNET.Examples.Text.NER
tf.train.import_meta_graph("graph/lstm_crf_ner.meta");
+ float loss_value = 0f;
+
+ //add_summary();
word_ids_tensor = graph.OperationByName("word_ids");
sequence_lengths_tensor = graph.OperationByName("sequence_lengths");
char_ids_tensor = graph.OperationByName("char_ids");
@@ -54,6 +60,9 @@ namespace TensorFlowNET.Examples.Text.NER
labels_tensor = graph.OperationByName("labels");
dropout_tensor = graph.OperationByName("dropout");
lr_tensor = graph.OperationByName("lr");
+ train_op = graph.OperationByName("train_step/Adam");
+ loss = graph.OperationByName("Mean");
+ //merged = graph.OperationByName("Merge/MergeSummary");
var init = tf.global_variables_initializer();
@@ -63,24 +72,28 @@ namespace TensorFlowNET.Examples.Text.NER
foreach (var epoch in range(hp.epochs))
{
- print($"Epoch {epoch + 1} out of {hp.epochs}");
- run_epoch(train, dev, epoch);
+ Console.Write($"Epoch {epoch + 1} out of {hp.epochs}, ");
+ loss_value = run_epoch(sess, train, dev, epoch);
+ print($"train loss: {loss_value}");
}
-
});
- return true;
+ return loss_value < 0.1;
}
- private void run_epoch(CoNLLDataset train, CoNLLDataset dev, int epoch)
+ private float run_epoch(Session sess, CoNLLDataset train, CoNLLDataset dev, int epoch)
{
- int i = 0;
+ NDArray results = null;
+
// iterate over dataset
var batches = minibatches(train, hp.batch_size);
foreach (var(words, labels) in batches)
{
- get_feed_dict(words, labels, hp.lr, hp.dropout);
+ var (fd, _) = get_feed_dict(words, labels, hp.lr, hp.dropout);
+ results = sess.run(new ITensorOrOperation[] { train_op, loss }, feed_dict: fd);
}
+
+ return results[1];
}
private IEnumerable<((int[][], int[])[], int[][])> minibatches(CoNLLDataset data, int minibatch_size)
@@ -115,7 +128,7 @@ namespace TensorFlowNET.Examples.Text.NER
/// list of ids
/// learning rate
/// keep prob
- private FeedItem[] get_feed_dict((int[][], int[])[] words, int[][] labels, float lr = 0f, float dropout = 0f)
+ private (FeedItem[], int[]) get_feed_dict((int[][], int[])[] words, int[][] labels, float lr = 0f, float dropout = 0f)
{
int[] sequence_lengths;
int[][] word_lengths;
@@ -140,14 +153,21 @@ namespace TensorFlowNET.Examples.Text.NER
feeds.Add(new FeedItem(word_lengths_tensor, np.array(word_lengths)));
}
- throw new NotImplementedException("get_feed_dict");
+ (labels, _) = pad_sequences(labels, 0);
+ feeds.Add(new FeedItem(labels_tensor, np.array(labels)));
+
+ feeds.Add(new FeedItem(lr_tensor, lr));
+
+ feeds.Add(new FeedItem(dropout_tensor, dropout));
+
+ return (feeds.ToArray(), sequence_lengths);
}
public void PrepareData()
{
hp = new HyperParams("LstmCrfNer")
{
- epochs = 15,
+ epochs = 50,
dropout = 0.5f,
batch_size = 20,
lr_method = "adam",