diff --git a/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs b/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs index 15f1e55b..d9877d8c 100644 --- a/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs +++ b/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs @@ -1,6 +1,7 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Diagnostics; using System.IO; using System.Linq; using System.Text; @@ -52,7 +53,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification protected virtual bool RunWithImportedGraph(Session sess, Graph graph) { Console.WriteLine("Building dataset..."); - var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit); + var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit=null); Console.WriteLine("\tDONE"); var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); @@ -76,12 +77,13 @@ namespace TensorFlowNET.Examples.CnnTextClassification Tensor optimizer = graph.get_operation_by_name("loss/optimizer"); Tensor global_step = graph.get_operation_by_name("global_step"); Tensor accuracy = graph.get_operation_by_name("accuracy/accuracy"); - + var stopwatch = Stopwatch.StartNew(); int i = 0; - foreach (var (x_batch, y_batch) in train_batches) + foreach (var (x_batch, y_batch, total) in train_batches) { i++; - Console.WriteLine("Training on batch " + i); + var estimate = TimeSpan.FromSeconds((stopwatch.Elapsed.TotalSeconds / i) * total); + Console.WriteLine($"Training on batch {i}/{total}. Estimated training time: {estimate}"); var train_feed_dict = new Hashtable { [model_x] = x_batch, @@ -90,8 +92,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification }; // original python: //_, step, loss = sess.run([model.optimizer, model.global_step, model.loss], feed_dict = train_feed_dict) - var result = sess.run(new Tensor[] { optimizer, global_step, loss }, train_feed_dict); - // exception here, loss value seems like a float[] + var result = sess.run(new ITensorOrOperation[] { optimizer, global_step, loss }, train_feed_dict); //loss_value = result[2]; var step = result[1]; if (step % 10 == 0) @@ -102,7 +103,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification // # Test accuracy with validation data for each epoch. var valid_batches = batch_iter(valid_x, valid_y, BATCH_SIZE, 1); var (sum_accuracy, cnt) = (0, 0); - foreach (var (valid_x_batch, valid_y_batch) in valid_batches) + foreach (var (valid_x_batch, valid_y_batch, total_validation_batches) in valid_batches) { // valid_feed_dict = { // model.x: valid_x_batch, @@ -170,16 +171,19 @@ namespace TensorFlowNET.Examples.CnnTextClassification return (train_x, valid_x, train_y, valid_y); } - private IEnumerable<(NDArray, NDArray)> batch_iter(NDArray inputs, NDArray outputs, int batch_size, int num_epochs) + private IEnumerable<(NDArray, NDArray, int)> batch_iter(NDArray inputs, NDArray outputs, int batch_size, int num_epochs) { - var num_batches_per_epoch = (len(inputs) - 1); // batch_size + 1 + var num_batches_per_epoch = (len(inputs) - 1) / batch_size; + var total_batches = num_batches_per_epoch * num_epochs; foreach (var epoch in range(num_epochs)) { foreach (var batch_num in range(num_batches_per_epoch)) { var start_index = batch_num * batch_size; var end_index = Math.Min((batch_num + 1) * batch_size, len(inputs)); - yield return (inputs[new Slice(start_index, end_index)], outputs[new Slice(start_index,end_index)]); + if (end_index <= start_index) + break; + yield return (inputs[new Slice(start_index, end_index)], outputs[new Slice(start_index,end_index)], total_batches); } } }