diff --git a/test/TensorFlowNET.Examples/Text/DataHelpers.cs b/test/TensorFlowNET.Examples/Text/DataHelpers.cs index 43137b02..658a102a 100644 --- a/test/TensorFlowNET.Examples/Text/DataHelpers.cs +++ b/test/TensorFlowNET.Examples/Text/DataHelpers.cs @@ -13,8 +13,10 @@ namespace TensorFlowNET.Examples.CnnTextClassification private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv"; private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv"; - public static (int[][], int[], int) build_char_dataset(string step, string model, int document_max_len, int? limit=null) + public static (int[][], int[], int) build_char_dataset(string step, string model, int document_max_len, int? limit = null) { + if (model != "vd_cnn") + throw new NotImplementedException(model); string alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:’'\"/|_#$%ˆ&*˜‘+=<>()[]{} "; /*if (step == "train") df = pd.read_csv(TRAIN_PATH, names =["class", "title", "content"]);*/ @@ -41,8 +43,8 @@ namespace TensorFlowNET.Examples.CnnTextClassification x[i][j] = char_dict[""]; else x[i][j] = char_dict.ContainsKey(content[j].ToString()) ? char_dict[content[j].ToString()] : char_dict[""]; - } - + } + y[i] = int.Parse(parts[0]); } diff --git a/test/TensorFlowNET.Examples/Text/TextClassificationTrain.cs b/test/TensorFlowNET.Examples/Text/TextClassificationTrain.cs index ce06cbc5..8931c5ba 100644 --- a/test/TensorFlowNET.Examples/Text/TextClassificationTrain.cs +++ b/test/TensorFlowNET.Examples/Text/TextClassificationTrain.cs @@ -4,6 +4,8 @@ using System.IO; using System.Linq; using System.Text; using Tensorflow; +using Tensorflow.Keras.Engine; +using TensorFlowNET.Examples.Text.cnn_models; using TensorFlowNET.Examples.TextClassification; using TensorFlowNET.Examples.Utility; @@ -15,7 +17,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification public class TextClassificationTrain : Python, IExample { public int Priority => 100; - public bool Enabled { get; set; }= false; + public bool Enabled { get; set; } = false; public string Name => "Text Classification"; public int? DataLimit = null; public bool ImportGraph { get; set; } = true; @@ -23,22 +25,68 @@ namespace TensorFlowNET.Examples.CnnTextClassification private string dataDir = "text_classification"; private string dataFileName = "dbpedia_csv.tar.gz"; + public string model_name = "vd_cnn"; // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn + private const int CHAR_MAX_LEN = 1014; private const int NUM_CLASS = 2; + protected float loss_value = 0; public bool Run() { PrepareData(); + return with(tf.Session(), sess => + { + if (ImportGraph) + return RunWithImportedGraph(sess); + else + return RunWithBuiltGraph(sess); + }); + } + + protected virtual bool RunWithImportedGraph(Session sess) + { + var graph = tf.Graph().as_default(); + Console.WriteLine("Building dataset..."); + var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit); + + var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); + + var meta_file = model_name + "_untrained.meta"; + tf.train.import_meta_graph(Path.Join("graph", meta_file)); + + //sess.run(tf.global_variables_initializer()); + + Tensor is_training = graph.get_operation_by_name("is_training"); + Tensor model_x = graph.get_operation_by_name("x"); + Tensor model_y = graph.get_operation_by_name("y"); + //Tensor loss = graph.get_operation_by_name("loss"); + //Tensor accuracy = graph.get_operation_by_name("accuracy"); + return false; + } + + protected virtual bool RunWithBuiltGraph(Session session) + { Console.WriteLine("Building dataset..."); - var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "vdcnn", CHAR_MAX_LEN, DataLimit); + var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit); var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); - return with(tf.Session(), sess => - { - new VdCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS); - return false; - }); + ITextClassificationModel model = null; + switch (model_name) // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn + { + case "word_cnn": + case "char_cnn": + case "word_rnn": + case "att_rnn": + case "rcnn": + throw new NotImplementedException(); + break; + case "vd_cnn": + model=new VdCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS); + break; + } + // todo train the model + return false; } private (int[][], int[][], int[], int[]) train_test_split(int[][] x, int[] y, float test_size = 0.3f) @@ -53,7 +101,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification var train_y = new List(); var valid_y = new List(); - for (int i = 0; i< classes; i++) + for (int i = 0; i < classes; i++) { for (int j = 0; j < samples; j++) { @@ -79,6 +127,14 @@ namespace TensorFlowNET.Examples.CnnTextClassification string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz"; Web.Download(url, dataDir, dataFileName); Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir); + + if (ImportGraph) + { + // download graph meta data + var meta_file = model_name + "_untrained.meta"; + url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file; + Web.Download(url, "graph", meta_file); + } } } } diff --git a/test/TensorFlowNET.Examples/Text/cnn_models/ITextClassificationModel.cs b/test/TensorFlowNET.Examples/Text/cnn_models/ITextClassificationModel.cs new file mode 100644 index 00000000..e9778bba --- /dev/null +++ b/test/TensorFlowNET.Examples/Text/cnn_models/ITextClassificationModel.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow; + +namespace TensorFlowNET.Examples.Text.cnn_models +{ + interface ITextClassificationModel + { + Tensor is_training { get; } + Tensor x { get;} + Tensor y { get; } + } +} diff --git a/test/TensorFlowNET.Examples/Text/cnn_models/VdCnn.cs b/test/TensorFlowNET.Examples/Text/cnn_models/VdCnn.cs index da6f49b6..ddcf16f1 100644 --- a/test/TensorFlowNET.Examples/Text/cnn_models/VdCnn.cs +++ b/test/TensorFlowNET.Examples/Text/cnn_models/VdCnn.cs @@ -3,10 +3,11 @@ using System.Collections.Generic; using System.Linq; using System.Text; using Tensorflow; +using TensorFlowNET.Examples.Text.cnn_models; namespace TensorFlowNET.Examples.TextClassification { - public class VdCnn : Python + public class VdCnn : Python, ITextClassificationModel { private int embedding_size; private int[] filter_sizes; @@ -15,9 +16,9 @@ namespace TensorFlowNET.Examples.TextClassification private float learning_rate; private IInitializer cnn_initializer; private IInitializer fc_initializer; - private Tensor x; - private Tensor y; - private Tensor is_training; + public Tensor x { get; private set; } + public Tensor y { get; private set; } + public Tensor is_training { get; private set; } private RefVariable global_step; private RefVariable embeddings; private Tensor x_emb;