Browse Source

word_cnn save training step works.

tags/v0.9
Oceania2018 6 years ago
parent
commit
c0fd13503d
3 changed files with 21 additions and 33 deletions
  1. BIN
      graph/word_cnn.meta
  2. +11
    -14
      src/TensorFlowNET.Core/Train/Saving/Saver.cs
  3. +10
    -19
      test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs

BIN
graph/word_cnn.meta View File


+ 11
- 14
src/TensorFlowNET.Core/Train/Saving/Saver.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using static Tensorflow.Python;

namespace Tensorflow
{
@@ -144,26 +145,20 @@ namespace Tensorflow

public string save(Session sess,
string save_path,
string global_step = "",
int global_step = -1,
string latest_filename = "",
string meta_graph_suffix = "meta",
bool write_meta_graph = true,
bool write_state = true,
bool strip_default_attrs = false)
bool strip_default_attrs = false,
bool save_debug_info = false)
{
if (string.IsNullOrEmpty(latest_filename))
latest_filename = "checkpoint";
string model_checkpoint_path = "";
string checkpoint_file = "";

if (!string.IsNullOrEmpty(global_step))
{

}
else
{
checkpoint_file = save_path;
}
checkpoint_file = $"{save_path}-{global_step}";

var save_path_parent = Path.GetDirectoryName(save_path);

@@ -189,6 +184,7 @@ namespace Tensorflow
if (write_meta_graph)
{
string meta_graph_filename = checkpoint_management.meta_graph_filename(checkpoint_file, meta_graph_suffix: meta_graph_suffix);
export_meta_graph(meta_graph_filename, strip_default_attrs: strip_default_attrs, save_debug_info: save_debug_info);
}

return _is_empty ? string.Empty : model_checkpoint_path;
@@ -244,10 +240,11 @@ namespace Tensorflow
public MetaGraphDef export_meta_graph(string filename= "",
string[] collection_list = null,
string export_scope = "",
bool as_text= false,
bool clear_devices= false,
bool clear_extraneous_savers= false,
bool strip_default_attrs= false)
bool as_text = false,
bool clear_devices = false,
bool clear_extraneous_savers = false,
bool strip_default_attrs = false,
bool save_debug_info = false)
{
return export_meta_graph(
filename: filename,


+ 10
- 19
test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs View File

@@ -26,14 +26,13 @@ namespace TensorFlowNET.Examples
public string Name => "CNN Text Classification";
public int? DataLimit = null;
public bool ImportGraph { get; set; } = true;
public bool UseSubset = false; // <----- set this true to use a limited subset of dbpedia

private string dataDir = "text_classification";
private string dataDir = "word_cnn";
private string dataFileName = "dbpedia_csv.tar.gz";

private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv";
private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv";
private const int NUM_CLASS = 14;
private const int BATCH_SIZE = 64;
private const int NUM_EPOCHS = 10;
@@ -41,6 +40,7 @@ namespace TensorFlowNET.Examples
private const int CHAR_MAX_LEN = 1014;
protected float loss_value = 0;
int vocabulary_size = 50000;

public bool Run()
{
@@ -63,10 +63,9 @@ namespace TensorFlowNET.Examples
int[][] x = null;
int[] y = null;
int alphabet_size = 0;
int vocabulary_size = 0;

var word_dict = DataHelpers.build_word_dict(TRAIN_PATH);
vocabulary_size = len(word_dict);
// vocabulary_size = len(word_dict);
(x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN);

Console.WriteLine("\tDONE ");
@@ -142,7 +141,7 @@ namespace TensorFlowNET.Examples
if (valid_accuracy > max_accuracy)
{
max_accuracy = valid_accuracy;
saver.save(sess, $"{dataDir}/word_cnn.ckpt", global_step: step.ToString());
saver.save(sess, $"{dataDir}/word_cnn.ckpt", global_step: step);
print("Model is saved.\n");
}
}
@@ -218,18 +217,10 @@ namespace TensorFlowNET.Examples

public void PrepareData()
{
if (UseSubset)
{
var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/dbpedia_subset.zip";
Web.Download(url, dataDir, "dbpedia_subset.zip");
Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv"));
}
else
{
string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz";
Web.Download(url, dataDir, dataFileName);
Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir);
}
// full dataset https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz
var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/dbpedia_subset.zip";
Web.Download(url, dataDir, "dbpedia_subset.zip");
Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv"));

if (ImportGraph)
{
@@ -242,7 +233,7 @@ namespace TensorFlowNET.Examples
Console.WriteLine("Discarding cached file: " + meta_path);
File.Delete(meta_path);
}
var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file;
url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file;
Web.Download(url, "graph", meta_file);
}
}


Loading…
Cancel
Save