diff --git a/data/dbpedia_subset.zip b/data/dbpedia_subset.zip index e4ab6dda..120ac8a1 100644 Binary files a/data/dbpedia_subset.zip and b/data/dbpedia_subset.zip differ diff --git a/graph/word_cnn.meta b/graph/word_cnn.meta new file mode 100644 index 00000000..de19687b Binary files /dev/null and b/graph/word_cnn.meta differ diff --git a/src/TensorFlowNET.Core/APIs/tf.variable.cs b/src/TensorFlowNET.Core/APIs/tf.variable.cs index d4f71b74..266d5799 100644 --- a/src/TensorFlowNET.Core/APIs/tf.variable.cs +++ b/src/TensorFlowNET.Core/APIs/tf.variable.cs @@ -6,6 +6,12 @@ namespace Tensorflow { public static partial class tf { + public static VariableV1[] global_variables(string scope = null) + { + return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) as List) + .ToArray(); + } + public static Operation global_variables_initializer() { var g = variables.global_variables(); diff --git a/src/TensorFlowNET.Core/Train/tf.optimizers.cs b/src/TensorFlowNET.Core/Train/tf.optimizers.cs index b4925f3a..9e3d66a6 100644 --- a/src/TensorFlowNET.Core/Train/tf.optimizers.cs +++ b/src/TensorFlowNET.Core/Train/tf.optimizers.cs @@ -14,7 +14,7 @@ namespace Tensorflow public static Optimizer AdamOptimizer(float learning_rate) => new AdamOptimizer(learning_rate); - public static Saver Saver() => new Saver(); + public static Saver Saver(VariableV1[] var_list = null) => new Saver(var_list: var_list); public static string write_graph(Graph graph, string logdir, string name, bool as_text = true) => graph_io.write_graph(graph, logdir, name, as_text); diff --git a/test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs b/test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs index 4c141ebf..8b5f79e2 100644 --- a/test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs +++ b/test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs @@ -12,6 +12,46 @@ namespace TensorFlowNET.Examples { public class DataHelpers { + public static Dictionary build_word_dict(string path) + { + var contents = File.ReadAllLines(path); + + var words = new List(); + foreach (var content in contents) + words.AddRange(clean_str(content).Split(' ').Where(x => x.Length > 1)); + var word_counter = words.GroupBy(x => x) + .Select(x => new { Word = x.Key, Count = x.Count() }) + .OrderByDescending(x => x.Count) + .ToArray(); + + var word_dict = new Dictionary(); + word_dict[""] = 0; + word_dict[""] = 1; + word_dict[""] = 2; + foreach (var word in word_counter) + word_dict[word.Word] = word_dict.Count; + + return word_dict; + } + + public static (int[][], int[]) build_word_dataset(string path, Dictionary word_dict, int document_max_len) + { + var contents = File.ReadAllLines(path); + var x = contents.Select(c => (clean_str(c) + " ") + .Split(' ').Take(document_max_len) + .Select(w => word_dict.ContainsKey(w) ? word_dict[w] : word_dict[""]).ToArray()) + .ToArray(); + + for (int i = 0; i < x.Length; i++) + if (x[i].Length == document_max_len) + x[i][document_max_len - 1] = word_dict[""]; + else + Array.Resize(ref x[i], document_max_len); + + var y = contents.Select(c => int.Parse(c.Substring(0, c.IndexOf(','))) - 1).ToArray(); + + return (x, y); + } public static (int[][], int[], int) build_char_dataset(string path, string model, int document_max_len, int? limit = null, bool shuffle=true) { @@ -96,8 +136,8 @@ namespace TensorFlowNET.Examples private static string clean_str(string str) { - str = Regex.Replace(str, @"[^A-Za-z0-9(),!?\'\`]", " "); - str = Regex.Replace(str, @"\'s", " \'s"); + str = Regex.Replace(str, "[^A-Za-z0-9(),!?]", " "); + str = Regex.Replace(str, ",", " "); return str; } diff --git a/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs b/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs index 73c74e3f..c7652268 100644 --- a/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs +++ b/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs @@ -26,22 +26,23 @@ namespace TensorFlowNET.Examples.CnnTextClassification public string Name => "Text Classification"; public int? DataLimit = null; public bool ImportGraph { get; set; } = true; - public bool UseSubset = true; // <----- set this true to use a limited subset of dbpedia + public bool UseSubset = false; // <----- set this true to use a limited subset of dbpedia private string dataDir = "text_classification"; private string dataFileName = "dbpedia_csv.tar.gz"; - public string model_name = "vd_cnn"; // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn + public string model_name = "word_cnn"; // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv"; private const string SUBSET_PATH = "text_classification/dbpedia_csv/dbpedia_6400.csv"; private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv"; - private const int CHAR_MAX_LEN = 1014; - private const int WORD_MAX_LEN = 1014; private const int NUM_CLASS = 14; private const int BATCH_SIZE = 64; private const int NUM_EPOCHS = 10; + private const int WORD_MAX_LEN = 100; + private const int CHAR_MAX_LEN = 1014; + protected float loss_value = 0; public bool Run() @@ -61,8 +62,21 @@ namespace TensorFlowNET.Examples.CnnTextClassification { var stopwatch = Stopwatch.StartNew(); Console.WriteLine("Building dataset..."); - var path = UseSubset ? SUBSET_PATH : TRAIN_PATH; - var (x, y, alphabet_size) = DataHelpers.build_char_dataset(path, model_name, CHAR_MAX_LEN, DataLimit = null, shuffle:!UseSubset); + var path = UseSubset ? SUBSET_PATH : TRAIN_PATH; + int[][] x = null; + int[] y = null; + int alphabet_size = 0; + int vocabulary_size = 0; + + if (model_name == "vd_cnn") + (x, y, alphabet_size) = DataHelpers.build_char_dataset(path, model_name, CHAR_MAX_LEN, DataLimit = null, shuffle:!UseSubset); + else + { + var word_dict = DataHelpers.build_word_dict(TRAIN_PATH); + vocabulary_size = len(word_dict); + (x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN); + } + Console.WriteLine("\tDONE "); var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); @@ -75,18 +89,19 @@ namespace TensorFlowNET.Examples.CnnTextClassification Console.WriteLine("\tDONE " + stopwatch.Elapsed); sess.run(tf.global_variables_initializer()); + var saver = tf.train.Saver(tf.global_variables()); var train_batches = batch_iter(train_x, train_y, BATCH_SIZE, NUM_EPOCHS); var num_batches_per_epoch = (len(train_x) - 1) / BATCH_SIZE + 1; double max_accuracy = 0; - Tensor is_training = graph.get_tensor_by_name("is_training:0"); - Tensor model_x = graph.get_tensor_by_name("x:0"); - Tensor model_y = graph.get_tensor_by_name("y:0"); - Tensor loss = graph.get_tensor_by_name("loss/value:0"); - Tensor optimizer = graph.get_tensor_by_name("loss/optimizer:0"); - Tensor global_step = graph.get_tensor_by_name("global_step:0"); - Tensor accuracy = graph.get_tensor_by_name("accuracy/value:0"); + Tensor is_training = graph.OperationByName("is_training"); + Tensor model_x = graph.OperationByName("x"); + Tensor model_y = graph.OperationByName("y"); + Tensor loss = graph.OperationByName("loss/Mean"); // word_cnn + Operation optimizer = graph.OperationByName("loss/Adam"); // word_cnn + Tensor global_step = graph.OperationByName("Variable"); + Tensor accuracy = graph.OperationByName("accuracy/accuracy"); stopwatch = Stopwatch.StartNew(); int i = 0; foreach (var (x_batch, y_batch, total) in train_batches) @@ -105,11 +120,10 @@ namespace TensorFlowNET.Examples.CnnTextClassification var result = sess.run(new ITensorOrOperation[] { optimizer, global_step, loss }, train_feed_dict); loss_value = result[2]; var step = (int)result[1]; - if (step % 10 == 0 || step < 10) + if (step % 10 == 0) { var estimate = TimeSpan.FromSeconds((stopwatch.Elapsed.TotalSeconds / i) * total); - Console.WriteLine($"Training on batch {i}/{total}. Estimated training time: {estimate}"); - Console.WriteLine($"Step {step} loss: {loss_value}"); + Console.WriteLine($"Training on batch {i}/{total} loss: {loss_value}. Estimated training time: {estimate}"); } if (step % 100 == 0) @@ -133,13 +147,15 @@ namespace TensorFlowNET.Examples.CnnTextClassification var valid_accuracy = sum_accuracy / cnt; - print($"\nValidation Accuracy = {valid_accuracy}\n"); - - // # Save model - // if valid_accuracy > max_accuracy: - // max_accuracy = valid_accuracy - // saver.save(sess, "{0}/{1}.ckpt".format(args.model, args.model), global_step = step) - // print("Model is saved.\n") + print($"\nValidation Accuracy = {valid_accuracy}\n"); + + // # Save model + if (valid_accuracy > max_accuracy) + { + max_accuracy = valid_accuracy; + // saver.save(sess, $"{dataDir}/{model_name}.ckpt", global_step: step.ToString()); + print("Model is saved.\n"); + } } } @@ -180,9 +196,9 @@ namespace TensorFlowNET.Examples.CnnTextClassification //int samples = len / classes; int train_size = (int)Math.Round(len * (1 - test_size)); var train_x = x[new Slice(stop: train_size), new Slice()]; - var valid_x = x[new Slice(start: train_size + 1), new Slice()]; + var valid_x = x[new Slice(start: train_size), new Slice()]; var train_y = y[new Slice(stop: train_size)]; - var valid_y = y[new Slice(start: train_size + 1)]; + var valid_y = y[new Slice(start: train_size)]; Console.WriteLine("\tDONE"); return (train_x, valid_x, train_y, valid_y); }