@@ -11,10 +11,9 @@ namespace TensorFlowNET.Examples | |||
/// </summary> | |||
public class BasicEagerApi : IExample | |||
{ | |||
public int Priority => 100; | |||
public bool Enabled { get; set; } = false; | |||
public string Name => "Basic Eager"; | |||
public bool ImportGraph { get; set; } = false; | |||
public bool IsImportingGraph { get; set; } = false; | |||
private Tensor a, b, c, d; | |||
@@ -46,5 +45,25 @@ namespace TensorFlowNET.Examples | |||
public void PrepareData() | |||
{ | |||
} | |||
public Graph ImportGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public Graph BuildGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Predict() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Train() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
} |
@@ -18,10 +18,9 @@ namespace TensorFlowNET.Examples | |||
/// </summary> | |||
public class KMeansClustering : IExample | |||
{ | |||
public int Priority => 8; | |||
public bool Enabled { get; set; } = true; | |||
public string Name => "K-means Clustering"; | |||
public bool ImportGraph { get; set; } = true; | |||
public bool IsImportingGraph { get; set; } = true; | |||
public int? train_size = null; | |||
public int validation_size = 5000; | |||
@@ -127,5 +126,25 @@ namespace TensorFlowNET.Examples | |||
string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/kmeans.meta"; | |||
Web.Download(url, "graph", "kmeans.meta"); | |||
} | |||
public Graph ImportGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public Graph BuildGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Train() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Predict() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
} |
@@ -13,11 +13,9 @@ namespace TensorFlowNET.Examples | |||
/// </summary> | |||
public class LinearRegression : IExample | |||
{ | |||
public int Priority => 3; | |||
public bool Enabled { get; set; } = true; | |||
public string Name => "Linear Regression"; | |||
public bool ImportGraph { get; set; } = false; | |||
public bool IsImportingGraph { get; set; } = false; | |||
public int training_epochs = 1000; | |||
@@ -113,5 +111,25 @@ namespace TensorFlowNET.Examples | |||
2.827f, 3.465f, 1.65f, 2.904f, 2.42f, 2.94f, 1.3f); | |||
n_samples = train_X.shape[0]; | |||
} | |||
public Graph ImportGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public Graph BuildGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Train() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Predict() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
} |
@@ -18,10 +18,9 @@ namespace TensorFlowNET.Examples | |||
/// </summary> | |||
public class LogisticRegression : IExample | |||
{ | |||
public int Priority => 4; | |||
public bool Enabled { get; set; } = true; | |||
public string Name => "Logistic Regression"; | |||
public bool ImportGraph { get; set; } = false; | |||
public bool IsImportingGraph { get; set; } = false; | |||
public int training_epochs = 10; | |||
@@ -158,5 +157,25 @@ namespace TensorFlowNET.Examples | |||
throw new ValueError("predict error, should be 90% accuracy"); | |||
}); | |||
} | |||
public Graph ImportGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public Graph BuildGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Train() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
bool IExample.Predict() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
} |
@@ -13,10 +13,9 @@ namespace TensorFlowNET.Examples | |||
/// </summary> | |||
public class NaiveBayesClassifier : IExample | |||
{ | |||
public int Priority => 6; | |||
public bool Enabled { get; set; } = true; | |||
public string Name => "Naive Bayes Classifier"; | |||
public bool ImportGraph { get; set; } = false; | |||
public bool IsImportingGraph { get; set; } = false; | |||
public NDArray X, y; | |||
public Normal dist { get; set; } | |||
@@ -96,7 +95,7 @@ namespace TensorFlowNET.Examples | |||
this.dist = dist; | |||
} | |||
public Tensor predict (NDArray X) | |||
public Tensor predict(NDArray X) | |||
{ | |||
if (dist == null) | |||
{ | |||
@@ -170,5 +169,25 @@ namespace TensorFlowNET.Examples | |||
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2); | |||
#endregion | |||
} | |||
public Graph ImportGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public Graph BuildGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Train() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Predict() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
} |
@@ -15,7 +15,6 @@ namespace TensorFlowNET.Examples | |||
/// </summary> | |||
public class NearestNeighbor : IExample | |||
{ | |||
public int Priority => 5; | |||
public bool Enabled { get; set; } = true; | |||
public string Name => "Nearest Neighbor"; | |||
Datasets mnist; | |||
@@ -23,7 +22,7 @@ namespace TensorFlowNET.Examples | |||
public int? TrainSize = null; | |||
public int ValidationSize = 5000; | |||
public int? TestSize = null; | |||
public bool ImportGraph { get; set; } = false; | |||
public bool IsImportingGraph { get; set; } = false; | |||
public bool Run() | |||
@@ -76,5 +75,25 @@ namespace TensorFlowNET.Examples | |||
(Xtr, Ytr) = mnist.train.next_batch(TrainSize==null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates) | |||
(Xte, Yte) = mnist.test.next_batch(TestSize==null ? 200 : TestSize.Value / 100); // 200 for testing | |||
} | |||
public Graph ImportGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public Graph BuildGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Train() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Predict() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
} |
@@ -14,10 +14,9 @@ namespace TensorFlowNET.Examples | |||
/// </summary> | |||
public class NeuralNetXor : IExample | |||
{ | |||
public int Priority => 10; | |||
public bool Enabled { get; set; } = true; | |||
public string Name => "NN XOR"; | |||
public bool ImportGraph { get; set; } = false; | |||
public bool IsImportingGraph { get; set; } = false; | |||
public int num_steps = 10000; | |||
@@ -54,7 +53,7 @@ namespace TensorFlowNET.Examples | |||
{ | |||
PrepareData(); | |||
float loss_value = 0; | |||
if (ImportGraph) | |||
if (IsImportingGraph) | |||
loss_value = RunWithImportedGraph(); | |||
else | |||
loss_value = RunWithBuiltGraph(); | |||
@@ -145,12 +144,32 @@ namespace TensorFlowNET.Examples | |||
{0, 1 } | |||
}; | |||
if (ImportGraph) | |||
if (IsImportingGraph) | |||
{ | |||
// download graph meta data | |||
string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/xor.meta"; | |||
Web.Download(url, "graph", "xor.meta"); | |||
} | |||
} | |||
public Graph ImportGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public Graph BuildGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Train() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Predict() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
} |
@@ -14,10 +14,8 @@ namespace TensorFlowNET.Examples | |||
public class BasicOperations : IExample | |||
{ | |||
public bool Enabled { get; set; } = true; | |||
public int Priority => 2; | |||
public string Name => "Basic Operations"; | |||
public bool ImportGraph { get; set; } = false; | |||
public bool IsImportingGraph { get; set; } = false; | |||
private Session sess; | |||
@@ -104,5 +102,25 @@ namespace TensorFlowNET.Examples | |||
public void PrepareData() | |||
{ | |||
} | |||
public Graph ImportGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public Graph BuildGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Train() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Predict() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
} |
@@ -12,11 +12,9 @@ namespace TensorFlowNET.Examples | |||
/// </summary> | |||
public class HelloWorld : IExample | |||
{ | |||
public int Priority => 1; | |||
public bool Enabled { get; set; } = true; | |||
public string Name => "Hello World"; | |||
public bool ImportGraph { get; set; } = false; | |||
public bool IsImportingGraph { get; set; } = false; | |||
public bool Run() | |||
{ | |||
@@ -41,5 +39,25 @@ namespace TensorFlowNET.Examples | |||
public void PrepareData() | |||
{ | |||
} | |||
public Graph ImportGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public Graph BuildGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Train() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Predict() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
} |
@@ -1,7 +1,8 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow; | |||
namespace TensorFlowNET.Examples | |||
{ | |||
/// <summary> | |||
@@ -10,11 +11,6 @@ namespace TensorFlowNET.Examples | |||
/// </summary> | |||
public interface IExample | |||
{ | |||
/// <summary> | |||
/// running order | |||
/// </summary> | |||
int Priority { get; } | |||
/// <summary> | |||
/// True to run example | |||
/// </summary> | |||
@@ -23,15 +19,24 @@ namespace TensorFlowNET.Examples | |||
/// <summary> | |||
/// Set true to import the computation graph instead of building it. | |||
/// </summary> | |||
bool ImportGraph { get; set; } | |||
bool IsImportingGraph { get; set; } | |||
string Name { get; } | |||
bool Run(); | |||
/// <summary> | |||
/// Build dataflow graph, train and predict | |||
/// </summary> | |||
/// <returns></returns> | |||
bool Run(); | |||
bool Train(); | |||
bool Predict(); | |||
Graph ImportGraph(); | |||
Graph BuildGraph(); | |||
/// <summary> | |||
/// Prepare dataset | |||
/// </summary> | |||
@@ -15,10 +15,8 @@ namespace TensorFlowNET.Examples.ImageProcess | |||
/// </summary> | |||
public class ImageBackgroundRemoval : IExample | |||
{ | |||
public int Priority => 15; | |||
public bool Enabled { get; set; } = true; | |||
public bool ImportGraph { get; set; } = true; | |||
public bool IsImportingGraph { get; set; } = true; | |||
public string Name => "Image Background Removal"; | |||
@@ -59,5 +57,25 @@ namespace TensorFlowNET.Examples.ImageProcess | |||
Web.Download(url, modelDir, fileName); | |||
Compress.ExtractTGZ(Path.Join(modelDir, fileName), modelDir);*/ | |||
} | |||
public Graph ImportGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public Graph BuildGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Train() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Predict() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
} |
@@ -20,10 +20,9 @@ namespace TensorFlowNET.Examples | |||
/// </summary> | |||
public class ImageRecognitionInception : IExample | |||
{ | |||
public int Priority => 7; | |||
public bool Enabled { get; set; } = true; | |||
public string Name => "Image Recognition Inception"; | |||
public bool ImportGraph { get; set; } = false; | |||
public bool IsImportingGraph { get; set; } = false; | |||
string dir = "ImageRecognitionInception"; | |||
@@ -115,5 +114,25 @@ namespace TensorFlowNET.Examples | |||
file_ndarrays.Add(nd); | |||
} | |||
} | |||
public Graph ImportGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public Graph BuildGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Train() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Predict() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
} |
@@ -21,9 +21,8 @@ namespace TensorFlowNET.Examples | |||
public class InceptionArchGoogLeNet : IExample | |||
{ | |||
public bool Enabled { get; set; } = false; | |||
public int Priority => 100; | |||
public string Name => "Inception Arch GoogLeNet"; | |||
public bool ImportGraph { get; set; } = false; | |||
public bool IsImportingGraph { get; set; } = false; | |||
string dir = "label_image_data"; | |||
@@ -108,5 +107,25 @@ namespace TensorFlowNET.Examples | |||
url = $"https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/label_image/data/{pic}"; | |||
Utility.Web.Download(url, dir, pic); | |||
} | |||
public Graph ImportGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public Graph BuildGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Train() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Predict() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
} |
@@ -16,10 +16,9 @@ namespace TensorFlowNET.Examples | |||
public class ObjectDetection : IExample | |||
{ | |||
public int Priority => 11; | |||
public bool Enabled { get; set; } = true; | |||
public string Name => "Object Detection"; | |||
public bool ImportGraph { get; set; } = false; | |||
public bool IsImportingGraph { get; set; } = false; | |||
public float MIN_SCORE = 0.5f; | |||
@@ -145,5 +144,25 @@ namespace TensorFlowNET.Examples | |||
} | |||
} | |||
} | |||
public Graph ImportGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public Graph BuildGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Train() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Predict() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
} |
@@ -26,7 +26,7 @@ namespace TensorFlowNET.Examples.ImageProcess | |||
public int Priority => 16; | |||
public bool Enabled { get; set; } = true; | |||
public bool ImportGraph { get; set; } = true; | |||
public bool IsImportingGraph { get; set; } = true; | |||
public string Name => "Retrain Image Classifier"; | |||
@@ -667,5 +667,25 @@ namespace TensorFlowNET.Examples.ImageProcess | |||
return result; | |||
} | |||
public Graph ImportGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public Graph BuildGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Train() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Predict() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
} |
@@ -19,7 +19,6 @@ namespace TensorFlowNET.Examples | |||
var examples = Assembly.GetEntryAssembly().GetTypes() | |||
.Where(x => x.GetInterfaces().Contains(typeof(IExample))) | |||
.Select(x => (IExample)Activator.CreateInstance(x)) | |||
.OrderBy(x => x.Priority) | |||
.ToArray(); | |||
Console.WriteLine($"TensorFlow v{tf.VERSION}", Color.Yellow); | |||
@@ -42,18 +41,18 @@ namespace TensorFlowNET.Examples | |||
sw.Stop(); | |||
if (isSuccess) | |||
success.Add($"Example {example.Priority}: {example.Name} in {sw.Elapsed.TotalSeconds}s"); | |||
success.Add($"Example: {example.Name} in {sw.Elapsed.TotalSeconds}s"); | |||
else | |||
errors.Add($"Example {example.Priority}: {example.Name} in {sw.Elapsed.TotalSeconds}s"); | |||
errors.Add($"Example: {example.Name} in {sw.Elapsed.TotalSeconds}s"); | |||
} | |||
else | |||
{ | |||
disabled.Add($"Example {example.Priority}: {example.Name} in {sw.ElapsedMilliseconds}ms"); | |||
disabled.Add($"Example: {example.Name} in {sw.ElapsedMilliseconds}ms"); | |||
} | |||
} | |||
catch (Exception ex) | |||
{ | |||
errors.Add($"Example {example.Priority}: {example.Name}"); | |||
errors.Add($"Example: {example.Name}"); | |||
Console.WriteLine(ex); | |||
} | |||
@@ -17,10 +17,9 @@ namespace TensorFlowNET.Examples | |||
/// </summary> | |||
public class BinaryTextClassification : IExample | |||
{ | |||
public int Priority => 9; | |||
public bool Enabled { get; set; } = true; | |||
public string Name => "Binary Text Classification"; | |||
public bool ImportGraph { get; set; } = true; | |||
public bool IsImportingGraph { get; set; } = true; | |||
string dir = "binary_text_classification"; | |||
string dataFile = "imdb.zip"; | |||
@@ -138,5 +137,25 @@ namespace TensorFlowNET.Examples | |||
return result; | |||
} | |||
public Graph ImportGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public Graph BuildGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Train() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Predict() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
} |
@@ -21,11 +21,10 @@ namespace TensorFlowNET.Examples | |||
/// </summary> | |||
public class CnnTextClassification : IExample | |||
{ | |||
public int Priority => 17; | |||
public bool Enabled { get; set; } = true; | |||
public string Name => "CNN Text Classification"; | |||
public int? DataLimit = null; | |||
public bool ImportGraph { get; set; } = true; | |||
public bool IsImportingGraph { get; set; } = false; | |||
private string dataDir = "word_cnn"; | |||
private string dataFileName = "dbpedia_csv.tar.gz"; | |||
@@ -49,7 +48,7 @@ namespace TensorFlowNET.Examples | |||
var graph = tf.Graph().as_default(); | |||
return with(tf.Session(graph), sess => | |||
{ | |||
if (ImportGraph) | |||
if (IsImportingGraph) | |||
return RunWithImportedGraph(sess, graph); | |||
else | |||
return RunWithBuiltGraph(sess, graph); | |||
@@ -222,7 +221,7 @@ namespace TensorFlowNET.Examples | |||
Web.Download(url, dataDir, "dbpedia_subset.zip"); | |||
Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv")); | |||
if (ImportGraph) | |||
if (IsImportingGraph) | |||
{ | |||
// download graph meta data | |||
var meta_file = "word_cnn.meta"; | |||
@@ -237,5 +236,25 @@ namespace TensorFlowNET.Examples | |||
Web.Download(url, "graph", meta_file); | |||
} | |||
} | |||
public Graph ImportGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public Graph BuildGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Train() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Predict() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
} |
@@ -14,10 +14,8 @@ namespace TensorFlowNET.Examples | |||
/// </summary> | |||
public class BiLstmCrfNer : IExample | |||
{ | |||
public int Priority => 101; | |||
public bool Enabled { get; set; } = true; | |||
public bool ImportGraph { get; set; } = false; | |||
public bool IsImportingGraph { get; set; } = false; | |||
public string Name => "bi-LSTM + CRF NER"; | |||
@@ -35,5 +33,25 @@ namespace TensorFlowNET.Examples | |||
hp.filepath_tags = Path.Combine(hp.data_root_dir, "vocab.tags.txt"); | |||
hp.filepath_glove = Path.Combine(hp.data_root_dir, "glove.npz"); | |||
} | |||
public Graph ImportGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public Graph BuildGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Train() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Predict() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
} |
@@ -15,10 +15,8 @@ namespace TensorFlowNET.Examples | |||
/// </summary> | |||
public class CRF : IExample | |||
{ | |||
public int Priority => 13; | |||
public bool Enabled { get; set; } = true; | |||
public bool ImportGraph { get; set; } = false; | |||
public bool IsImportingGraph { get; set; } = false; | |||
public string Name => "CRF"; | |||
@@ -31,5 +29,25 @@ namespace TensorFlowNET.Examples | |||
{ | |||
} | |||
public Graph ImportGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public Graph BuildGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Train() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Predict() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
} |
@@ -20,10 +20,8 @@ namespace TensorFlowNET.Examples.Text.NER | |||
/// </summary> | |||
public class LstmCrfNer : IExample | |||
{ | |||
public int Priority => 14; | |||
public bool Enabled { get; set; } = true; | |||
public bool ImportGraph { get; set; } = true; | |||
public bool IsImportingGraph { get; set; } = true; | |||
public string Name => "LSTM + CRF NER"; | |||
@@ -208,5 +206,25 @@ namespace TensorFlowNET.Examples.Text.NER | |||
Web.Download(url, "graph", meta_file); | |||
} | |||
public Graph ImportGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public Graph BuildGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Train() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Predict() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
} |
@@ -11,13 +11,12 @@ namespace TensorFlowNET.Examples | |||
/// </summary> | |||
public class NamedEntityRecognition : IExample | |||
{ | |||
public int Priority => 100; | |||
public bool Enabled { get; set; } = false; | |||
public string Name => "NER"; | |||
public bool ImportGraph { get; set; } = false; | |||
public bool IsImportingGraph { get; set; } = false; | |||
public bool Run() | |||
public bool Train() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
@@ -26,5 +25,25 @@ namespace TensorFlowNET.Examples | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public Graph ImportGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public Graph BuildGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Run() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Predict() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
} |
@@ -21,11 +21,10 @@ namespace TensorFlowNET.Examples | |||
/// </summary> | |||
public class TextClassificationTrain : IExample | |||
{ | |||
public int Priority => 100; | |||
public bool Enabled { get; set; } = false; | |||
public string Name => "Text Classification"; | |||
public int? DataLimit = null; | |||
public bool ImportGraph { get; set; } = true; | |||
public bool IsImportingGraph { get; set; } = true; | |||
public bool UseSubset = false; // <----- set this true to use a limited subset of dbpedia | |||
private string dataDir = "text_classification"; | |||
@@ -51,7 +50,7 @@ namespace TensorFlowNET.Examples | |||
var graph = tf.Graph().as_default(); | |||
return with(tf.Session(graph), sess => | |||
{ | |||
if (ImportGraph) | |||
if (IsImportingGraph) | |||
return RunWithImportedGraph(sess, graph); | |||
else | |||
return RunWithBuiltGraph(sess, graph); | |||
@@ -255,7 +254,7 @@ namespace TensorFlowNET.Examples | |||
Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir); | |||
} | |||
if (ImportGraph) | |||
if (IsImportingGraph) | |||
{ | |||
// download graph meta data | |||
var meta_file = model_name + ".meta"; | |||
@@ -269,6 +268,26 @@ namespace TensorFlowNET.Examples | |||
var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file; | |||
Web.Download(url, "graph", meta_file); | |||
} | |||
} | |||
} | |||
public Graph ImportGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public Graph BuildGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Train() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Predict() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
} |
@@ -16,10 +16,9 @@ namespace TensorFlowNET.Examples | |||
/// </summary> | |||
public class Word2Vec : IExample | |||
{ | |||
public int Priority => 12; | |||
public bool Enabled { get; set; } = true; | |||
public string Name => "Word2Vec"; | |||
public bool ImportGraph { get; set; } = true; | |||
public bool IsImportingGraph { get; set; } = true; | |||
// Training Parameters | |||
float learning_rate = 0.1f; | |||
@@ -205,6 +204,26 @@ namespace TensorFlowNET.Examples | |||
print($"Most common words: {string.Join(", ", word2id.Take(10))}"); | |||
} | |||
public Graph ImportGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public Graph BuildGraph() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Train() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public bool Predict() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
private class WordId | |||
{ | |||
public string Word { get; set; } | |||
@@ -14,21 +14,21 @@ namespace TensorFlowNET.ExamplesTests | |||
public void BasicOperations() | |||
{ | |||
tf.Graph().as_default(); | |||
new BasicOperations() { Enabled = true }.Run(); | |||
new BasicOperations() { Enabled = true }.Train(); | |||
} | |||
[TestMethod] | |||
public void HelloWorld() | |||
{ | |||
tf.Graph().as_default(); | |||
new HelloWorld() { Enabled = true }.Run(); | |||
new HelloWorld() { Enabled = true }.Train(); | |||
} | |||
[TestMethod] | |||
public void ImageRecognition() | |||
{ | |||
tf.Graph().as_default(); | |||
new HelloWorld() { Enabled = true }.Run(); | |||
new HelloWorld() { Enabled = true }.Train(); | |||
} | |||
[Ignore] | |||
@@ -36,28 +36,28 @@ namespace TensorFlowNET.ExamplesTests | |||
public void InceptionArchGoogLeNet() | |||
{ | |||
tf.Graph().as_default(); | |||
new InceptionArchGoogLeNet() { Enabled = true }.Run(); | |||
new InceptionArchGoogLeNet() { Enabled = true }.Train(); | |||
} | |||
[TestMethod] | |||
public void KMeansClustering() | |||
{ | |||
tf.Graph().as_default(); | |||
new KMeansClustering() { Enabled = true, ImportGraph = true, train_size = 500, validation_size = 100, test_size = 100, batch_size =100 }.Run(); | |||
new KMeansClustering() { Enabled = true, IsImportingGraph = true, train_size = 500, validation_size = 100, test_size = 100, batch_size =100 }.Train(); | |||
} | |||
[TestMethod] | |||
public void LinearRegression() | |||
{ | |||
tf.Graph().as_default(); | |||
new LinearRegression() { Enabled = true }.Run(); | |||
new LinearRegression() { Enabled = true }.Train(); | |||
} | |||
[TestMethod] | |||
public void LogisticRegression() | |||
{ | |||
tf.Graph().as_default(); | |||
new LogisticRegression() { Enabled = true, training_epochs=10, train_size = 500, validation_size = 100, test_size = 100 }.Run(); | |||
new LogisticRegression() { Enabled = true, training_epochs=10, train_size = 500, validation_size = 100, test_size = 100 }.Train(); | |||
} | |||
[Ignore] | |||
@@ -65,7 +65,7 @@ namespace TensorFlowNET.ExamplesTests | |||
public void NaiveBayesClassifier() | |||
{ | |||
tf.Graph().as_default(); | |||
new NaiveBayesClassifier() { Enabled = false }.Run(); | |||
new NaiveBayesClassifier() { Enabled = false }.Train(); | |||
} | |||
[Ignore] | |||
@@ -73,14 +73,14 @@ namespace TensorFlowNET.ExamplesTests | |||
public void NamedEntityRecognition() | |||
{ | |||
tf.Graph().as_default(); | |||
new NamedEntityRecognition() { Enabled = true }.Run(); | |||
new NamedEntityRecognition() { Enabled = true }.Train(); | |||
} | |||
[TestMethod] | |||
public void NearestNeighbor() | |||
{ | |||
tf.Graph().as_default(); | |||
new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Run(); | |||
new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Train(); | |||
} | |||
[Ignore] | |||
@@ -88,7 +88,7 @@ namespace TensorFlowNET.ExamplesTests | |||
public void TextClassificationTrain() | |||
{ | |||
tf.Graph().as_default(); | |||
new TextClassificationTrain() { Enabled = true, DataLimit=100 }.Run(); | |||
new TextClassificationTrain() { Enabled = true, DataLimit=100 }.Train(); | |||
} | |||
[Ignore] | |||
@@ -96,21 +96,21 @@ namespace TensorFlowNET.ExamplesTests | |||
public void TextClassificationWithMovieReviews() | |||
{ | |||
tf.Graph().as_default(); | |||
new BinaryTextClassification() { Enabled = true }.Run(); | |||
new BinaryTextClassification() { Enabled = true }.Train(); | |||
} | |||
[TestMethod] | |||
public void NeuralNetXor() | |||
{ | |||
tf.Graph().as_default(); | |||
Assert.IsTrue(new NeuralNetXor() { Enabled = true, ImportGraph = false }.Run()); | |||
Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = false }.Train()); | |||
} | |||
[TestMethod] | |||
public void NeuralNetXor_ImportedGraph() | |||
{ | |||
tf.Graph().as_default(); | |||
Assert.IsTrue(new NeuralNetXor() { Enabled = true, ImportGraph = true }.Run()); | |||
Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = true }.Train()); | |||
} | |||
@@ -118,7 +118,7 @@ namespace TensorFlowNET.ExamplesTests | |||
public void ObjectDetection() | |||
{ | |||
tf.Graph().as_default(); | |||
Assert.IsTrue(new ObjectDetection() { Enabled = true, ImportGraph = true }.Run()); | |||
Assert.IsTrue(new ObjectDetection() { Enabled = true, IsImportingGraph = true }.Train()); | |||
} | |||
} | |||
} |