diff --git a/graph/word_cnn.meta b/graph/word_cnn.meta index de19687b..141947b1 100644 Binary files a/graph/word_cnn.meta and b/graph/word_cnn.meta differ diff --git a/src/TensorFlowNET.Core/Train/Saving/Saver.cs b/src/TensorFlowNET.Core/Train/Saving/Saver.cs index a6ced8fa..0f3a2ab8 100644 --- a/src/TensorFlowNET.Core/Train/Saving/Saver.cs +++ b/src/TensorFlowNET.Core/Train/Saving/Saver.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; using System.Text; +using static Tensorflow.Python; namespace Tensorflow { @@ -144,26 +145,20 @@ namespace Tensorflow public string save(Session sess, string save_path, - string global_step = "", + int global_step = -1, string latest_filename = "", string meta_graph_suffix = "meta", bool write_meta_graph = true, bool write_state = true, - bool strip_default_attrs = false) + bool strip_default_attrs = false, + bool save_debug_info = false) { if (string.IsNullOrEmpty(latest_filename)) latest_filename = "checkpoint"; string model_checkpoint_path = ""; string checkpoint_file = ""; - if (!string.IsNullOrEmpty(global_step)) - { - - } - else - { - checkpoint_file = save_path; - } + checkpoint_file = $"{save_path}-{global_step}"; var save_path_parent = Path.GetDirectoryName(save_path); @@ -189,6 +184,7 @@ namespace Tensorflow if (write_meta_graph) { string meta_graph_filename = checkpoint_management.meta_graph_filename(checkpoint_file, meta_graph_suffix: meta_graph_suffix); + export_meta_graph(meta_graph_filename, strip_default_attrs: strip_default_attrs, save_debug_info: save_debug_info); } return _is_empty ? string.Empty : model_checkpoint_path; @@ -244,10 +240,11 @@ namespace Tensorflow public MetaGraphDef export_meta_graph(string filename= "", string[] collection_list = null, string export_scope = "", - bool as_text= false, - bool clear_devices= false, - bool clear_extraneous_savers= false, - bool strip_default_attrs= false) + bool as_text = false, + bool clear_devices = false, + bool clear_extraneous_savers = false, + bool strip_default_attrs = false, + bool save_debug_info = false) { return export_meta_graph( filename: filename, diff --git a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs index 01b6d72c..aaac8e4a 100644 --- a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs +++ b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs @@ -26,14 +26,13 @@ namespace TensorFlowNET.Examples public string Name => "CNN Text Classification"; public int? DataLimit = null; public bool ImportGraph { get; set; } = true; - public bool UseSubset = false; // <----- set this true to use a limited subset of dbpedia - private string dataDir = "text_classification"; + private string dataDir = "word_cnn"; private string dataFileName = "dbpedia_csv.tar.gz"; private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv"; private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv"; - + private const int NUM_CLASS = 14; private const int BATCH_SIZE = 64; private const int NUM_EPOCHS = 10; @@ -41,6 +40,7 @@ namespace TensorFlowNET.Examples private const int CHAR_MAX_LEN = 1014; protected float loss_value = 0; + int vocabulary_size = 50000; public bool Run() { @@ -63,10 +63,9 @@ namespace TensorFlowNET.Examples int[][] x = null; int[] y = null; int alphabet_size = 0; - int vocabulary_size = 0; var word_dict = DataHelpers.build_word_dict(TRAIN_PATH); - vocabulary_size = len(word_dict); + // vocabulary_size = len(word_dict); (x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN); Console.WriteLine("\tDONE "); @@ -142,7 +141,7 @@ namespace TensorFlowNET.Examples if (valid_accuracy > max_accuracy) { max_accuracy = valid_accuracy; - saver.save(sess, $"{dataDir}/word_cnn.ckpt", global_step: step.ToString()); + saver.save(sess, $"{dataDir}/word_cnn.ckpt", global_step: step); print("Model is saved.\n"); } } @@ -218,18 +217,10 @@ namespace TensorFlowNET.Examples public void PrepareData() { - if (UseSubset) - { - var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/dbpedia_subset.zip"; - Web.Download(url, dataDir, "dbpedia_subset.zip"); - Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv")); - } - else - { - string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz"; - Web.Download(url, dataDir, dataFileName); - Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir); - } + // full dataset https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz + var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/dbpedia_subset.zip"; + Web.Download(url, dataDir, "dbpedia_subset.zip"); + Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv")); if (ImportGraph) { @@ -242,7 +233,7 @@ namespace TensorFlowNET.Examples Console.WriteLine("Discarding cached file: " + meta_path); File.Delete(meta_path); } - var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file; + url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file; Web.Download(url, "graph", meta_file); } }