From c42f8bbf88f9766fc485547399854503befee5d3 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 30 May 2019 22:00:41 -0500 Subject: [PATCH] add train and predict interfaces to IExample. --- test/TensorFlowNET.Examples/BasicEagerApi.cs | 23 ++++++++++++-- .../BasicModels/KMeansClustering.cs | 23 ++++++++++++-- .../BasicModels/LinearRegression.cs | 24 +++++++++++++-- .../BasicModels/LogisticRegression.cs | 23 ++++++++++++-- .../BasicModels/NaiveBayesClassifier.cs | 25 ++++++++++++++-- .../BasicModels/NearestNeighbor.cs | 23 ++++++++++++-- .../BasicModels/NeuralNetXor.cs | 27 ++++++++++++++--- .../TensorFlowNET.Examples/BasicOperations.cs | 24 +++++++++++++-- test/TensorFlowNET.Examples/HelloWorld.cs | 24 +++++++++++++-- test/TensorFlowNET.Examples/IExample.cs | 21 ++++++++----- .../ImageProcess/ImageBackgroundRemoval.cs | 24 +++++++++++++-- .../ImageProcess/ImageRecognitionInception.cs | 23 ++++++++++++-- .../ImageProcess/InceptionArchGoogLeNet.cs | 23 ++++++++++++-- .../ImageProcess/ObjectDetection.cs | 23 ++++++++++++-- .../ImageProcess/RetrainImageClassifier.cs | 22 +++++++++++++- test/TensorFlowNET.Examples/Program.cs | 9 +++--- .../TextProcess/BinaryTextClassification.cs | 23 ++++++++++++-- .../TextProcess/CnnTextClassification.cs | 27 ++++++++++++++--- .../TextProcess/NER/BiLstmCrfNer.cs | 24 +++++++++++++-- .../TextProcess/NER/CRF.cs | 24 +++++++++++++-- .../TextProcess/NER/LstmCrfNer.cs | 24 +++++++++++++-- .../TextProcess/NamedEntityRecognition.cs | 25 ++++++++++++++-- .../TextProcess/TextClassificationTrain.cs | 29 ++++++++++++++---- .../TextProcess/Word2Vec.cs | 23 ++++++++++++-- .../ExamplesTests/ExamplesTest.cs | 30 +++++++++---------- 25 files changed, 503 insertions(+), 87 deletions(-) diff --git a/test/TensorFlowNET.Examples/BasicEagerApi.cs b/test/TensorFlowNET.Examples/BasicEagerApi.cs index 0cf572ef..35664310 100644 --- a/test/TensorFlowNET.Examples/BasicEagerApi.cs +++ b/test/TensorFlowNET.Examples/BasicEagerApi.cs @@ -11,10 +11,9 @@ namespace TensorFlowNET.Examples /// public class BasicEagerApi : IExample { - public int Priority => 100; public bool Enabled { get; set; } = false; public string Name => "Basic Eager"; - public bool ImportGraph { get; set; } = false; + public bool IsImportingGraph { get; set; } = false; private Tensor a, b, c, d; @@ -46,5 +45,25 @@ namespace TensorFlowNET.Examples public void PrepareData() { } + + public Graph ImportGraph() + { + throw new NotImplementedException(); + } + + public Graph BuildGraph() + { + throw new NotImplementedException(); + } + + public bool Predict() + { + throw new NotImplementedException(); + } + + public bool Train() + { + throw new NotImplementedException(); + } } } diff --git a/test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs b/test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs index 2b216439..b8f01865 100644 --- a/test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs +++ b/test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs @@ -18,10 +18,9 @@ namespace TensorFlowNET.Examples /// public class KMeansClustering : IExample { - public int Priority => 8; public bool Enabled { get; set; } = true; public string Name => "K-means Clustering"; - public bool ImportGraph { get; set; } = true; + public bool IsImportingGraph { get; set; } = true; public int? train_size = null; public int validation_size = 5000; @@ -127,5 +126,25 @@ namespace TensorFlowNET.Examples string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/kmeans.meta"; Web.Download(url, "graph", "kmeans.meta"); } + + public Graph ImportGraph() + { + throw new NotImplementedException(); + } + + public Graph BuildGraph() + { + throw new NotImplementedException(); + } + + public bool Train() + { + throw new NotImplementedException(); + } + + public bool Predict() + { + throw new NotImplementedException(); + } } } diff --git a/test/TensorFlowNET.Examples/BasicModels/LinearRegression.cs b/test/TensorFlowNET.Examples/BasicModels/LinearRegression.cs index c5021da1..59a36530 100644 --- a/test/TensorFlowNET.Examples/BasicModels/LinearRegression.cs +++ b/test/TensorFlowNET.Examples/BasicModels/LinearRegression.cs @@ -13,11 +13,9 @@ namespace TensorFlowNET.Examples /// public class LinearRegression : IExample { - public int Priority => 3; public bool Enabled { get; set; } = true; public string Name => "Linear Regression"; - public bool ImportGraph { get; set; } = false; - + public bool IsImportingGraph { get; set; } = false; public int training_epochs = 1000; @@ -113,5 +111,25 @@ namespace TensorFlowNET.Examples 2.827f, 3.465f, 1.65f, 2.904f, 2.42f, 2.94f, 1.3f); n_samples = train_X.shape[0]; } + + public Graph ImportGraph() + { + throw new NotImplementedException(); + } + + public Graph BuildGraph() + { + throw new NotImplementedException(); + } + + public bool Train() + { + throw new NotImplementedException(); + } + + public bool Predict() + { + throw new NotImplementedException(); + } } } diff --git a/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs b/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs index 8e7779de..d8e9444a 100644 --- a/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs +++ b/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs @@ -18,10 +18,9 @@ namespace TensorFlowNET.Examples /// public class LogisticRegression : IExample { - public int Priority => 4; public bool Enabled { get; set; } = true; public string Name => "Logistic Regression"; - public bool ImportGraph { get; set; } = false; + public bool IsImportingGraph { get; set; } = false; public int training_epochs = 10; @@ -158,5 +157,25 @@ namespace TensorFlowNET.Examples throw new ValueError("predict error, should be 90% accuracy"); }); } + + public Graph ImportGraph() + { + throw new NotImplementedException(); + } + + public Graph BuildGraph() + { + throw new NotImplementedException(); + } + + public bool Train() + { + throw new NotImplementedException(); + } + + bool IExample.Predict() + { + throw new NotImplementedException(); + } } } diff --git a/test/TensorFlowNET.Examples/BasicModels/NaiveBayesClassifier.cs b/test/TensorFlowNET.Examples/BasicModels/NaiveBayesClassifier.cs index a9104b96..e1615035 100644 --- a/test/TensorFlowNET.Examples/BasicModels/NaiveBayesClassifier.cs +++ b/test/TensorFlowNET.Examples/BasicModels/NaiveBayesClassifier.cs @@ -13,10 +13,9 @@ namespace TensorFlowNET.Examples /// public class NaiveBayesClassifier : IExample { - public int Priority => 6; public bool Enabled { get; set; } = true; public string Name => "Naive Bayes Classifier"; - public bool ImportGraph { get; set; } = false; + public bool IsImportingGraph { get; set; } = false; public NDArray X, y; public Normal dist { get; set; } @@ -96,7 +95,7 @@ namespace TensorFlowNET.Examples this.dist = dist; } - public Tensor predict (NDArray X) + public Tensor predict(NDArray X) { if (dist == null) { @@ -170,5 +169,25 @@ namespace TensorFlowNET.Examples 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2); #endregion } + + public Graph ImportGraph() + { + throw new NotImplementedException(); + } + + public Graph BuildGraph() + { + throw new NotImplementedException(); + } + + public bool Train() + { + throw new NotImplementedException(); + } + + public bool Predict() + { + throw new NotImplementedException(); + } } } diff --git a/test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs b/test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs index 0c99d84e..07838927 100644 --- a/test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs +++ b/test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs @@ -15,7 +15,6 @@ namespace TensorFlowNET.Examples /// public class NearestNeighbor : IExample { - public int Priority => 5; public bool Enabled { get; set; } = true; public string Name => "Nearest Neighbor"; Datasets mnist; @@ -23,7 +22,7 @@ namespace TensorFlowNET.Examples public int? TrainSize = null; public int ValidationSize = 5000; public int? TestSize = null; - public bool ImportGraph { get; set; } = false; + public bool IsImportingGraph { get; set; } = false; public bool Run() @@ -76,5 +75,25 @@ namespace TensorFlowNET.Examples (Xtr, Ytr) = mnist.train.next_batch(TrainSize==null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates) (Xte, Yte) = mnist.test.next_batch(TestSize==null ? 200 : TestSize.Value / 100); // 200 for testing } + + public Graph ImportGraph() + { + throw new NotImplementedException(); + } + + public Graph BuildGraph() + { + throw new NotImplementedException(); + } + + public bool Train() + { + throw new NotImplementedException(); + } + + public bool Predict() + { + throw new NotImplementedException(); + } } } diff --git a/test/TensorFlowNET.Examples/BasicModels/NeuralNetXor.cs b/test/TensorFlowNET.Examples/BasicModels/NeuralNetXor.cs index 9d62cfc2..6b8bf10f 100644 --- a/test/TensorFlowNET.Examples/BasicModels/NeuralNetXor.cs +++ b/test/TensorFlowNET.Examples/BasicModels/NeuralNetXor.cs @@ -14,10 +14,9 @@ namespace TensorFlowNET.Examples /// public class NeuralNetXor : IExample { - public int Priority => 10; public bool Enabled { get; set; } = true; public string Name => "NN XOR"; - public bool ImportGraph { get; set; } = false; + public bool IsImportingGraph { get; set; } = false; public int num_steps = 10000; @@ -54,7 +53,7 @@ namespace TensorFlowNET.Examples { PrepareData(); float loss_value = 0; - if (ImportGraph) + if (IsImportingGraph) loss_value = RunWithImportedGraph(); else loss_value = RunWithBuiltGraph(); @@ -145,12 +144,32 @@ namespace TensorFlowNET.Examples {0, 1 } }; - if (ImportGraph) + if (IsImportingGraph) { // download graph meta data string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/xor.meta"; Web.Download(url, "graph", "xor.meta"); } } + + public Graph ImportGraph() + { + throw new NotImplementedException(); + } + + public Graph BuildGraph() + { + throw new NotImplementedException(); + } + + public bool Train() + { + throw new NotImplementedException(); + } + + public bool Predict() + { + throw new NotImplementedException(); + } } } diff --git a/test/TensorFlowNET.Examples/BasicOperations.cs b/test/TensorFlowNET.Examples/BasicOperations.cs index b2051cf9..ce861df7 100644 --- a/test/TensorFlowNET.Examples/BasicOperations.cs +++ b/test/TensorFlowNET.Examples/BasicOperations.cs @@ -14,10 +14,8 @@ namespace TensorFlowNET.Examples public class BasicOperations : IExample { public bool Enabled { get; set; } = true; - public int Priority => 2; public string Name => "Basic Operations"; - public bool ImportGraph { get; set; } = false; - + public bool IsImportingGraph { get; set; } = false; private Session sess; @@ -104,5 +102,25 @@ namespace TensorFlowNET.Examples public void PrepareData() { } + + public Graph ImportGraph() + { + throw new NotImplementedException(); + } + + public Graph BuildGraph() + { + throw new NotImplementedException(); + } + + public bool Train() + { + throw new NotImplementedException(); + } + + public bool Predict() + { + throw new NotImplementedException(); + } } } diff --git a/test/TensorFlowNET.Examples/HelloWorld.cs b/test/TensorFlowNET.Examples/HelloWorld.cs index f2d9681f..f963e3b3 100644 --- a/test/TensorFlowNET.Examples/HelloWorld.cs +++ b/test/TensorFlowNET.Examples/HelloWorld.cs @@ -12,11 +12,9 @@ namespace TensorFlowNET.Examples /// public class HelloWorld : IExample { - public int Priority => 1; public bool Enabled { get; set; } = true; public string Name => "Hello World"; - public bool ImportGraph { get; set; } = false; - + public bool IsImportingGraph { get; set; } = false; public bool Run() { @@ -41,5 +39,25 @@ namespace TensorFlowNET.Examples public void PrepareData() { } + + public Graph ImportGraph() + { + throw new NotImplementedException(); + } + + public Graph BuildGraph() + { + throw new NotImplementedException(); + } + + public bool Train() + { + throw new NotImplementedException(); + } + + public bool Predict() + { + throw new NotImplementedException(); + } } } diff --git a/test/TensorFlowNET.Examples/IExample.cs b/test/TensorFlowNET.Examples/IExample.cs index ed2521f6..3c9c24df 100644 --- a/test/TensorFlowNET.Examples/IExample.cs +++ b/test/TensorFlowNET.Examples/IExample.cs @@ -1,7 +1,8 @@ using System; using System.Collections.Generic; using System.Text; - +using Tensorflow; + namespace TensorFlowNET.Examples { /// @@ -10,11 +11,6 @@ namespace TensorFlowNET.Examples /// public interface IExample { - /// - /// running order - /// - int Priority { get; } - /// /// True to run example /// @@ -23,15 +19,24 @@ namespace TensorFlowNET.Examples /// /// Set true to import the computation graph instead of building it. /// - bool ImportGraph { get; set; } + bool IsImportingGraph { get; set; } string Name { get; } + bool Run(); + /// /// Build dataflow graph, train and predict /// /// - bool Run(); + bool Train(); + + bool Predict(); + + Graph ImportGraph(); + + Graph BuildGraph(); + /// /// Prepare dataset /// diff --git a/test/TensorFlowNET.Examples/ImageProcess/ImageBackgroundRemoval.cs b/test/TensorFlowNET.Examples/ImageProcess/ImageBackgroundRemoval.cs index 46556989..29e69a8d 100644 --- a/test/TensorFlowNET.Examples/ImageProcess/ImageBackgroundRemoval.cs +++ b/test/TensorFlowNET.Examples/ImageProcess/ImageBackgroundRemoval.cs @@ -15,10 +15,8 @@ namespace TensorFlowNET.Examples.ImageProcess /// public class ImageBackgroundRemoval : IExample { - public int Priority => 15; - public bool Enabled { get; set; } = true; - public bool ImportGraph { get; set; } = true; + public bool IsImportingGraph { get; set; } = true; public string Name => "Image Background Removal"; @@ -59,5 +57,25 @@ namespace TensorFlowNET.Examples.ImageProcess Web.Download(url, modelDir, fileName); Compress.ExtractTGZ(Path.Join(modelDir, fileName), modelDir);*/ } + + public Graph ImportGraph() + { + throw new NotImplementedException(); + } + + public Graph BuildGraph() + { + throw new NotImplementedException(); + } + + public bool Train() + { + throw new NotImplementedException(); + } + + public bool Predict() + { + throw new NotImplementedException(); + } } } diff --git a/test/TensorFlowNET.Examples/ImageProcess/ImageRecognitionInception.cs b/test/TensorFlowNET.Examples/ImageProcess/ImageRecognitionInception.cs index 8fbe6631..b96bf494 100644 --- a/test/TensorFlowNET.Examples/ImageProcess/ImageRecognitionInception.cs +++ b/test/TensorFlowNET.Examples/ImageProcess/ImageRecognitionInception.cs @@ -20,10 +20,9 @@ namespace TensorFlowNET.Examples /// public class ImageRecognitionInception : IExample { - public int Priority => 7; public bool Enabled { get; set; } = true; public string Name => "Image Recognition Inception"; - public bool ImportGraph { get; set; } = false; + public bool IsImportingGraph { get; set; } = false; string dir = "ImageRecognitionInception"; @@ -115,5 +114,25 @@ namespace TensorFlowNET.Examples file_ndarrays.Add(nd); } } + + public Graph ImportGraph() + { + throw new NotImplementedException(); + } + + public Graph BuildGraph() + { + throw new NotImplementedException(); + } + + public bool Train() + { + throw new NotImplementedException(); + } + + public bool Predict() + { + throw new NotImplementedException(); + } } } diff --git a/test/TensorFlowNET.Examples/ImageProcess/InceptionArchGoogLeNet.cs b/test/TensorFlowNET.Examples/ImageProcess/InceptionArchGoogLeNet.cs index e4d1cf2f..63a69425 100644 --- a/test/TensorFlowNET.Examples/ImageProcess/InceptionArchGoogLeNet.cs +++ b/test/TensorFlowNET.Examples/ImageProcess/InceptionArchGoogLeNet.cs @@ -21,9 +21,8 @@ namespace TensorFlowNET.Examples public class InceptionArchGoogLeNet : IExample { public bool Enabled { get; set; } = false; - public int Priority => 100; public string Name => "Inception Arch GoogLeNet"; - public bool ImportGraph { get; set; } = false; + public bool IsImportingGraph { get; set; } = false; string dir = "label_image_data"; @@ -108,5 +107,25 @@ namespace TensorFlowNET.Examples url = $"https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/label_image/data/{pic}"; Utility.Web.Download(url, dir, pic); } + + public Graph ImportGraph() + { + throw new NotImplementedException(); + } + + public Graph BuildGraph() + { + throw new NotImplementedException(); + } + + public bool Train() + { + throw new NotImplementedException(); + } + + public bool Predict() + { + throw new NotImplementedException(); + } } } diff --git a/test/TensorFlowNET.Examples/ImageProcess/ObjectDetection.cs b/test/TensorFlowNET.Examples/ImageProcess/ObjectDetection.cs index 27b232b2..485c5b30 100644 --- a/test/TensorFlowNET.Examples/ImageProcess/ObjectDetection.cs +++ b/test/TensorFlowNET.Examples/ImageProcess/ObjectDetection.cs @@ -16,10 +16,9 @@ namespace TensorFlowNET.Examples public class ObjectDetection : IExample { - public int Priority => 11; public bool Enabled { get; set; } = true; public string Name => "Object Detection"; - public bool ImportGraph { get; set; } = false; + public bool IsImportingGraph { get; set; } = false; public float MIN_SCORE = 0.5f; @@ -145,5 +144,25 @@ namespace TensorFlowNET.Examples } } } + + public Graph ImportGraph() + { + throw new NotImplementedException(); + } + + public Graph BuildGraph() + { + throw new NotImplementedException(); + } + + public bool Train() + { + throw new NotImplementedException(); + } + + public bool Predict() + { + throw new NotImplementedException(); + } } } diff --git a/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs b/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs index 2b2b5364..aaaf6865 100644 --- a/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs +++ b/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs @@ -26,7 +26,7 @@ namespace TensorFlowNET.Examples.ImageProcess public int Priority => 16; public bool Enabled { get; set; } = true; - public bool ImportGraph { get; set; } = true; + public bool IsImportingGraph { get; set; } = true; public string Name => "Retrain Image Classifier"; @@ -667,5 +667,25 @@ namespace TensorFlowNET.Examples.ImageProcess return result; } + + public Graph ImportGraph() + { + throw new NotImplementedException(); + } + + public Graph BuildGraph() + { + throw new NotImplementedException(); + } + + public bool Train() + { + throw new NotImplementedException(); + } + + public bool Predict() + { + throw new NotImplementedException(); + } } } diff --git a/test/TensorFlowNET.Examples/Program.cs b/test/TensorFlowNET.Examples/Program.cs index 9fd9c714..49192704 100644 --- a/test/TensorFlowNET.Examples/Program.cs +++ b/test/TensorFlowNET.Examples/Program.cs @@ -19,7 +19,6 @@ namespace TensorFlowNET.Examples var examples = Assembly.GetEntryAssembly().GetTypes() .Where(x => x.GetInterfaces().Contains(typeof(IExample))) .Select(x => (IExample)Activator.CreateInstance(x)) - .OrderBy(x => x.Priority) .ToArray(); Console.WriteLine($"TensorFlow v{tf.VERSION}", Color.Yellow); @@ -42,18 +41,18 @@ namespace TensorFlowNET.Examples sw.Stop(); if (isSuccess) - success.Add($"Example {example.Priority}: {example.Name} in {sw.Elapsed.TotalSeconds}s"); + success.Add($"Example: {example.Name} in {sw.Elapsed.TotalSeconds}s"); else - errors.Add($"Example {example.Priority}: {example.Name} in {sw.Elapsed.TotalSeconds}s"); + errors.Add($"Example: {example.Name} in {sw.Elapsed.TotalSeconds}s"); } else { - disabled.Add($"Example {example.Priority}: {example.Name} in {sw.ElapsedMilliseconds}ms"); + disabled.Add($"Example: {example.Name} in {sw.ElapsedMilliseconds}ms"); } } catch (Exception ex) { - errors.Add($"Example {example.Priority}: {example.Name}"); + errors.Add($"Example: {example.Name}"); Console.WriteLine(ex); } diff --git a/test/TensorFlowNET.Examples/TextProcess/BinaryTextClassification.cs b/test/TensorFlowNET.Examples/TextProcess/BinaryTextClassification.cs index 63c3aefc..99ac6d94 100644 --- a/test/TensorFlowNET.Examples/TextProcess/BinaryTextClassification.cs +++ b/test/TensorFlowNET.Examples/TextProcess/BinaryTextClassification.cs @@ -17,10 +17,9 @@ namespace TensorFlowNET.Examples /// public class BinaryTextClassification : IExample { - public int Priority => 9; public bool Enabled { get; set; } = true; public string Name => "Binary Text Classification"; - public bool ImportGraph { get; set; } = true; + public bool IsImportingGraph { get; set; } = true; string dir = "binary_text_classification"; string dataFile = "imdb.zip"; @@ -138,5 +137,25 @@ namespace TensorFlowNET.Examples return result; } + + public Graph ImportGraph() + { + throw new NotImplementedException(); + } + + public Graph BuildGraph() + { + throw new NotImplementedException(); + } + + public bool Train() + { + throw new NotImplementedException(); + } + + public bool Predict() + { + throw new NotImplementedException(); + } } } diff --git a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs index aaac8e4a..8bc2f9ef 100644 --- a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs +++ b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs @@ -21,11 +21,10 @@ namespace TensorFlowNET.Examples /// public class CnnTextClassification : IExample { - public int Priority => 17; public bool Enabled { get; set; } = true; public string Name => "CNN Text Classification"; public int? DataLimit = null; - public bool ImportGraph { get; set; } = true; + public bool IsImportingGraph { get; set; } = false; private string dataDir = "word_cnn"; private string dataFileName = "dbpedia_csv.tar.gz"; @@ -49,7 +48,7 @@ namespace TensorFlowNET.Examples var graph = tf.Graph().as_default(); return with(tf.Session(graph), sess => { - if (ImportGraph) + if (IsImportingGraph) return RunWithImportedGraph(sess, graph); else return RunWithBuiltGraph(sess, graph); @@ -222,7 +221,7 @@ namespace TensorFlowNET.Examples Web.Download(url, dataDir, "dbpedia_subset.zip"); Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv")); - if (ImportGraph) + if (IsImportingGraph) { // download graph meta data var meta_file = "word_cnn.meta"; @@ -237,5 +236,25 @@ namespace TensorFlowNET.Examples Web.Download(url, "graph", meta_file); } } + + public Graph ImportGraph() + { + throw new NotImplementedException(); + } + + public Graph BuildGraph() + { + throw new NotImplementedException(); + } + + public bool Train() + { + throw new NotImplementedException(); + } + + public bool Predict() + { + throw new NotImplementedException(); + } } } diff --git a/test/TensorFlowNET.Examples/TextProcess/NER/BiLstmCrfNer.cs b/test/TensorFlowNET.Examples/TextProcess/NER/BiLstmCrfNer.cs index 9f983fca..26626c1a 100644 --- a/test/TensorFlowNET.Examples/TextProcess/NER/BiLstmCrfNer.cs +++ b/test/TensorFlowNET.Examples/TextProcess/NER/BiLstmCrfNer.cs @@ -14,10 +14,8 @@ namespace TensorFlowNET.Examples /// public class BiLstmCrfNer : IExample { - public int Priority => 101; - public bool Enabled { get; set; } = true; - public bool ImportGraph { get; set; } = false; + public bool IsImportingGraph { get; set; } = false; public string Name => "bi-LSTM + CRF NER"; @@ -35,5 +33,25 @@ namespace TensorFlowNET.Examples hp.filepath_tags = Path.Combine(hp.data_root_dir, "vocab.tags.txt"); hp.filepath_glove = Path.Combine(hp.data_root_dir, "glove.npz"); } + + public Graph ImportGraph() + { + throw new NotImplementedException(); + } + + public Graph BuildGraph() + { + throw new NotImplementedException(); + } + + public bool Train() + { + throw new NotImplementedException(); + } + + public bool Predict() + { + throw new NotImplementedException(); + } } } diff --git a/test/TensorFlowNET.Examples/TextProcess/NER/CRF.cs b/test/TensorFlowNET.Examples/TextProcess/NER/CRF.cs index e238c4dc..1fbfddb6 100644 --- a/test/TensorFlowNET.Examples/TextProcess/NER/CRF.cs +++ b/test/TensorFlowNET.Examples/TextProcess/NER/CRF.cs @@ -15,10 +15,8 @@ namespace TensorFlowNET.Examples /// public class CRF : IExample { - public int Priority => 13; - public bool Enabled { get; set; } = true; - public bool ImportGraph { get; set; } = false; + public bool IsImportingGraph { get; set; } = false; public string Name => "CRF"; @@ -31,5 +29,25 @@ namespace TensorFlowNET.Examples { } + + public Graph ImportGraph() + { + throw new NotImplementedException(); + } + + public Graph BuildGraph() + { + throw new NotImplementedException(); + } + + public bool Train() + { + throw new NotImplementedException(); + } + + public bool Predict() + { + throw new NotImplementedException(); + } } } diff --git a/test/TensorFlowNET.Examples/TextProcess/NER/LstmCrfNer.cs b/test/TensorFlowNET.Examples/TextProcess/NER/LstmCrfNer.cs index 524bf6e6..1aa1e44a 100644 --- a/test/TensorFlowNET.Examples/TextProcess/NER/LstmCrfNer.cs +++ b/test/TensorFlowNET.Examples/TextProcess/NER/LstmCrfNer.cs @@ -20,10 +20,8 @@ namespace TensorFlowNET.Examples.Text.NER /// public class LstmCrfNer : IExample { - public int Priority => 14; - public bool Enabled { get; set; } = true; - public bool ImportGraph { get; set; } = true; + public bool IsImportingGraph { get; set; } = true; public string Name => "LSTM + CRF NER"; @@ -208,5 +206,25 @@ namespace TensorFlowNET.Examples.Text.NER Web.Download(url, "graph", meta_file); } + + public Graph ImportGraph() + { + throw new NotImplementedException(); + } + + public Graph BuildGraph() + { + throw new NotImplementedException(); + } + + public bool Train() + { + throw new NotImplementedException(); + } + + public bool Predict() + { + throw new NotImplementedException(); + } } } diff --git a/test/TensorFlowNET.Examples/TextProcess/NamedEntityRecognition.cs b/test/TensorFlowNET.Examples/TextProcess/NamedEntityRecognition.cs index 7e229551..33bee661 100644 --- a/test/TensorFlowNET.Examples/TextProcess/NamedEntityRecognition.cs +++ b/test/TensorFlowNET.Examples/TextProcess/NamedEntityRecognition.cs @@ -11,13 +11,12 @@ namespace TensorFlowNET.Examples /// public class NamedEntityRecognition : IExample { - public int Priority => 100; public bool Enabled { get; set; } = false; public string Name => "NER"; - public bool ImportGraph { get; set; } = false; + public bool IsImportingGraph { get; set; } = false; - public bool Run() + public bool Train() { throw new NotImplementedException(); } @@ -26,5 +25,25 @@ namespace TensorFlowNET.Examples { throw new NotImplementedException(); } + + public Graph ImportGraph() + { + throw new NotImplementedException(); + } + + public Graph BuildGraph() + { + throw new NotImplementedException(); + } + + public bool Run() + { + throw new NotImplementedException(); + } + + public bool Predict() + { + throw new NotImplementedException(); + } } } diff --git a/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs b/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs index 1c2237b2..42841664 100644 --- a/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs +++ b/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs @@ -21,11 +21,10 @@ namespace TensorFlowNET.Examples /// public class TextClassificationTrain : IExample { - public int Priority => 100; public bool Enabled { get; set; } = false; public string Name => "Text Classification"; public int? DataLimit = null; - public bool ImportGraph { get; set; } = true; + public bool IsImportingGraph { get; set; } = true; public bool UseSubset = false; // <----- set this true to use a limited subset of dbpedia private string dataDir = "text_classification"; @@ -51,7 +50,7 @@ namespace TensorFlowNET.Examples var graph = tf.Graph().as_default(); return with(tf.Session(graph), sess => { - if (ImportGraph) + if (IsImportingGraph) return RunWithImportedGraph(sess, graph); else return RunWithBuiltGraph(sess, graph); @@ -255,7 +254,7 @@ namespace TensorFlowNET.Examples Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir); } - if (ImportGraph) + if (IsImportingGraph) { // download graph meta data var meta_file = model_name + ".meta"; @@ -269,6 +268,26 @@ namespace TensorFlowNET.Examples var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file; Web.Download(url, "graph", meta_file); } - } + } + + public Graph ImportGraph() + { + throw new NotImplementedException(); + } + + public Graph BuildGraph() + { + throw new NotImplementedException(); + } + + public bool Train() + { + throw new NotImplementedException(); + } + + public bool Predict() + { + throw new NotImplementedException(); + } } } diff --git a/test/TensorFlowNET.Examples/TextProcess/Word2Vec.cs b/test/TensorFlowNET.Examples/TextProcess/Word2Vec.cs index b4110cc6..dc3f27f3 100644 --- a/test/TensorFlowNET.Examples/TextProcess/Word2Vec.cs +++ b/test/TensorFlowNET.Examples/TextProcess/Word2Vec.cs @@ -16,10 +16,9 @@ namespace TensorFlowNET.Examples /// public class Word2Vec : IExample { - public int Priority => 12; public bool Enabled { get; set; } = true; public string Name => "Word2Vec"; - public bool ImportGraph { get; set; } = true; + public bool IsImportingGraph { get; set; } = true; // Training Parameters float learning_rate = 0.1f; @@ -205,6 +204,26 @@ namespace TensorFlowNET.Examples print($"Most common words: {string.Join(", ", word2id.Take(10))}"); } + public Graph ImportGraph() + { + throw new NotImplementedException(); + } + + public Graph BuildGraph() + { + throw new NotImplementedException(); + } + + public bool Train() + { + throw new NotImplementedException(); + } + + public bool Predict() + { + throw new NotImplementedException(); + } + private class WordId { public string Word { get; set; } diff --git a/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs b/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs index 3fb3ec26..c5497692 100644 --- a/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs +++ b/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs @@ -14,21 +14,21 @@ namespace TensorFlowNET.ExamplesTests public void BasicOperations() { tf.Graph().as_default(); - new BasicOperations() { Enabled = true }.Run(); + new BasicOperations() { Enabled = true }.Train(); } [TestMethod] public void HelloWorld() { tf.Graph().as_default(); - new HelloWorld() { Enabled = true }.Run(); + new HelloWorld() { Enabled = true }.Train(); } [TestMethod] public void ImageRecognition() { tf.Graph().as_default(); - new HelloWorld() { Enabled = true }.Run(); + new HelloWorld() { Enabled = true }.Train(); } [Ignore] @@ -36,28 +36,28 @@ namespace TensorFlowNET.ExamplesTests public void InceptionArchGoogLeNet() { tf.Graph().as_default(); - new InceptionArchGoogLeNet() { Enabled = true }.Run(); + new InceptionArchGoogLeNet() { Enabled = true }.Train(); } [TestMethod] public void KMeansClustering() { tf.Graph().as_default(); - new KMeansClustering() { Enabled = true, ImportGraph = true, train_size = 500, validation_size = 100, test_size = 100, batch_size =100 }.Run(); + new KMeansClustering() { Enabled = true, IsImportingGraph = true, train_size = 500, validation_size = 100, test_size = 100, batch_size =100 }.Train(); } [TestMethod] public void LinearRegression() { tf.Graph().as_default(); - new LinearRegression() { Enabled = true }.Run(); + new LinearRegression() { Enabled = true }.Train(); } [TestMethod] public void LogisticRegression() { tf.Graph().as_default(); - new LogisticRegression() { Enabled = true, training_epochs=10, train_size = 500, validation_size = 100, test_size = 100 }.Run(); + new LogisticRegression() { Enabled = true, training_epochs=10, train_size = 500, validation_size = 100, test_size = 100 }.Train(); } [Ignore] @@ -65,7 +65,7 @@ namespace TensorFlowNET.ExamplesTests public void NaiveBayesClassifier() { tf.Graph().as_default(); - new NaiveBayesClassifier() { Enabled = false }.Run(); + new NaiveBayesClassifier() { Enabled = false }.Train(); } [Ignore] @@ -73,14 +73,14 @@ namespace TensorFlowNET.ExamplesTests public void NamedEntityRecognition() { tf.Graph().as_default(); - new NamedEntityRecognition() { Enabled = true }.Run(); + new NamedEntityRecognition() { Enabled = true }.Train(); } [TestMethod] public void NearestNeighbor() { tf.Graph().as_default(); - new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Run(); + new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Train(); } [Ignore] @@ -88,7 +88,7 @@ namespace TensorFlowNET.ExamplesTests public void TextClassificationTrain() { tf.Graph().as_default(); - new TextClassificationTrain() { Enabled = true, DataLimit=100 }.Run(); + new TextClassificationTrain() { Enabled = true, DataLimit=100 }.Train(); } [Ignore] @@ -96,21 +96,21 @@ namespace TensorFlowNET.ExamplesTests public void TextClassificationWithMovieReviews() { tf.Graph().as_default(); - new BinaryTextClassification() { Enabled = true }.Run(); + new BinaryTextClassification() { Enabled = true }.Train(); } [TestMethod] public void NeuralNetXor() { tf.Graph().as_default(); - Assert.IsTrue(new NeuralNetXor() { Enabled = true, ImportGraph = false }.Run()); + Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = false }.Train()); } [TestMethod] public void NeuralNetXor_ImportedGraph() { tf.Graph().as_default(); - Assert.IsTrue(new NeuralNetXor() { Enabled = true, ImportGraph = true }.Run()); + Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = true }.Train()); } @@ -118,7 +118,7 @@ namespace TensorFlowNET.ExamplesTests public void ObjectDetection() { tf.Graph().as_default(); - Assert.IsTrue(new ObjectDetection() { Enabled = true, ImportGraph = true }.Run()); + Assert.IsTrue(new ObjectDetection() { Enabled = true, IsImportingGraph = true }.Train()); } } }