Browse Source

word_cnn training completed, but saving model doesn't work.

tags/v0.9
Oceania2018 6 years ago
parent
commit
e56e5d3292
6 changed files with 90 additions and 28 deletions
  1. BIN
      data/dbpedia_subset.zip
  2. BIN
      graph/word_cnn.meta
  3. +6
    -0
      src/TensorFlowNET.Core/APIs/tf.variable.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Train/tf.optimizers.cs
  5. +42
    -2
      test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs
  6. +41
    -25
      test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs

BIN
data/dbpedia_subset.zip View File


BIN
graph/word_cnn.meta View File


+ 6
- 0
src/TensorFlowNET.Core/APIs/tf.variable.cs View File

@@ -6,6 +6,12 @@ namespace Tensorflow
{ {
public static partial class tf public static partial class tf
{ {
public static VariableV1[] global_variables(string scope = null)
{
return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) as List<VariableV1>)
.ToArray();
}

public static Operation global_variables_initializer() public static Operation global_variables_initializer()
{ {
var g = variables.global_variables(); var g = variables.global_variables();


+ 1
- 1
src/TensorFlowNET.Core/Train/tf.optimizers.cs View File

@@ -14,7 +14,7 @@ namespace Tensorflow


public static Optimizer AdamOptimizer(float learning_rate) => new AdamOptimizer(learning_rate); public static Optimizer AdamOptimizer(float learning_rate) => new AdamOptimizer(learning_rate);


public static Saver Saver() => new Saver();
public static Saver Saver(VariableV1[] var_list = null) => new Saver(var_list: var_list);


public static string write_graph(Graph graph, string logdir, string name, bool as_text = true) public static string write_graph(Graph graph, string logdir, string name, bool as_text = true)
=> graph_io.write_graph(graph, logdir, name, as_text); => graph_io.write_graph(graph, logdir, name, as_text);


+ 42
- 2
test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs View File

@@ -12,6 +12,46 @@ namespace TensorFlowNET.Examples
{ {
public class DataHelpers public class DataHelpers
{ {
public static Dictionary<string, int> build_word_dict(string path)
{
var contents = File.ReadAllLines(path);
var words = new List<string>();
foreach (var content in contents)
words.AddRange(clean_str(content).Split(' ').Where(x => x.Length > 1));
var word_counter = words.GroupBy(x => x)
.Select(x => new { Word = x.Key, Count = x.Count() })
.OrderByDescending(x => x.Count)
.ToArray();
var word_dict = new Dictionary<string, int>();
word_dict["<pad>"] = 0;
word_dict["<unk>"] = 1;
word_dict["<eos>"] = 2;
foreach (var word in word_counter)
word_dict[word.Word] = word_dict.Count;
return word_dict;
}

public static (int[][], int[]) build_word_dataset(string path, Dictionary<string, int> word_dict, int document_max_len)
{
var contents = File.ReadAllLines(path);
var x = contents.Select(c => (clean_str(c) + " <eos>")
.Split(' ').Take(document_max_len)
.Select(w => word_dict.ContainsKey(w) ? word_dict[w] : word_dict["<unk>"]).ToArray())
.ToArray();
for (int i = 0; i < x.Length; i++)
if (x[i].Length == document_max_len)
x[i][document_max_len - 1] = word_dict["<eos>"];
else
Array.Resize(ref x[i], document_max_len);
var y = contents.Select(c => int.Parse(c.Substring(0, c.IndexOf(','))) - 1).ToArray();
return (x, y);
}


public static (int[][], int[], int) build_char_dataset(string path, string model, int document_max_len, int? limit = null, bool shuffle=true) public static (int[][], int[], int) build_char_dataset(string path, string model, int document_max_len, int? limit = null, bool shuffle=true)
{ {
@@ -96,8 +136,8 @@ namespace TensorFlowNET.Examples


private static string clean_str(string str) private static string clean_str(string str)
{ {
str = Regex.Replace(str, @"[^A-Za-z0-9(),!?\'\`]", " ");
str = Regex.Replace(str, @"\'s", " \'s");
str = Regex.Replace(str, "[^A-Za-z0-9(),!?]", " ");
str = Regex.Replace(str, ",", " ");
return str; return str;
} }




+ 41
- 25
test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs View File

@@ -26,22 +26,23 @@ namespace TensorFlowNET.Examples.CnnTextClassification
public string Name => "Text Classification"; public string Name => "Text Classification";
public int? DataLimit = null; public int? DataLimit = null;
public bool ImportGraph { get; set; } = true; public bool ImportGraph { get; set; } = true;
public bool UseSubset = true; // <----- set this true to use a limited subset of dbpedia
public bool UseSubset = false; // <----- set this true to use a limited subset of dbpedia


private string dataDir = "text_classification"; private string dataDir = "text_classification";
private string dataFileName = "dbpedia_csv.tar.gz"; private string dataFileName = "dbpedia_csv.tar.gz";


public string model_name = "vd_cnn"; // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn
public string model_name = "word_cnn"; // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn


private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv"; private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv";
private const string SUBSET_PATH = "text_classification/dbpedia_csv/dbpedia_6400.csv"; private const string SUBSET_PATH = "text_classification/dbpedia_csv/dbpedia_6400.csv";
private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv"; private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv";


private const int CHAR_MAX_LEN = 1014;
private const int WORD_MAX_LEN = 1014;
private const int NUM_CLASS = 14; private const int NUM_CLASS = 14;
private const int BATCH_SIZE = 64; private const int BATCH_SIZE = 64;
private const int NUM_EPOCHS = 10; private const int NUM_EPOCHS = 10;
private const int WORD_MAX_LEN = 100;
private const int CHAR_MAX_LEN = 1014;
protected float loss_value = 0; protected float loss_value = 0;


public bool Run() public bool Run()
@@ -61,8 +62,21 @@ namespace TensorFlowNET.Examples.CnnTextClassification
{ {
var stopwatch = Stopwatch.StartNew(); var stopwatch = Stopwatch.StartNew();
Console.WriteLine("Building dataset..."); Console.WriteLine("Building dataset...");
var path = UseSubset ? SUBSET_PATH : TRAIN_PATH;
var (x, y, alphabet_size) = DataHelpers.build_char_dataset(path, model_name, CHAR_MAX_LEN, DataLimit = null, shuffle:!UseSubset);
var path = UseSubset ? SUBSET_PATH : TRAIN_PATH;
int[][] x = null;
int[] y = null;
int alphabet_size = 0;
int vocabulary_size = 0;

if (model_name == "vd_cnn")
(x, y, alphabet_size) = DataHelpers.build_char_dataset(path, model_name, CHAR_MAX_LEN, DataLimit = null, shuffle:!UseSubset);
else
{
var word_dict = DataHelpers.build_word_dict(TRAIN_PATH);
vocabulary_size = len(word_dict);
(x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN);
}

Console.WriteLine("\tDONE "); Console.WriteLine("\tDONE ");


var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);
@@ -75,18 +89,19 @@ namespace TensorFlowNET.Examples.CnnTextClassification
Console.WriteLine("\tDONE " + stopwatch.Elapsed); Console.WriteLine("\tDONE " + stopwatch.Elapsed);


sess.run(tf.global_variables_initializer()); sess.run(tf.global_variables_initializer());
var saver = tf.train.Saver(tf.global_variables());
var train_batches = batch_iter(train_x, train_y, BATCH_SIZE, NUM_EPOCHS); var train_batches = batch_iter(train_x, train_y, BATCH_SIZE, NUM_EPOCHS);
var num_batches_per_epoch = (len(train_x) - 1) / BATCH_SIZE + 1; var num_batches_per_epoch = (len(train_x) - 1) / BATCH_SIZE + 1;
double max_accuracy = 0; double max_accuracy = 0;


Tensor is_training = graph.get_tensor_by_name("is_training:0");
Tensor model_x = graph.get_tensor_by_name("x:0");
Tensor model_y = graph.get_tensor_by_name("y:0");
Tensor loss = graph.get_tensor_by_name("loss/value:0");
Tensor optimizer = graph.get_tensor_by_name("loss/optimizer:0");
Tensor global_step = graph.get_tensor_by_name("global_step:0");
Tensor accuracy = graph.get_tensor_by_name("accuracy/value:0");
Tensor is_training = graph.OperationByName("is_training");
Tensor model_x = graph.OperationByName("x");
Tensor model_y = graph.OperationByName("y");
Tensor loss = graph.OperationByName("loss/Mean"); // word_cnn
Operation optimizer = graph.OperationByName("loss/Adam"); // word_cnn
Tensor global_step = graph.OperationByName("Variable");
Tensor accuracy = graph.OperationByName("accuracy/accuracy");
stopwatch = Stopwatch.StartNew(); stopwatch = Stopwatch.StartNew();
int i = 0; int i = 0;
foreach (var (x_batch, y_batch, total) in train_batches) foreach (var (x_batch, y_batch, total) in train_batches)
@@ -105,11 +120,10 @@ namespace TensorFlowNET.Examples.CnnTextClassification
var result = sess.run(new ITensorOrOperation[] { optimizer, global_step, loss }, train_feed_dict); var result = sess.run(new ITensorOrOperation[] { optimizer, global_step, loss }, train_feed_dict);
loss_value = result[2]; loss_value = result[2];
var step = (int)result[1]; var step = (int)result[1];
if (step % 10 == 0 || step < 10)
if (step % 10 == 0)
{ {
var estimate = TimeSpan.FromSeconds((stopwatch.Elapsed.TotalSeconds / i) * total); var estimate = TimeSpan.FromSeconds((stopwatch.Elapsed.TotalSeconds / i) * total);
Console.WriteLine($"Training on batch {i}/{total}. Estimated training time: {estimate}");
Console.WriteLine($"Step {step} loss: {loss_value}");
Console.WriteLine($"Training on batch {i}/{total} loss: {loss_value}. Estimated training time: {estimate}");
} }


if (step % 100 == 0) if (step % 100 == 0)
@@ -133,13 +147,15 @@ namespace TensorFlowNET.Examples.CnnTextClassification


var valid_accuracy = sum_accuracy / cnt; var valid_accuracy = sum_accuracy / cnt;


print($"\nValidation Accuracy = {valid_accuracy}\n");

// # Save model
// if valid_accuracy > max_accuracy:
// max_accuracy = valid_accuracy
// saver.save(sess, "{0}/{1}.ckpt".format(args.model, args.model), global_step = step)
// print("Model is saved.\n")
print($"\nValidation Accuracy = {valid_accuracy}\n");
// # Save model
if (valid_accuracy > max_accuracy)
{
max_accuracy = valid_accuracy;
// saver.save(sess, $"{dataDir}/{model_name}.ckpt", global_step: step.ToString());
print("Model is saved.\n");
}
} }
} }


@@ -180,9 +196,9 @@ namespace TensorFlowNET.Examples.CnnTextClassification
//int samples = len / classes; //int samples = len / classes;
int train_size = (int)Math.Round(len * (1 - test_size)); int train_size = (int)Math.Round(len * (1 - test_size));
var train_x = x[new Slice(stop: train_size), new Slice()]; var train_x = x[new Slice(stop: train_size), new Slice()];
var valid_x = x[new Slice(start: train_size + 1), new Slice()];
var valid_x = x[new Slice(start: train_size), new Slice()];
var train_y = y[new Slice(stop: train_size)]; var train_y = y[new Slice(stop: train_size)];
var valid_y = y[new Slice(start: train_size + 1)];
var valid_y = y[new Slice(start: train_size)];
Console.WriteLine("\tDONE"); Console.WriteLine("\tDONE");
return (train_x, valid_x, train_y, valid_y); return (train_x, valid_x, train_y, valid_y);
} }


Loading…
Cancel
Save