diff --git a/test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs b/test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs index 1906e305..66e4cc0a 100644 --- a/test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs +++ b/test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs @@ -11,10 +11,8 @@ namespace TensorFlowNET.Examples { public class DataHelpers { - private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv"; - private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv"; - public static (int[][], int[], int) build_char_dataset(string step, string model, int document_max_len, int? limit = null) + public static (int[][], int[], int) build_char_dataset(string path, string model, int document_max_len, int? limit = null, bool shuffle=true) { if (model != "vd_cnn") throw new NotImplementedException(model); @@ -26,8 +24,10 @@ namespace TensorFlowNET.Examples char_dict[""] = 1; foreach (char c in alphabet) char_dict[c.ToString()] = char_dict.Count; - var contents = new Random(17).Shuffle( File.ReadAllLines(TRAIN_PATH)); - //File.WriteAllLines("text_classification/dbpedia_csv/train_6400.csv", contents.Take(6400)); + var contents = File.ReadAllLines(path); + if (shuffle) + new Random(17).Shuffle(contents); + //File.WriteAllLines("text_classification/dbpedia_csv/train_6400.csv", contents.Take(6400)); var size = limit == null ? contents.Length : limit.Value; var x = new int[size][]; diff --git a/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs b/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs index 528c453f..1c1d84bb 100644 --- a/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs +++ b/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs @@ -25,15 +25,20 @@ namespace TensorFlowNET.Examples.CnnTextClassification public string Name => "Text Classification"; public int? DataLimit = null; public bool ImportGraph { get; set; } = true; + public bool UseSubset = true; // <----- set this true to use a limited subset of dbpedia private string dataDir = "text_classification"; private string dataFileName = "dbpedia_csv.tar.gz"; public string model_name = "vd_cnn"; // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn + private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv"; + private const string SUBSET_PATH = "text_classification/dbpedia_csv/dbpedia_6400.csv"; + private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv"; + private const int CHAR_MAX_LEN = 1014; private const int WORD_MAX_LEN = 1014; - private const int NUM_CLASS = 2; + private const int NUM_CLASS = 14; private const int BATCH_SIZE = 64; private const int NUM_EPOCHS = 10; protected float loss_value = 0; @@ -55,7 +60,8 @@ namespace TensorFlowNET.Examples.CnnTextClassification { var stopwatch = Stopwatch.StartNew(); Console.WriteLine("Building dataset..."); - var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit = null); + var path = UseSubset ? SUBSET_PATH : TRAIN_PATH; + var (x, y, alphabet_size) = DataHelpers.build_char_dataset(path, model_name, CHAR_MAX_LEN, DataLimit = null, shuffle:!UseSubset); Console.WriteLine("\tDONE "); var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); @@ -249,9 +255,18 @@ namespace TensorFlowNET.Examples.CnnTextClassification public void PrepareData() { - string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz"; - Web.Download(url, dataDir, dataFileName); - Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir); + if (UseSubset) + { + var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/dbpedia_subset.zip"; + Web.Download(url, dataDir, "dbpedia_subset.zip"); + Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv")); + } + else + { + string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz"; + Web.Download(url, dataDir, dataFileName); + Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir); + } if (ImportGraph) { @@ -264,7 +279,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification Console.WriteLine("Discarding cached file: " + meta_path); File.Delete(meta_path); } - url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file; + var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file; Web.Download(url, "graph", meta_file); } }