From caeb0e3050b34fe8e13a41ef7e2257c3340807f3 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 17 Aug 2019 13:47:06 -0500 Subject: [PATCH] overloade session.run(), make syntax simpler. --- .../Sessions/BaseSession.cs | 28 +++++++++++++++++ src/TensorFlowNET.Core/Sessions/FeedItem.cs | 3 ++ src/TensorFlowNET.Core/ops.py.cs | 2 +- .../BasicModels/LinearRegression.cs | 10 +++---- .../BasicModels/LogisticRegression.cs | 9 +++--- .../BasicModels/NearestNeighbor.cs | 4 +-- .../BasicModels/NeuralNetXor.cs | 12 +++----- .../TensorFlowNET.Examples/BasicOperations.cs | 8 ++--- test/TensorFlowNET.Examples/HelloWorld.cs | 2 +- .../ImageProcessing/DigitRecognitionCNN.cs | 10 ++----- .../ImageProcessing/DigitRecognitionNN.cs | 15 +++------- .../ImageRecognitionInception.cs | 4 +-- .../ImageProcessing/ObjectDetection.cs | 2 +- .../ImageProcessing/RetrainImageClassifier.cs | 30 ++++++++----------- .../TextProcessing/CnnTextClassification.cs | 22 ++++---------- .../TextProcessing/NER/LstmCrfNer.cs | 7 ++--- .../TextProcessing/Word2Vec.cs | 6 ++-- 17 files changed, 87 insertions(+), 87 deletions(-) diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 0f0001eb..efe2afd4 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -54,6 +54,34 @@ namespace Tensorflow status.Check(true); } + public virtual void run(Operation op, params FeedItem[] feed_dict) + { + _run(op, feed_dict); + } + + public virtual NDArray run(Tensor fetche, params FeedItem[] feed_dict) + { + return _run(fetche, feed_dict)[0]; + } + + public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) + { + var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); + return (results[0], results[1], results[2], results[3]); + } + + public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) + { + var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict); + return (results[0], results[1], results[2]); + } + + public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) + { + var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict); + return (results[0], results[1]); + } + public virtual NDArray[] run(object fetches, params FeedItem[] feed_dict) { return _run(fetches, feed_dict); diff --git a/src/TensorFlowNET.Core/Sessions/FeedItem.cs b/src/TensorFlowNET.Core/Sessions/FeedItem.cs index f60f7fb0..f87457e7 100644 --- a/src/TensorFlowNET.Core/Sessions/FeedItem.cs +++ b/src/TensorFlowNET.Core/Sessions/FeedItem.cs @@ -13,5 +13,8 @@ Key = key; Value = val; } + + public static implicit operator FeedItem((object, object) feed) + => new FeedItem(feed.Item1, feed.Item2); } } diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index ef313f49..979e132e 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -377,7 +377,7 @@ namespace Tensorflow "`eval(session=sess)`."); } - return session.run(tensor, feed_dict)[0]; + return session.run(tensor, feed_dict); } /// diff --git a/test/TensorFlowNET.Examples/BasicModels/LinearRegression.cs b/test/TensorFlowNET.Examples/BasicModels/LinearRegression.cs index 6fb6c08d..0098404d 100644 --- a/test/TensorFlowNET.Examples/BasicModels/LinearRegression.cs +++ b/test/TensorFlowNET.Examples/BasicModels/LinearRegression.cs @@ -91,16 +91,16 @@ namespace TensorFlowNET.Examples { var c = sess.run(cost, new FeedItem(X, train_X), - new FeedItem(Y, train_Y))[0]; - Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)[0]} b={sess.run(b)[0]}"); + new FeedItem(Y, train_Y)); + Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)} b={sess.run(b)}"); } } Console.WriteLine("Optimization Finished!"); var training_cost = sess.run(cost, new FeedItem(X, train_X), - new FeedItem(Y, train_Y))[0]; - Console.WriteLine($"Training cost={training_cost} W={sess.run(W)[0]} b={sess.run(b)[0]}"); + new FeedItem(Y, train_Y)); + Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}"); // Testing example var test_X = np.array(6.83f, 4.668f, 8.9f, 7.91f, 5.7f, 8.7f, 3.1f, 2.1f); @@ -108,7 +108,7 @@ namespace TensorFlowNET.Examples Console.WriteLine("Testing... (Mean square loss Comparison)"); var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]), new FeedItem(X, test_X), - new FeedItem(Y, test_Y))[0]; + new FeedItem(Y, test_Y)); Console.WriteLine($"Testing cost={testing_cost}"); var diff = Math.Abs((float)training_cost - (float)testing_cost); Console.WriteLine($"Absolute mean square loss difference: {diff}"); diff --git a/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs b/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs index 3f6756f6..1b34b961 100644 --- a/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs +++ b/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs @@ -90,11 +90,10 @@ namespace TensorFlowNET.Examples { var (batch_xs, batch_ys) = mnist.Train.GetNextBatch(batch_size); // Run optimization op (backprop) and cost op (to get loss value) - var result = sess.run(new object[] { optimizer, cost }, - new FeedItem(x, batch_xs), - new FeedItem(y, batch_ys)); + (_, float c) = sess.run((optimizer, cost), + (x, batch_xs), + (y, batch_ys)); - float c = result[1]; // Compute average loss avg_cost += c / total_batch; } @@ -115,7 +114,7 @@ namespace TensorFlowNET.Examples var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)); // Calculate accuracy var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)); - float acc = accuracy.eval(new FeedItem(x, mnist.Test.Data), new FeedItem(y, mnist.Test.Labels)); + float acc = accuracy.eval((x, mnist.Test.Data), (y, mnist.Test.Labels)); print($"Accuracy: {acc.ToString("F4")}"); return acc > 0.9; diff --git a/test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs b/test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs index 5e471e6e..eb96d275 100644 --- a/test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs +++ b/test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs @@ -64,7 +64,7 @@ namespace TensorFlowNET.Examples foreach(int i in range(Xte.shape[0])) { // Get nearest neighbor - long nn_index = sess.run(pred, new FeedItem(xtr, Xtr), new FeedItem(xte, Xte[i]))[0]; + long nn_index = sess.run(pred, (xtr, Xtr), (xte, Xte[i])); // Get nearest neighbor class label and compare it to its true label int index = (int)nn_index; @@ -72,7 +72,7 @@ namespace TensorFlowNET.Examples print($"Test {i} Prediction: {np.argmax(Ytr[index])} True Class: {np.argmax(Yte[i])}"); // Calculate accuracy - if ((int)np.argmax(Ytr[index]) == (int)np.argmax(Yte[i])) + if (np.argmax(Ytr[index]) == np.argmax(Yte[i])) accuracy += 1f/ Xte.shape[0]; } diff --git a/test/TensorFlowNET.Examples/BasicModels/NeuralNetXor.cs b/test/TensorFlowNET.Examples/BasicModels/NeuralNetXor.cs index 12687e3f..8b158012 100644 --- a/test/TensorFlowNET.Examples/BasicModels/NeuralNetXor.cs +++ b/test/TensorFlowNET.Examples/BasicModels/NeuralNetXor.cs @@ -103,10 +103,8 @@ namespace TensorFlowNET.Examples // [train_op, gs, loss], // feed_dict={features: xy, labels: y_} // ) - var result = sess.run(new ITensorOrOperation[] { train_op, global_step, loss }, new FeedItem(features, data), new FeedItem(labels, y_)); - loss_value = result[2]; - step = result[1]; - if (step % 1000 == 0) + (_, step, loss_value) = sess.run((train_op, global_step, loss), (features, data), (labels, y_)); + if (step == 1 || step % 1000 == 0) Console.WriteLine($"Step {step} loss: {loss_value}"); } Console.WriteLine($"Final loss: {loss_value}"); @@ -136,10 +134,8 @@ namespace TensorFlowNET.Examples var y_ = np.array(new int[] { 1, 0, 0, 1 }, dtype: np.int32); while (step < num_steps) { - var result = sess.run(new ITensorOrOperation[] { train_op, gs, loss }, new FeedItem(features, data), new FeedItem(labels, y_)); - loss_value = result[2]; - step = result[1]; - if (step % 1000 == 0) + (_, step, loss_value) = sess.run((train_op, gs, loss), (features, data), (labels, y_)); + if (step == 1 || step % 1000 == 0) Console.WriteLine($"Step {step} loss: {loss_value}"); } Console.WriteLine($"Final loss: {loss_value}"); diff --git a/test/TensorFlowNET.Examples/BasicOperations.cs b/test/TensorFlowNET.Examples/BasicOperations.cs index d95d5424..c7314abe 100644 --- a/test/TensorFlowNET.Examples/BasicOperations.cs +++ b/test/TensorFlowNET.Examples/BasicOperations.cs @@ -53,8 +53,8 @@ namespace TensorFlowNET.Examples new FeedItem(b, (short)3) }; // Run every operation with variable input - Console.WriteLine($"Addition with variables: {sess.run(add, feed_dict)[0]}"); - Console.WriteLine($"Multiplication with variables: {sess.run(mul, feed_dict)[0]}"); + Console.WriteLine($"Addition with variables: {sess.run(add, feed_dict)}"); + Console.WriteLine($"Multiplication with variables: {sess.run(mul, feed_dict)}"); } // ---------------- @@ -91,7 +91,7 @@ namespace TensorFlowNET.Examples // The output of the op is returned in 'result' as a numpy `ndarray` object. using (sess = tf.Session()) { - var result = sess.run(product)[0]; + var result = sess.run(product); Console.WriteLine(result.ToString()); // ==> [[ 12.]] }; @@ -136,7 +136,7 @@ namespace TensorFlowNET.Examples var checkTensor = np.array(0, 6, 0, 15, 0, 24, 3, 1, 6, 4, 9, 7, 6, 0, 15, 0, 24, 0); using (var sess = tf.Session()) { - var result = sess.run(batchMul)[0]; + var result = sess.run(batchMul); Console.WriteLine(result.ToString()); // // ==> array([[[0, 6], diff --git a/test/TensorFlowNET.Examples/HelloWorld.cs b/test/TensorFlowNET.Examples/HelloWorld.cs index da4a924a..52e47e3d 100644 --- a/test/TensorFlowNET.Examples/HelloWorld.cs +++ b/test/TensorFlowNET.Examples/HelloWorld.cs @@ -28,7 +28,7 @@ namespace TensorFlowNET.Examples using (var sess = tf.Session()) { // Run the op - var result = sess.run(hello)[0]; + var result = sess.run(hello); Console.WriteLine(result.ToString()); return result.ToString().Equals(str); } diff --git a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs index 25171974..9c22b149 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs @@ -160,7 +160,7 @@ namespace TensorFlowNET.Examples var (x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end); // Run optimization op (backprop) - sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch)); + sess.run(optimizer, (x, x_batch), (y, y_batch)); if (iteration % display_freq == 0) { @@ -174,9 +174,7 @@ namespace TensorFlowNET.Examples } // Run validation after every epoch - var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_valid), new FeedItem(y, y_valid)); - loss_val = results1[0]; - accuracy_val = results1[1]; + (loss_val, accuracy_val) = sess.run((loss, accuracy), (x, x_valid), (y, y_valid)); print("---------------------------------------------------------"); print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}"); print("---------------------------------------------------------"); @@ -185,9 +183,7 @@ namespace TensorFlowNET.Examples public void Test(Session sess) { - var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_test), new FeedItem(y, y_test)); - loss_test = result[0]; - accuracy_test = result[1]; + (loss_test, accuracy_test) = sess.run((loss, accuracy), (x, x_test), (y, y_test)); print("---------------------------------------------------------"); print($"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}"); print("---------------------------------------------------------"); diff --git a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs index ebc1240a..9125c286 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs @@ -148,23 +148,18 @@ namespace TensorFlowNET.Examples var (x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end); // Run optimization op (backprop) - sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch)); + sess.run(optimizer, (x, x_batch), (y, y_batch)); if (iteration % display_freq == 0) { // Calculate and display the batch loss and accuracy - var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_batch), new FeedItem(y, y_batch)); - loss_val = result[0]; - accuracy_val = result[1]; + (loss_val, accuracy_val) = sess.run((loss, accuracy), (x, x_batch), (y, y_batch)); print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")}"); } } // Run validation after every epoch - var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.Validation.Data), new FeedItem(y, mnist.Validation.Labels)); - - loss_val = results1[0]; - accuracy_val = results1[1]; + (loss_val, accuracy_val) = sess.run((loss, accuracy), (x, mnist.Validation.Data), (y, mnist.Validation.Labels)); print("---------------------------------------------------------"); print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}"); print("---------------------------------------------------------"); @@ -173,9 +168,7 @@ namespace TensorFlowNET.Examples public void Test(Session sess) { - var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.Test.Data), new FeedItem(y, mnist.Test.Labels)); - loss_test = result[0]; - accuracy_test = result[1]; + (loss_test, accuracy_test) = sess.run((loss, accuracy), (x, mnist.Test.Data), (y, mnist.Test.Labels)); print("---------------------------------------------------------"); print($"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}"); print("---------------------------------------------------------"); diff --git a/test/TensorFlowNET.Examples/ImageProcessing/ImageRecognitionInception.cs b/test/TensorFlowNET.Examples/ImageProcessing/ImageRecognitionInception.cs index 366961e2..85e0357a 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/ImageRecognitionInception.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/ImageRecognitionInception.cs @@ -51,7 +51,7 @@ namespace TensorFlowNET.Examples { sw.Restart(); - var results = sess.run(output_operation.outputs[0], new FeedItem(input_operation.outputs[0], nd))[0]; + var results = sess.run(output_operation.outputs[0], (input_operation.outputs[0], nd)); results = np.squeeze(results); int idx = np.argmax(results); @@ -81,7 +81,7 @@ namespace TensorFlowNET.Examples var normalized = tf.divide(sub, new float[] { input_std }); using (var sess = tf.Session(graph)) - return sess.run(normalized)[0]; + return sess.run(normalized); } public void PrepareData() diff --git a/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection.cs b/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection.cs index a21a0e6f..f5c967a7 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection.cs @@ -108,7 +108,7 @@ namespace TensorFlowNET.Examples var dims_expander = tf.expand_dims(casted, 0); using (var sess = tf.Session(graph)) - return sess.run(dims_expander)[0]; + return sess.run(dims_expander); } private void buildOutputImage(NDArray[] resultArr) diff --git a/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs b/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs index 7246c42e..92442d17 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs @@ -124,13 +124,13 @@ namespace TensorFlowNET.Examples var (eval_session, _, bottleneck_input, ground_truth_input, evaluation_step, prediction) = build_eval_session(class_count); - var results = eval_session.run(new Tensor[] { evaluation_step, prediction }, - new FeedItem(bottleneck_input, test_bottlenecks), - new FeedItem(ground_truth_input, test_ground_truth)); + (float accuracy, NDArray prediction1) = eval_session.run((evaluation_step, prediction), + (bottleneck_input, test_bottlenecks), + (ground_truth_input, test_ground_truth)); - print($"final test accuracy: {((float)results[0] * 100).ToString("G4")}% (N={len(test_bottlenecks)})"); + print($"final test accuracy: {(accuracy * 100).ToString("G4")}% (N={len(test_bottlenecks)})"); - return (results[0], results[1]); + return (accuracy, prediction1); } private (Session, Tensor, Tensor, Tensor, Tensor, Tensor) @@ -661,11 +661,9 @@ namespace TensorFlowNET.Examples bool is_last_step = (i + 1 == how_many_training_steps); if ((i % eval_step_interval) == 0 || is_last_step) { - results = sess.run( - new Tensor[] { evaluation_step, cross_entropy }, - new FeedItem(bottleneck_input, train_bottlenecks), - new FeedItem(ground_truth_input, train_ground_truth)); - (float train_accuracy, float cross_entropy_value) = (results[0], results[1]); + (float train_accuracy, float cross_entropy_value) = sess.run((evaluation_step, cross_entropy), + (bottleneck_input, train_bottlenecks), + (ground_truth_input, train_ground_truth)); print($"{DateTime.Now}: Step {i + 1}: Train accuracy = {train_accuracy * 100}%, Cross entropy = {cross_entropy_value.ToString("G4")}"); var (validation_bottlenecks, validation_ground_truth, _) = get_random_cached_bottlenecks( @@ -676,12 +674,10 @@ namespace TensorFlowNET.Examples // Run a validation step and capture training summaries for TensorBoard // with the `merged` op. - results = sess.run(new Tensor[] { merged, evaluation_step }, - new FeedItem(bottleneck_input, validation_bottlenecks), - new FeedItem(ground_truth_input, validation_ground_truth)); + (_, float validation_accuracy) = sess.run((merged, evaluation_step), + (bottleneck_input, validation_bottlenecks), + (ground_truth_input, validation_ground_truth)); - //(string validation_summary, float validation_accuracy) = (results[0], results[1]); - float validation_accuracy = results[1]; // validation_writer.add_summary(validation_summary, i); print($"{DateTime.Now}: Step {i + 1}: Validation accuracy = {validation_accuracy * 100}% (N={len(validation_bottlenecks)}) {sw.ElapsedMilliseconds}ms"); sw.Restart(); @@ -741,10 +737,10 @@ namespace TensorFlowNET.Examples using (var sess = tf.Session(graph)) { - var result = sess.run(output, new FeedItem(input, fileBytes)); + var result = sess.run(output, (input, fileBytes)); var prob = np.squeeze(result); var idx = np.argmax(prob); - print($"Prediction result: [{labels[idx]} {prob[idx][0]}] for {img_path}."); + print($"Prediction result: [{labels[idx]} {prob[idx]}] for {img_path}."); } } diff --git a/test/TensorFlowNET.Examples/TextProcessing/CnnTextClassification.cs b/test/TensorFlowNET.Examples/TextProcessing/CnnTextClassification.cs index 8b01df7a..d781bccd 100644 --- a/test/TensorFlowNET.Examples/TextProcessing/CnnTextClassification.cs +++ b/test/TensorFlowNET.Examples/TextProcessing/CnnTextClassification.cs @@ -213,22 +213,13 @@ namespace TensorFlowNET.Examples Tensor global_step = graph.OperationByName("Variable"); Tensor accuracy = graph.OperationByName("accuracy/accuracy"); stopwatch = Stopwatch.StartNew(); - int i = 0; + int step = 0; foreach (var (x_batch, y_batch, total) in train_batches) { - i++; - var train_feed_dict = new FeedDict - { - [model_x] = x_batch, - [model_y] = y_batch, - [is_training] = true, - }; - - var result = sess.run(new ITensorOrOperation[] { optimizer, global_step, loss }, train_feed_dict); - loss_value = result[2]; - var step = (int)result[1]; - if (step % 10 == 0) - Console.WriteLine($"Training on batch {i}/{total} loss: {loss_value.ToString("0.0000")}."); + (_, step, loss_value) = sess.run((optimizer, global_step, loss), + (model_x, x_batch), (model_y, y_batch), (is_training, true)); + if (step == 1 || step % 10 == 0) + Console.WriteLine($"Training on batch {step}/{total} loss: {loss_value.ToString("0.0000")}."); if (step % 100 == 0) { @@ -243,8 +234,7 @@ namespace TensorFlowNET.Examples [model_y] = valid_y_batch, [is_training] = false }; - var result1 = sess.run(accuracy, valid_feed_dict); - float accuracy_value = result1[0]; + float accuracy_value = sess.run(accuracy, (model_x, valid_x_batch), (model_y, valid_y_batch), (is_training, false)); sum_accuracy += accuracy_value; cnt += 1; } diff --git a/test/TensorFlowNET.Examples/TextProcessing/NER/LstmCrfNer.cs b/test/TensorFlowNET.Examples/TextProcessing/NER/LstmCrfNer.cs index 146bc080..db5f230f 100644 --- a/test/TensorFlowNET.Examples/TextProcessing/NER/LstmCrfNer.cs +++ b/test/TensorFlowNET.Examples/TextProcessing/NER/LstmCrfNer.cs @@ -80,17 +80,16 @@ namespace TensorFlowNET.Examples.Text.NER private float run_epoch(Session sess, CoNLLDataset train, CoNLLDataset dev, int epoch) { - NDArray[] results = null; - + float accuracy = 0; // iterate over dataset var batches = minibatches(train, hp.batch_size); foreach (var(words, labels) in batches) { var (fd, _) = get_feed_dict(words, labels, hp.lr, hp.dropout); - results = sess.run(new ITensorOrOperation[] { train_op, loss }, feed_dict: fd); + (_, accuracy) = sess.run((train_op, loss), feed_dict: fd); } - return results[1]; + return accuracy; } private IEnumerable<((int[][], int[])[], int[][])> minibatches(CoNLLDataset data, int minibatch_size) diff --git a/test/TensorFlowNET.Examples/TextProcessing/Word2Vec.cs b/test/TensorFlowNET.Examples/TextProcessing/Word2Vec.cs index 51dc270d..33f50aad 100644 --- a/test/TensorFlowNET.Examples/TextProcessing/Word2Vec.cs +++ b/test/TensorFlowNET.Examples/TextProcessing/Word2Vec.cs @@ -81,8 +81,8 @@ namespace TensorFlowNET.Examples // Get a new batch of data var (batch_x, batch_y) = next_batch(batch_size, num_skips, skip_window); - var result = sess.run(new ITensorOrOperation[] { train_op, loss_op }, new FeedItem(X, batch_x), new FeedItem(Y, batch_y)); - average_loss += result[1]; + (_, float loss) = sess.run((train_op, loss_op), (X, batch_x), (Y, batch_y)); + average_loss += loss; if (step % display_step == 0 || step == 1) { @@ -97,7 +97,7 @@ namespace TensorFlowNET.Examples if (step % eval_step == 0 || step == 1) { print("Evaluation..."); - var sim = sess.run(cosine_sim_op, new FeedItem(X, x_test))[0]; + var sim = sess.run(cosine_sim_op, (X, x_test)); foreach(var i in range(len(eval_words))) { var nearest = (0f - sim[i]).argsort()