diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs
index 4babffb0..60011e58 100644
--- a/src/TensorFlowNET.Core/ops.py.cs
+++ b/src/TensorFlowNET.Core/ops.py.cs
@@ -103,7 +103,7 @@ namespace Tensorflow
}
public static Graph _get_graph_from_inputs(params Tensor[] op_input_list)
- => _get_graph_from_inputs(op_input_list: op_input_list);
+ => _get_graph_from_inputs(op_input_list: op_input_list, graph: null);
public static Graph _get_graph_from_inputs(Tensor[] op_input_list, Graph graph = null)
{
diff --git a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs
new file mode 100644
index 00000000..01b6d72c
--- /dev/null
+++ b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs
@@ -0,0 +1,250 @@
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.IO;
+using System.Linq;
+using System.Text;
+using NumSharp;
+using Tensorflow;
+using Tensorflow.Keras.Engine;
+using Tensorflow.Sessions;
+using TensorFlowNET.Examples.Text.cnn_models;
+using TensorFlowNET.Examples.TextClassification;
+using TensorFlowNET.Examples.Utility;
+using static Tensorflow.Python;
+
+namespace TensorFlowNET.Examples
+{
+ ///
+ /// https://github.com/dongjun-Lee/text-classification-models-tf
+ ///
+ public class CnnTextClassification : IExample
+ {
+ public int Priority => 17;
+ public bool Enabled { get; set; } = true;
+ public string Name => "CNN Text Classification";
+ public int? DataLimit = null;
+ public bool ImportGraph { get; set; } = true;
+ public bool UseSubset = false; // <----- set this true to use a limited subset of dbpedia
+
+ private string dataDir = "text_classification";
+ private string dataFileName = "dbpedia_csv.tar.gz";
+
+ private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv";
+ private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv";
+
+ private const int NUM_CLASS = 14;
+ private const int BATCH_SIZE = 64;
+ private const int NUM_EPOCHS = 10;
+ private const int WORD_MAX_LEN = 100;
+ private const int CHAR_MAX_LEN = 1014;
+
+ protected float loss_value = 0;
+
+ public bool Run()
+ {
+ PrepareData();
+
+ var graph = tf.Graph().as_default();
+ return with(tf.Session(graph), sess =>
+ {
+ if (ImportGraph)
+ return RunWithImportedGraph(sess, graph);
+ else
+ return RunWithBuiltGraph(sess, graph);
+ });
+ }
+
+ protected virtual bool RunWithImportedGraph(Session sess, Graph graph)
+ {
+ var stopwatch = Stopwatch.StartNew();
+ Console.WriteLine("Building dataset...");
+ int[][] x = null;
+ int[] y = null;
+ int alphabet_size = 0;
+ int vocabulary_size = 0;
+
+ var word_dict = DataHelpers.build_word_dict(TRAIN_PATH);
+ vocabulary_size = len(word_dict);
+ (x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN);
+
+ Console.WriteLine("\tDONE ");
+
+ var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);
+ Console.WriteLine("Training set size: " + train_x.len);
+ Console.WriteLine("Test set size: " + valid_x.len);
+
+ Console.WriteLine("Import graph...");
+ var meta_file = "word_cnn.meta";
+ tf.train.import_meta_graph(Path.Join("graph", meta_file));
+ Console.WriteLine("\tDONE " + stopwatch.Elapsed);
+
+ sess.run(tf.global_variables_initializer());
+ var saver = tf.train.Saver(tf.global_variables());
+
+ var train_batches = batch_iter(train_x, train_y, BATCH_SIZE, NUM_EPOCHS);
+ var num_batches_per_epoch = (len(train_x) - 1) / BATCH_SIZE + 1;
+ double max_accuracy = 0;
+
+ Tensor is_training = graph.OperationByName("is_training");
+ Tensor model_x = graph.OperationByName("x");
+ Tensor model_y = graph.OperationByName("y");
+ Tensor loss = graph.OperationByName("loss/Mean");
+ Operation optimizer = graph.OperationByName("loss/Adam");
+ Tensor global_step = graph.OperationByName("Variable");
+ Tensor accuracy = graph.OperationByName("accuracy/accuracy");
+ stopwatch = Stopwatch.StartNew();
+ int i = 0;
+ foreach (var (x_batch, y_batch, total) in train_batches)
+ {
+ i++;
+ var train_feed_dict = new FeedDict
+ {
+ [model_x] = x_batch,
+ [model_y] = y_batch,
+ [is_training] = true,
+ };
+
+ var result = sess.run(new ITensorOrOperation[] { optimizer, global_step, loss }, train_feed_dict);
+ loss_value = result[2];
+ var step = (int)result[1];
+ if (step % 10 == 0)
+ {
+ var estimate = TimeSpan.FromSeconds((stopwatch.Elapsed.TotalSeconds / i) * total);
+ Console.WriteLine($"Training on batch {i}/{total} loss: {loss_value}. Estimated training time: {estimate}");
+ }
+
+ if (step % 100 == 0)
+ {
+ // Test accuracy with validation data for each epoch.
+ var valid_batches = batch_iter(valid_x, valid_y, BATCH_SIZE, 1);
+ var (sum_accuracy, cnt) = (0.0f, 0);
+ foreach (var (valid_x_batch, valid_y_batch, total_validation_batches) in valid_batches)
+ {
+ var valid_feed_dict = new FeedDict
+ {
+ [model_x] = valid_x_batch,
+ [model_y] = valid_y_batch,
+ [is_training] = false
+ };
+ var result1 = sess.run(accuracy, valid_feed_dict);
+ float accuracy_value = result1;
+ sum_accuracy += accuracy_value;
+ cnt += 1;
+ }
+
+ var valid_accuracy = sum_accuracy / cnt;
+
+ print($"\nValidation Accuracy = {valid_accuracy}\n");
+
+ // Save model
+ if (valid_accuracy > max_accuracy)
+ {
+ max_accuracy = valid_accuracy;
+ saver.save(sess, $"{dataDir}/word_cnn.ckpt", global_step: step.ToString());
+ print("Model is saved.\n");
+ }
+ }
+ }
+
+ return false;
+ }
+
+ protected virtual bool RunWithBuiltGraph(Session session, Graph graph)
+ {
+ Console.WriteLine("Building dataset...");
+ var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "word_cnn", CHAR_MAX_LEN, DataLimit);
+
+ var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);
+
+ ITextClassificationModel model = null;
+ // todo train the model
+ return false;
+ }
+
+ // TODO: this originally is an SKLearn utility function. it randomizes train and test which we don't do here
+ private (NDArray, NDArray, NDArray, NDArray) train_test_split(NDArray x, NDArray y, float test_size = 0.3f)
+ {
+ Console.WriteLine("Splitting in Training and Testing data...");
+ int len = x.shape[0];
+ //int classes = y.Data().Distinct().Count();
+ //int samples = len / classes;
+ int train_size = (int)Math.Round(len * (1 - test_size));
+ var train_x = x[new Slice(stop: train_size), new Slice()];
+ var valid_x = x[new Slice(start: train_size), new Slice()];
+ var train_y = y[new Slice(stop: train_size)];
+ var valid_y = y[new Slice(start: train_size)];
+ Console.WriteLine("\tDONE");
+ return (train_x, valid_x, train_y, valid_y);
+ }
+
+ private static void FillWithShuffledLabels(int[][] x, int[] y, int[][] shuffled_x, int[] shuffled_y, Random random, Dictionary> labels)
+ {
+ int i = 0;
+ var label_keys = labels.Keys.ToArray();
+ while (i < shuffled_x.Length)
+ {
+ var key = label_keys[random.Next(label_keys.Length)];
+ var set = labels[key];
+ var index = set.First();
+ if (set.Count == 0)
+ {
+ labels.Remove(key); // remove the set as it is empty
+ label_keys = labels.Keys.ToArray();
+ }
+ shuffled_x[i] = x[index];
+ shuffled_y[i] = y[index];
+ i++;
+ }
+ }
+
+ private IEnumerable<(NDArray, NDArray, int)> batch_iter(NDArray inputs, NDArray outputs, int batch_size, int num_epochs)
+ {
+ var num_batches_per_epoch = (len(inputs) - 1) / batch_size + 1;
+ var total_batches = num_batches_per_epoch * num_epochs;
+ foreach (var epoch in range(num_epochs))
+ {
+ foreach (var batch_num in range(num_batches_per_epoch))
+ {
+ var start_index = batch_num * batch_size;
+ var end_index = Math.Min((batch_num + 1) * batch_size, len(inputs));
+ if (end_index <= start_index)
+ break;
+ yield return (inputs[new Slice(start_index, end_index)], outputs[new Slice(start_index, end_index)], total_batches);
+ }
+ }
+ }
+
+ public void PrepareData()
+ {
+ if (UseSubset)
+ {
+ var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/dbpedia_subset.zip";
+ Web.Download(url, dataDir, "dbpedia_subset.zip");
+ Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv"));
+ }
+ else
+ {
+ string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz";
+ Web.Download(url, dataDir, dataFileName);
+ Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir);
+ }
+
+ if (ImportGraph)
+ {
+ // download graph meta data
+ var meta_file = "word_cnn.meta";
+ var meta_path = Path.Combine("graph", meta_file);
+ if (File.GetLastWriteTime(meta_path) < new DateTime(2019, 05, 11))
+ {
+ // delete old cached file which contains errors
+ Console.WriteLine("Discarding cached file: " + meta_path);
+ File.Delete(meta_path);
+ }
+ var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file;
+ Web.Download(url, "graph", meta_file);
+ }
+ }
+ }
+}
diff --git a/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs b/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs
index 38a519d1..1c2237b2 100644
--- a/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs
+++ b/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs
@@ -14,7 +14,7 @@ using TensorFlowNET.Examples.TextClassification;
using TensorFlowNET.Examples.Utility;
using static Tensorflow.Python;
-namespace TensorFlowNET.Examples.CnnTextClassification
+namespace TensorFlowNET.Examples
{
///
/// https://github.com/dongjun-Lee/text-classification-models-tf
diff --git a/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs b/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs
index bca6e64f..3fb3ec26 100644
--- a/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs
+++ b/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs
@@ -4,7 +4,6 @@ using System.Text;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
using TensorFlowNET.Examples;
-using TensorFlowNET.Examples.CnnTextClassification;
namespace TensorFlowNET.ExamplesTests
{