|
|
@@ -26,14 +26,13 @@ namespace TensorFlowNET.Examples |
|
|
|
public string Name => "CNN Text Classification"; |
|
|
|
public int? DataLimit = null; |
|
|
|
public bool ImportGraph { get; set; } = true; |
|
|
|
public bool UseSubset = false; // <----- set this true to use a limited subset of dbpedia |
|
|
|
|
|
|
|
private string dataDir = "text_classification"; |
|
|
|
private string dataDir = "word_cnn"; |
|
|
|
private string dataFileName = "dbpedia_csv.tar.gz"; |
|
|
|
|
|
|
|
private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv"; |
|
|
|
private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv"; |
|
|
|
|
|
|
|
|
|
|
|
private const int NUM_CLASS = 14; |
|
|
|
private const int BATCH_SIZE = 64; |
|
|
|
private const int NUM_EPOCHS = 10; |
|
|
@@ -41,6 +40,7 @@ namespace TensorFlowNET.Examples |
|
|
|
private const int CHAR_MAX_LEN = 1014; |
|
|
|
|
|
|
|
protected float loss_value = 0; |
|
|
|
int vocabulary_size = 50000; |
|
|
|
|
|
|
|
public bool Run() |
|
|
|
{ |
|
|
@@ -63,10 +63,9 @@ namespace TensorFlowNET.Examples |
|
|
|
int[][] x = null; |
|
|
|
int[] y = null; |
|
|
|
int alphabet_size = 0; |
|
|
|
int vocabulary_size = 0; |
|
|
|
|
|
|
|
var word_dict = DataHelpers.build_word_dict(TRAIN_PATH); |
|
|
|
vocabulary_size = len(word_dict); |
|
|
|
// vocabulary_size = len(word_dict); |
|
|
|
(x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN); |
|
|
|
|
|
|
|
Console.WriteLine("\tDONE "); |
|
|
@@ -142,7 +141,7 @@ namespace TensorFlowNET.Examples |
|
|
|
if (valid_accuracy > max_accuracy) |
|
|
|
{ |
|
|
|
max_accuracy = valid_accuracy; |
|
|
|
saver.save(sess, $"{dataDir}/word_cnn.ckpt", global_step: step.ToString()); |
|
|
|
saver.save(sess, $"{dataDir}/word_cnn.ckpt", global_step: step); |
|
|
|
print("Model is saved.\n"); |
|
|
|
} |
|
|
|
} |
|
|
@@ -218,18 +217,10 @@ namespace TensorFlowNET.Examples |
|
|
|
|
|
|
|
public void PrepareData() |
|
|
|
{ |
|
|
|
if (UseSubset) |
|
|
|
{ |
|
|
|
var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/dbpedia_subset.zip"; |
|
|
|
Web.Download(url, dataDir, "dbpedia_subset.zip"); |
|
|
|
Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv")); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz"; |
|
|
|
Web.Download(url, dataDir, dataFileName); |
|
|
|
Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir); |
|
|
|
} |
|
|
|
// full dataset https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz |
|
|
|
var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/dbpedia_subset.zip"; |
|
|
|
Web.Download(url, dataDir, "dbpedia_subset.zip"); |
|
|
|
Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv")); |
|
|
|
|
|
|
|
if (ImportGraph) |
|
|
|
{ |
|
|
@@ -242,7 +233,7 @@ namespace TensorFlowNET.Examples |
|
|
|
Console.WriteLine("Discarding cached file: " + meta_path); |
|
|
|
File.Delete(meta_path); |
|
|
|
} |
|
|
|
var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file; |
|
|
|
url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file; |
|
|
|
Web.Download(url, "graph", meta_file); |
|
|
|
} |
|
|
|
} |
|
|
|