Browse Source

TextClassification: supported training against dbpedia subset

tags/v0.9
Meinrad Recheis 6 years ago
parent
commit
f4d25c9c6a
2 changed files with 26 additions and 11 deletions
  1. +5
    -5
      test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs
  2. +21
    -6
      test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs

+ 5
- 5
test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs View File

@@ -11,10 +11,8 @@ namespace TensorFlowNET.Examples
{
public class DataHelpers
{
private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv";
private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv";

public static (int[][], int[], int) build_char_dataset(string step, string model, int document_max_len, int? limit = null)
public static (int[][], int[], int) build_char_dataset(string path, string model, int document_max_len, int? limit = null, bool shuffle=true)
{
if (model != "vd_cnn")
throw new NotImplementedException(model);
@@ -26,8 +24,10 @@ namespace TensorFlowNET.Examples
char_dict["<unk>"] = 1;
foreach (char c in alphabet)
char_dict[c.ToString()] = char_dict.Count;
var contents = new Random(17).Shuffle( File.ReadAllLines(TRAIN_PATH));
//File.WriteAllLines("text_classification/dbpedia_csv/train_6400.csv", contents.Take(6400));
var contents = File.ReadAllLines(path);
if (shuffle)
new Random(17).Shuffle(contents);
//File.WriteAllLines("text_classification/dbpedia_csv/train_6400.csv", contents.Take(6400));
var size = limit == null ? contents.Length : limit.Value;

var x = new int[size][];


+ 21
- 6
test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs View File

@@ -25,15 +25,20 @@ namespace TensorFlowNET.Examples.CnnTextClassification
public string Name => "Text Classification";
public int? DataLimit = null;
public bool ImportGraph { get; set; } = true;
public bool UseSubset = true; // <----- set this true to use a limited subset of dbpedia

private string dataDir = "text_classification";
private string dataFileName = "dbpedia_csv.tar.gz";

public string model_name = "vd_cnn"; // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn

private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv";
private const string SUBSET_PATH = "text_classification/dbpedia_csv/dbpedia_6400.csv";
private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv";

private const int CHAR_MAX_LEN = 1014;
private const int WORD_MAX_LEN = 1014;
private const int NUM_CLASS = 2;
private const int NUM_CLASS = 14;
private const int BATCH_SIZE = 64;
private const int NUM_EPOCHS = 10;
protected float loss_value = 0;
@@ -55,7 +60,8 @@ namespace TensorFlowNET.Examples.CnnTextClassification
{
var stopwatch = Stopwatch.StartNew();
Console.WriteLine("Building dataset...");
var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit = null);
var path = UseSubset ? SUBSET_PATH : TRAIN_PATH;
var (x, y, alphabet_size) = DataHelpers.build_char_dataset(path, model_name, CHAR_MAX_LEN, DataLimit = null, shuffle:!UseSubset);
Console.WriteLine("\tDONE ");

var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);
@@ -249,9 +255,18 @@ namespace TensorFlowNET.Examples.CnnTextClassification

public void PrepareData()
{
string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz";
Web.Download(url, dataDir, dataFileName);
Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir);
if (UseSubset)
{
var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/dbpedia_subset.zip";
Web.Download(url, dataDir, "dbpedia_subset.zip");
Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv"));
}
else
{
string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz";
Web.Download(url, dataDir, dataFileName);
Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir);
}

if (ImportGraph)
{
@@ -264,7 +279,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification
Console.WriteLine("Discarding cached file: " + meta_path);
File.Delete(meta_path);
}
url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file;
var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file;
Web.Download(url, "graph", meta_file);
}
}


Loading…
Cancel
Save