|
|
@@ -55,7 +55,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification |
|
|
|
{ |
|
|
|
var stopwatch = Stopwatch.StartNew(); |
|
|
|
Console.WriteLine("Building dataset..."); |
|
|
|
var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit=null); |
|
|
|
var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit = null); |
|
|
|
Console.WriteLine("\tDONE "); |
|
|
|
|
|
|
|
var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); |
|
|
@@ -169,8 +169,8 @@ namespace TensorFlowNET.Examples.CnnTextClassification |
|
|
|
//int classes = y.Data<int>().Distinct().Count(); |
|
|
|
//int samples = len / classes; |
|
|
|
int train_size = (int)Math.Round(len * (1 - test_size)); |
|
|
|
var train_x = x[new Slice(stop:train_size), new Slice()]; |
|
|
|
var valid_x = x[new Slice(start: train_size+1), new Slice()]; |
|
|
|
var train_x = x[new Slice(stop: train_size), new Slice()]; |
|
|
|
var valid_x = x[new Slice(start: train_size + 1), new Slice()]; |
|
|
|
var train_y = y[new Slice(stop: train_size)]; |
|
|
|
var valid_y = y[new Slice(start: train_size + 1)]; |
|
|
|
Console.WriteLine("\tDONE"); |
|
|
@@ -179,7 +179,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification |
|
|
|
|
|
|
|
private IEnumerable<(NDArray, NDArray, int)> batch_iter(NDArray inputs, NDArray outputs, int batch_size, int num_epochs) |
|
|
|
{ |
|
|
|
var num_batches_per_epoch = (len(inputs) - 1) / batch_size +1; |
|
|
|
var num_batches_per_epoch = (len(inputs) - 1) / batch_size + 1; |
|
|
|
var total_batches = num_batches_per_epoch * num_epochs; |
|
|
|
foreach (var epoch in range(num_epochs)) |
|
|
|
{ |
|
|
@@ -189,7 +189,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification |
|
|
|
var end_index = Math.Min((batch_num + 1) * batch_size, len(inputs)); |
|
|
|
if (end_index <= start_index) |
|
|
|
break; |
|
|
|
yield return (inputs[new Slice(start_index, end_index)], outputs[new Slice(start_index,end_index)], total_batches); |
|
|
|
yield return (inputs[new Slice(start_index, end_index)], outputs[new Slice(start_index, end_index)], total_batches); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|