Browse Source

TextClassification example: use imported graph (in progress)

tags/v0.9
Meinrad Recheis 6 years ago
parent
commit
0cb07ece36
4 changed files with 88 additions and 15 deletions
  1. +5
    -3
      test/TensorFlowNET.Examples/Text/DataHelpers.cs
  2. +64
    -8
      test/TensorFlowNET.Examples/Text/TextClassificationTrain.cs
  3. +14
    -0
      test/TensorFlowNET.Examples/Text/cnn_models/ITextClassificationModel.cs
  4. +5
    -4
      test/TensorFlowNET.Examples/Text/cnn_models/VdCnn.cs

+ 5
- 3
test/TensorFlowNET.Examples/Text/DataHelpers.cs View File

@@ -13,8 +13,10 @@ namespace TensorFlowNET.Examples.CnnTextClassification
private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv";
private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv";

public static (int[][], int[], int) build_char_dataset(string step, string model, int document_max_len, int? limit=null)
public static (int[][], int[], int) build_char_dataset(string step, string model, int document_max_len, int? limit = null)
{
if (model != "vd_cnn")
throw new NotImplementedException(model);
string alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:’'\"/|_#$%ˆ&*˜‘+=<>()[]{} ";
/*if (step == "train")
df = pd.read_csv(TRAIN_PATH, names =["class", "title", "content"]);*/
@@ -41,8 +43,8 @@ namespace TensorFlowNET.Examples.CnnTextClassification
x[i][j] = char_dict["<pad>"];
else
x[i][j] = char_dict.ContainsKey(content[j].ToString()) ? char_dict[content[j].ToString()] : char_dict["<unk>"];
}
}
y[i] = int.Parse(parts[0]);
}



+ 64
- 8
test/TensorFlowNET.Examples/Text/TextClassificationTrain.cs View File

@@ -4,6 +4,8 @@ using System.IO;
using System.Linq;
using System.Text;
using Tensorflow;
using Tensorflow.Keras.Engine;
using TensorFlowNET.Examples.Text.cnn_models;
using TensorFlowNET.Examples.TextClassification;
using TensorFlowNET.Examples.Utility;

@@ -15,7 +17,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification
public class TextClassificationTrain : Python, IExample
{
public int Priority => 100;
public bool Enabled { get; set; }= false;
public bool Enabled { get; set; } = false;
public string Name => "Text Classification";
public int? DataLimit = null;
public bool ImportGraph { get; set; } = true;
@@ -23,22 +25,68 @@ namespace TensorFlowNET.Examples.CnnTextClassification
private string dataDir = "text_classification";
private string dataFileName = "dbpedia_csv.tar.gz";

public string model_name = "vd_cnn"; // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn

private const int CHAR_MAX_LEN = 1014;
private const int NUM_CLASS = 2;
protected float loss_value = 0;

public bool Run()
{
PrepareData();
return with(tf.Session(), sess =>
{
if (ImportGraph)
return RunWithImportedGraph(sess);
else
return RunWithBuiltGraph(sess);
});
}

protected virtual bool RunWithImportedGraph(Session sess)
{
var graph = tf.Graph().as_default();
Console.WriteLine("Building dataset...");
var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit);

var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);
var meta_file = model_name + "_untrained.meta";
tf.train.import_meta_graph(Path.Join("graph", meta_file));

//sess.run(tf.global_variables_initializer());

Tensor is_training = graph.get_operation_by_name("is_training");
Tensor model_x = graph.get_operation_by_name("x");
Tensor model_y = graph.get_operation_by_name("y");
//Tensor loss = graph.get_operation_by_name("loss");
//Tensor accuracy = graph.get_operation_by_name("accuracy");
return false;
}

protected virtual bool RunWithBuiltGraph(Session session)
{
Console.WriteLine("Building dataset...");
var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "vdcnn", CHAR_MAX_LEN, DataLimit);
var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit);

var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);

return with(tf.Session(), sess =>
{
new VdCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS);
return false;
});
ITextClassificationModel model = null;
switch (model_name) // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn
{
case "word_cnn":
case "char_cnn":
case "word_rnn":
case "att_rnn":
case "rcnn":
throw new NotImplementedException();
break;
case "vd_cnn":
model=new VdCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS);
break;
}
// todo train the model
return false;
}

private (int[][], int[][], int[], int[]) train_test_split(int[][] x, int[] y, float test_size = 0.3f)
@@ -53,7 +101,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification
var train_y = new List<int>();
var valid_y = new List<int>();

for (int i = 0; i< classes; i++)
for (int i = 0; i < classes; i++)
{
for (int j = 0; j < samples; j++)
{
@@ -79,6 +127,14 @@ namespace TensorFlowNET.Examples.CnnTextClassification
string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz";
Web.Download(url, dataDir, dataFileName);
Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir);

if (ImportGraph)
{
// download graph meta data
var meta_file = model_name + "_untrained.meta";
url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file;
Web.Download(url, "graph", meta_file);
}
}
}
}

+ 14
- 0
test/TensorFlowNET.Examples/Text/cnn_models/ITextClassificationModel.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;
namespace TensorFlowNET.Examples.Text.cnn_models
{
interface ITextClassificationModel
{
Tensor is_training { get; }
Tensor x { get;}
Tensor y { get; }
}
}

+ 5
- 4
test/TensorFlowNET.Examples/Text/cnn_models/VdCnn.cs View File

@@ -3,10 +3,11 @@ using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow;
using TensorFlowNET.Examples.Text.cnn_models;

namespace TensorFlowNET.Examples.TextClassification
{
public class VdCnn : Python
public class VdCnn : Python, ITextClassificationModel
{
private int embedding_size;
private int[] filter_sizes;
@@ -15,9 +16,9 @@ namespace TensorFlowNET.Examples.TextClassification
private float learning_rate;
private IInitializer cnn_initializer;
private IInitializer fc_initializer;
private Tensor x;
private Tensor y;
private Tensor is_training;
public Tensor x { get; private set; }
public Tensor y { get; private set; }
public Tensor is_training { get; private set; }
private RefVariable global_step;
private RefVariable embeddings;
private Tensor x_emb;


Loading…
Cancel
Save