Browse Source

sync

tags/v0.9
Oceania2018 6 years ago
parent
commit
bf53773f64
1 changed files with 5 additions and 5 deletions
  1. +5
    -5
      test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs

+ 5
- 5
test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs View File

@@ -55,7 +55,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification
{
var stopwatch = Stopwatch.StartNew();
Console.WriteLine("Building dataset...");
var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit=null);
var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit = null);
Console.WriteLine("\tDONE ");

var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);
@@ -169,8 +169,8 @@ namespace TensorFlowNET.Examples.CnnTextClassification
//int classes = y.Data<int>().Distinct().Count();
//int samples = len / classes;
int train_size = (int)Math.Round(len * (1 - test_size));
var train_x = x[new Slice(stop:train_size), new Slice()];
var valid_x = x[new Slice(start: train_size+1), new Slice()];
var train_x = x[new Slice(stop: train_size), new Slice()];
var valid_x = x[new Slice(start: train_size + 1), new Slice()];
var train_y = y[new Slice(stop: train_size)];
var valid_y = y[new Slice(start: train_size + 1)];
Console.WriteLine("\tDONE");
@@ -179,7 +179,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification

private IEnumerable<(NDArray, NDArray, int)> batch_iter(NDArray inputs, NDArray outputs, int batch_size, int num_epochs)
{
var num_batches_per_epoch = (len(inputs) - 1) / batch_size +1;
var num_batches_per_epoch = (len(inputs) - 1) / batch_size + 1;
var total_batches = num_batches_per_epoch * num_epochs;
foreach (var epoch in range(num_epochs))
{
@@ -189,7 +189,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification
var end_index = Math.Min((batch_num + 1) * batch_size, len(inputs));
if (end_index <= start_index)
break;
yield return (inputs[new Slice(start_index, end_index)], outputs[new Slice(start_index,end_index)], total_batches);
yield return (inputs[new Slice(start_index, end_index)], outputs[new Slice(start_index, end_index)], total_batches);
}
}
}


Loading…
Cancel
Save