From bf53773f643cc52c751134e99efd25c4002237f2 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 11 May 2019 08:38:52 -0500 Subject: [PATCH] sync --- .../TextProcess/TextClassificationTrain.cs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs b/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs index 9d24e3c6..15ee6c1c 100644 --- a/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs +++ b/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs @@ -55,7 +55,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification { var stopwatch = Stopwatch.StartNew(); Console.WriteLine("Building dataset..."); - var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit=null); + var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit = null); Console.WriteLine("\tDONE "); var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); @@ -169,8 +169,8 @@ namespace TensorFlowNET.Examples.CnnTextClassification //int classes = y.Data().Distinct().Count(); //int samples = len / classes; int train_size = (int)Math.Round(len * (1 - test_size)); - var train_x = x[new Slice(stop:train_size), new Slice()]; - var valid_x = x[new Slice(start: train_size+1), new Slice()]; + var train_x = x[new Slice(stop: train_size), new Slice()]; + var valid_x = x[new Slice(start: train_size + 1), new Slice()]; var train_y = y[new Slice(stop: train_size)]; var valid_y = y[new Slice(start: train_size + 1)]; Console.WriteLine("\tDONE"); @@ -179,7 +179,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification private IEnumerable<(NDArray, NDArray, int)> batch_iter(NDArray inputs, NDArray outputs, int batch_size, int num_epochs) { - var num_batches_per_epoch = (len(inputs) - 1) / batch_size +1; + var num_batches_per_epoch = (len(inputs) - 1) / batch_size + 1; var total_batches = num_batches_per_epoch * num_epochs; foreach (var epoch in range(num_epochs)) { @@ -189,7 +189,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification var end_index = Math.Min((batch_num + 1) * batch_size, len(inputs)); if (end_index <= start_index) break; - yield return (inputs[new Slice(start_index, end_index)], outputs[new Slice(start_index,end_index)], total_batches); + yield return (inputs[new Slice(start_index, end_index)], outputs[new Slice(start_index, end_index)], total_batches); } } }