|
|
@@ -4,6 +4,8 @@ using System.IO; |
|
|
|
using System.Linq; |
|
|
|
using System.Text; |
|
|
|
using Tensorflow; |
|
|
|
using Tensorflow.Keras.Engine; |
|
|
|
using TensorFlowNET.Examples.Text.cnn_models; |
|
|
|
using TensorFlowNET.Examples.TextClassification; |
|
|
|
using TensorFlowNET.Examples.Utility; |
|
|
|
|
|
|
@@ -15,7 +17,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification |
|
|
|
public class TextClassificationTrain : Python, IExample |
|
|
|
{ |
|
|
|
public int Priority => 100; |
|
|
|
public bool Enabled { get; set; }= false; |
|
|
|
public bool Enabled { get; set; } = false; |
|
|
|
public string Name => "Text Classification"; |
|
|
|
public int? DataLimit = null; |
|
|
|
public bool ImportGraph { get; set; } = true; |
|
|
@@ -23,22 +25,68 @@ namespace TensorFlowNET.Examples.CnnTextClassification |
|
|
|
private string dataDir = "text_classification"; |
|
|
|
private string dataFileName = "dbpedia_csv.tar.gz"; |
|
|
|
|
|
|
|
public string model_name = "vd_cnn"; // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn |
|
|
|
|
|
|
|
private const int CHAR_MAX_LEN = 1014; |
|
|
|
private const int NUM_CLASS = 2; |
|
|
|
protected float loss_value = 0; |
|
|
|
|
|
|
|
public bool Run() |
|
|
|
{ |
|
|
|
PrepareData(); |
|
|
|
return with(tf.Session(), sess => |
|
|
|
{ |
|
|
|
if (ImportGraph) |
|
|
|
return RunWithImportedGraph(sess); |
|
|
|
else |
|
|
|
return RunWithBuiltGraph(sess); |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
protected virtual bool RunWithImportedGraph(Session sess) |
|
|
|
{ |
|
|
|
var graph = tf.Graph().as_default(); |
|
|
|
Console.WriteLine("Building dataset..."); |
|
|
|
var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit); |
|
|
|
|
|
|
|
var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);
|
|
|
|
|
|
|
|
var meta_file = model_name + "_untrained.meta"; |
|
|
|
tf.train.import_meta_graph(Path.Join("graph", meta_file)); |
|
|
|
|
|
|
|
//sess.run(tf.global_variables_initializer()); |
|
|
|
|
|
|
|
Tensor is_training = graph.get_operation_by_name("is_training"); |
|
|
|
Tensor model_x = graph.get_operation_by_name("x"); |
|
|
|
Tensor model_y = graph.get_operation_by_name("y"); |
|
|
|
//Tensor loss = graph.get_operation_by_name("loss"); |
|
|
|
//Tensor accuracy = graph.get_operation_by_name("accuracy"); |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
protected virtual bool RunWithBuiltGraph(Session session) |
|
|
|
{ |
|
|
|
Console.WriteLine("Building dataset..."); |
|
|
|
var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "vdcnn", CHAR_MAX_LEN, DataLimit); |
|
|
|
var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit); |
|
|
|
|
|
|
|
var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); |
|
|
|
|
|
|
|
return with(tf.Session(), sess => |
|
|
|
{ |
|
|
|
new VdCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS); |
|
|
|
return false; |
|
|
|
}); |
|
|
|
ITextClassificationModel model = null;
|
|
|
|
switch (model_name) // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn
|
|
|
|
{
|
|
|
|
case "word_cnn":
|
|
|
|
case "char_cnn":
|
|
|
|
case "word_rnn":
|
|
|
|
case "att_rnn":
|
|
|
|
case "rcnn":
|
|
|
|
throw new NotImplementedException();
|
|
|
|
break;
|
|
|
|
case "vd_cnn": |
|
|
|
model=new VdCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS);
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
// todo train the model
|
|
|
|
return false;
|
|
|
|
} |
|
|
|
|
|
|
|
private (int[][], int[][], int[], int[]) train_test_split(int[][] x, int[] y, float test_size = 0.3f) |
|
|
@@ -53,7 +101,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification |
|
|
|
var train_y = new List<int>(); |
|
|
|
var valid_y = new List<int>(); |
|
|
|
|
|
|
|
for (int i = 0; i< classes; i++) |
|
|
|
for (int i = 0; i < classes; i++) |
|
|
|
{ |
|
|
|
for (int j = 0; j < samples; j++) |
|
|
|
{ |
|
|
@@ -79,6 +127,14 @@ namespace TensorFlowNET.Examples.CnnTextClassification |
|
|
|
string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz"; |
|
|
|
Web.Download(url, dataDir, dataFileName); |
|
|
|
Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir); |
|
|
|
|
|
|
|
if (ImportGraph) |
|
|
|
{ |
|
|
|
// download graph meta data |
|
|
|
var meta_file = model_name + "_untrained.meta"; |
|
|
|
url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file; |
|
|
|
Web.Download(url, "graph", meta_file); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |