diff --git a/test/TensorFlowNET.Examples/TextProcess/NER/LstmCrfNer.cs b/test/TensorFlowNET.Examples/TextProcess/NER/LstmCrfNer.cs index f34b132b..1b83da37 100644 --- a/test/TensorFlowNET.Examples/TextProcess/NER/LstmCrfNer.cs +++ b/test/TensorFlowNET.Examples/TextProcess/NER/LstmCrfNer.cs @@ -39,6 +39,9 @@ namespace TensorFlowNET.Examples.Text.NER Tensor labels_tensor; Tensor dropout_tensor; Tensor lr_tensor; + Operation train_op; + Tensor loss; + Tensor merged; public bool Run() { @@ -47,6 +50,9 @@ namespace TensorFlowNET.Examples.Text.NER tf.train.import_meta_graph("graph/lstm_crf_ner.meta"); + float loss_value = 0f; + + //add_summary(); word_ids_tensor = graph.OperationByName("word_ids"); sequence_lengths_tensor = graph.OperationByName("sequence_lengths"); char_ids_tensor = graph.OperationByName("char_ids"); @@ -54,6 +60,9 @@ namespace TensorFlowNET.Examples.Text.NER labels_tensor = graph.OperationByName("labels"); dropout_tensor = graph.OperationByName("dropout"); lr_tensor = graph.OperationByName("lr"); + train_op = graph.OperationByName("train_step/Adam"); + loss = graph.OperationByName("Mean"); + //merged = graph.OperationByName("Merge/MergeSummary"); var init = tf.global_variables_initializer(); @@ -63,24 +72,28 @@ namespace TensorFlowNET.Examples.Text.NER foreach (var epoch in range(hp.epochs)) { - print($"Epoch {epoch + 1} out of {hp.epochs}"); - run_epoch(train, dev, epoch); + Console.Write($"Epoch {epoch + 1} out of {hp.epochs}, "); + loss_value = run_epoch(sess, train, dev, epoch); + print($"train loss: {loss_value}"); } - }); - return true; + return loss_value < 0.1; } - private void run_epoch(CoNLLDataset train, CoNLLDataset dev, int epoch) + private float run_epoch(Session sess, CoNLLDataset train, CoNLLDataset dev, int epoch) { - int i = 0; + NDArray results = null; + // iterate over dataset var batches = minibatches(train, hp.batch_size); foreach (var(words, labels) in batches) { - get_feed_dict(words, labels, hp.lr, hp.dropout); + var (fd, _) = get_feed_dict(words, labels, hp.lr, hp.dropout); + results = sess.run(new ITensorOrOperation[] { train_op, loss }, feed_dict: fd); } + + return results[1]; } private IEnumerable<((int[][], int[])[], int[][])> minibatches(CoNLLDataset data, int minibatch_size) @@ -115,7 +128,7 @@ namespace TensorFlowNET.Examples.Text.NER /// list of ids /// learning rate /// keep prob - private FeedItem[] get_feed_dict((int[][], int[])[] words, int[][] labels, float lr = 0f, float dropout = 0f) + private (FeedItem[], int[]) get_feed_dict((int[][], int[])[] words, int[][] labels, float lr = 0f, float dropout = 0f) { int[] sequence_lengths; int[][] word_lengths; @@ -140,14 +153,21 @@ namespace TensorFlowNET.Examples.Text.NER feeds.Add(new FeedItem(word_lengths_tensor, np.array(word_lengths))); } - throw new NotImplementedException("get_feed_dict"); + (labels, _) = pad_sequences(labels, 0); + feeds.Add(new FeedItem(labels_tensor, np.array(labels))); + + feeds.Add(new FeedItem(lr_tensor, lr)); + + feeds.Add(new FeedItem(dropout_tensor, dropout)); + + return (feeds.ToArray(), sequence_lengths); } public void PrepareData() { hp = new HyperParams("LstmCrfNer") { - epochs = 15, + epochs = 50, dropout = 0.5f, batch_size = 20, lr_method = "adam",