Browse Source

add train and predict interfaces to IExample.

tags/v0.9
Oceania2018 6 years ago
parent
commit
c42f8bbf88
25 changed files with 503 additions and 87 deletions
  1. +21
    -2
      test/TensorFlowNET.Examples/BasicEagerApi.cs
  2. +21
    -2
      test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs
  3. +21
    -3
      test/TensorFlowNET.Examples/BasicModels/LinearRegression.cs
  4. +21
    -2
      test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs
  5. +22
    -3
      test/TensorFlowNET.Examples/BasicModels/NaiveBayesClassifier.cs
  6. +21
    -2
      test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs
  7. +23
    -4
      test/TensorFlowNET.Examples/BasicModels/NeuralNetXor.cs
  8. +21
    -3
      test/TensorFlowNET.Examples/BasicOperations.cs
  9. +21
    -3
      test/TensorFlowNET.Examples/HelloWorld.cs
  10. +13
    -8
      test/TensorFlowNET.Examples/IExample.cs
  11. +21
    -3
      test/TensorFlowNET.Examples/ImageProcess/ImageBackgroundRemoval.cs
  12. +21
    -2
      test/TensorFlowNET.Examples/ImageProcess/ImageRecognitionInception.cs
  13. +21
    -2
      test/TensorFlowNET.Examples/ImageProcess/InceptionArchGoogLeNet.cs
  14. +21
    -2
      test/TensorFlowNET.Examples/ImageProcess/ObjectDetection.cs
  15. +21
    -1
      test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs
  16. +4
    -5
      test/TensorFlowNET.Examples/Program.cs
  17. +21
    -2
      test/TensorFlowNET.Examples/TextProcess/BinaryTextClassification.cs
  18. +23
    -4
      test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs
  19. +21
    -3
      test/TensorFlowNET.Examples/TextProcess/NER/BiLstmCrfNer.cs
  20. +21
    -3
      test/TensorFlowNET.Examples/TextProcess/NER/CRF.cs
  21. +21
    -3
      test/TensorFlowNET.Examples/TextProcess/NER/LstmCrfNer.cs
  22. +22
    -3
      test/TensorFlowNET.Examples/TextProcess/NamedEntityRecognition.cs
  23. +24
    -5
      test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs
  24. +21
    -2
      test/TensorFlowNET.Examples/TextProcess/Word2Vec.cs
  25. +15
    -15
      test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs

+ 21
- 2
test/TensorFlowNET.Examples/BasicEagerApi.cs View File

@@ -11,10 +11,9 @@ namespace TensorFlowNET.Examples
/// </summary>
public class BasicEagerApi : IExample
{
public int Priority => 100;
public bool Enabled { get; set; } = false;
public string Name => "Basic Eager";
public bool ImportGraph { get; set; } = false;
public bool IsImportingGraph { get; set; } = false;

private Tensor a, b, c, d;

@@ -46,5 +45,25 @@ namespace TensorFlowNET.Examples
public void PrepareData()
{
}

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}
}
}

+ 21
- 2
test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs View File

@@ -18,10 +18,9 @@ namespace TensorFlowNET.Examples
/// </summary>
public class KMeansClustering : IExample
{
public int Priority => 8;
public bool Enabled { get; set; } = true;
public string Name => "K-means Clustering";
public bool ImportGraph { get; set; } = true;
public bool IsImportingGraph { get; set; } = true;

public int? train_size = null;
public int validation_size = 5000;
@@ -127,5 +126,25 @@ namespace TensorFlowNET.Examples
string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/kmeans.meta";
Web.Download(url, "graph", "kmeans.meta");
}

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
}
}

+ 21
- 3
test/TensorFlowNET.Examples/BasicModels/LinearRegression.cs View File

@@ -13,11 +13,9 @@ namespace TensorFlowNET.Examples
/// </summary>
public class LinearRegression : IExample
{
public int Priority => 3;
public bool Enabled { get; set; } = true;
public string Name => "Linear Regression";
public bool ImportGraph { get; set; } = false;

public bool IsImportingGraph { get; set; } = false;

public int training_epochs = 1000;

@@ -113,5 +111,25 @@ namespace TensorFlowNET.Examples
2.827f, 3.465f, 1.65f, 2.904f, 2.42f, 2.94f, 1.3f);
n_samples = train_X.shape[0];
}

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
}
}

+ 21
- 2
test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs View File

@@ -18,10 +18,9 @@ namespace TensorFlowNET.Examples
/// </summary>
public class LogisticRegression : IExample
{
public int Priority => 4;
public bool Enabled { get; set; } = true;
public string Name => "Logistic Regression";
public bool ImportGraph { get; set; } = false;
public bool IsImportingGraph { get; set; } = false;


public int training_epochs = 10;
@@ -158,5 +157,25 @@ namespace TensorFlowNET.Examples
throw new ValueError("predict error, should be 90% accuracy");
});
}

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

bool IExample.Predict()
{
throw new NotImplementedException();
}
}
}

+ 22
- 3
test/TensorFlowNET.Examples/BasicModels/NaiveBayesClassifier.cs View File

@@ -13,10 +13,9 @@ namespace TensorFlowNET.Examples
/// </summary>
public class NaiveBayesClassifier : IExample
{
public int Priority => 6;
public bool Enabled { get; set; } = true;
public string Name => "Naive Bayes Classifier";
public bool ImportGraph { get; set; } = false;
public bool IsImportingGraph { get; set; } = false;

public NDArray X, y;
public Normal dist { get; set; }
@@ -96,7 +95,7 @@ namespace TensorFlowNET.Examples
this.dist = dist;
}

public Tensor predict (NDArray X)
public Tensor predict(NDArray X)
{
if (dist == null)
{
@@ -170,5 +169,25 @@ namespace TensorFlowNET.Examples
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2);
#endregion
}

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
}
}

+ 21
- 2
test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs View File

@@ -15,7 +15,6 @@ namespace TensorFlowNET.Examples
/// </summary>
public class NearestNeighbor : IExample
{
public int Priority => 5;
public bool Enabled { get; set; } = true;
public string Name => "Nearest Neighbor";
Datasets mnist;
@@ -23,7 +22,7 @@ namespace TensorFlowNET.Examples
public int? TrainSize = null;
public int ValidationSize = 5000;
public int? TestSize = null;
public bool ImportGraph { get; set; } = false;
public bool IsImportingGraph { get; set; } = false;


public bool Run()
@@ -76,5 +75,25 @@ namespace TensorFlowNET.Examples
(Xtr, Ytr) = mnist.train.next_batch(TrainSize==null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates)
(Xte, Yte) = mnist.test.next_batch(TestSize==null ? 200 : TestSize.Value / 100); // 200 for testing
}

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
}
}

+ 23
- 4
test/TensorFlowNET.Examples/BasicModels/NeuralNetXor.cs View File

@@ -14,10 +14,9 @@ namespace TensorFlowNET.Examples
/// </summary>
public class NeuralNetXor : IExample
{
public int Priority => 10;
public bool Enabled { get; set; } = true;
public string Name => "NN XOR";
public bool ImportGraph { get; set; } = false;
public bool IsImportingGraph { get; set; } = false;

public int num_steps = 10000;

@@ -54,7 +53,7 @@ namespace TensorFlowNET.Examples
{
PrepareData();
float loss_value = 0;
if (ImportGraph)
if (IsImportingGraph)
loss_value = RunWithImportedGraph();
else
loss_value = RunWithBuiltGraph();
@@ -145,12 +144,32 @@ namespace TensorFlowNET.Examples
{0, 1 }
};

if (ImportGraph)
if (IsImportingGraph)
{
// download graph meta data
string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/xor.meta";
Web.Download(url, "graph", "xor.meta");
}
}

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
}
}

+ 21
- 3
test/TensorFlowNET.Examples/BasicOperations.cs View File

@@ -14,10 +14,8 @@ namespace TensorFlowNET.Examples
public class BasicOperations : IExample
{
public bool Enabled { get; set; } = true;
public int Priority => 2;
public string Name => "Basic Operations";
public bool ImportGraph { get; set; } = false;

public bool IsImportingGraph { get; set; } = false;

private Session sess;

@@ -104,5 +102,25 @@ namespace TensorFlowNET.Examples
public void PrepareData()
{
}

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
}
}

+ 21
- 3
test/TensorFlowNET.Examples/HelloWorld.cs View File

@@ -12,11 +12,9 @@ namespace TensorFlowNET.Examples
/// </summary>
public class HelloWorld : IExample
{
public int Priority => 1;
public bool Enabled { get; set; } = true;
public string Name => "Hello World";
public bool ImportGraph { get; set; } = false;

public bool IsImportingGraph { get; set; } = false;

public bool Run()
{
@@ -41,5 +39,25 @@ namespace TensorFlowNET.Examples
public void PrepareData()
{
}

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
}
}

+ 13
- 8
test/TensorFlowNET.Examples/IExample.cs View File

@@ -1,7 +1,8 @@
using System;
using System.Collections.Generic;
using System.Text;

using Tensorflow;
namespace TensorFlowNET.Examples
{
/// <summary>
@@ -10,11 +11,6 @@ namespace TensorFlowNET.Examples
/// </summary>
public interface IExample
{
/// <summary>
/// running order
/// </summary>
int Priority { get; }

/// <summary>
/// True to run example
/// </summary>
@@ -23,15 +19,24 @@ namespace TensorFlowNET.Examples
/// <summary>
/// Set true to import the computation graph instead of building it.
/// </summary>
bool ImportGraph { get; set; }
bool IsImportingGraph { get; set; }

string Name { get; }

bool Run();

/// <summary>
/// Build dataflow graph, train and predict
/// </summary>
/// <returns></returns>
bool Run();
bool Train();

bool Predict();

Graph ImportGraph();

Graph BuildGraph();

/// <summary>
/// Prepare dataset
/// </summary>


+ 21
- 3
test/TensorFlowNET.Examples/ImageProcess/ImageBackgroundRemoval.cs View File

@@ -15,10 +15,8 @@ namespace TensorFlowNET.Examples.ImageProcess
/// </summary>
public class ImageBackgroundRemoval : IExample
{
public int Priority => 15;

public bool Enabled { get; set; } = true;
public bool ImportGraph { get; set; } = true;
public bool IsImportingGraph { get; set; } = true;

public string Name => "Image Background Removal";

@@ -59,5 +57,25 @@ namespace TensorFlowNET.Examples.ImageProcess
Web.Download(url, modelDir, fileName);
Compress.ExtractTGZ(Path.Join(modelDir, fileName), modelDir);*/
}

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
}
}

+ 21
- 2
test/TensorFlowNET.Examples/ImageProcess/ImageRecognitionInception.cs View File

@@ -20,10 +20,9 @@ namespace TensorFlowNET.Examples
/// </summary>
public class ImageRecognitionInception : IExample
{
public int Priority => 7;
public bool Enabled { get; set; } = true;
public string Name => "Image Recognition Inception";
public bool ImportGraph { get; set; } = false;
public bool IsImportingGraph { get; set; } = false;


string dir = "ImageRecognitionInception";
@@ -115,5 +114,25 @@ namespace TensorFlowNET.Examples
file_ndarrays.Add(nd);
}
}

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
}
}

+ 21
- 2
test/TensorFlowNET.Examples/ImageProcess/InceptionArchGoogLeNet.cs View File

@@ -21,9 +21,8 @@ namespace TensorFlowNET.Examples
public class InceptionArchGoogLeNet : IExample
{
public bool Enabled { get; set; } = false;
public int Priority => 100;
public string Name => "Inception Arch GoogLeNet";
public bool ImportGraph { get; set; } = false;
public bool IsImportingGraph { get; set; } = false;


string dir = "label_image_data";
@@ -108,5 +107,25 @@ namespace TensorFlowNET.Examples
url = $"https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/label_image/data/{pic}";
Utility.Web.Download(url, dir, pic);
}

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
}
}

+ 21
- 2
test/TensorFlowNET.Examples/ImageProcess/ObjectDetection.cs View File

@@ -16,10 +16,9 @@ namespace TensorFlowNET.Examples

public class ObjectDetection : IExample
{
public int Priority => 11;
public bool Enabled { get; set; } = true;
public string Name => "Object Detection";
public bool ImportGraph { get; set; } = false;
public bool IsImportingGraph { get; set; } = false;

public float MIN_SCORE = 0.5f;

@@ -145,5 +144,25 @@ namespace TensorFlowNET.Examples
}
}
}

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
}
}

+ 21
- 1
test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs View File

@@ -26,7 +26,7 @@ namespace TensorFlowNET.Examples.ImageProcess
public int Priority => 16;

public bool Enabled { get; set; } = true;
public bool ImportGraph { get; set; } = true;
public bool IsImportingGraph { get; set; } = true;

public string Name => "Retrain Image Classifier";

@@ -667,5 +667,25 @@ namespace TensorFlowNET.Examples.ImageProcess

return result;
}

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
}
}

+ 4
- 5
test/TensorFlowNET.Examples/Program.cs View File

@@ -19,7 +19,6 @@ namespace TensorFlowNET.Examples
var examples = Assembly.GetEntryAssembly().GetTypes()
.Where(x => x.GetInterfaces().Contains(typeof(IExample)))
.Select(x => (IExample)Activator.CreateInstance(x))
.OrderBy(x => x.Priority)
.ToArray();

Console.WriteLine($"TensorFlow v{tf.VERSION}", Color.Yellow);
@@ -42,18 +41,18 @@ namespace TensorFlowNET.Examples
sw.Stop();

if (isSuccess)
success.Add($"Example {example.Priority}: {example.Name} in {sw.Elapsed.TotalSeconds}s");
success.Add($"Example: {example.Name} in {sw.Elapsed.TotalSeconds}s");
else
errors.Add($"Example {example.Priority}: {example.Name} in {sw.Elapsed.TotalSeconds}s");
errors.Add($"Example: {example.Name} in {sw.Elapsed.TotalSeconds}s");
}
else
{
disabled.Add($"Example {example.Priority}: {example.Name} in {sw.ElapsedMilliseconds}ms");
disabled.Add($"Example: {example.Name} in {sw.ElapsedMilliseconds}ms");
}
}
catch (Exception ex)
{
errors.Add($"Example {example.Priority}: {example.Name}");
errors.Add($"Example: {example.Name}");
Console.WriteLine(ex);
}


+ 21
- 2
test/TensorFlowNET.Examples/TextProcess/BinaryTextClassification.cs View File

@@ -17,10 +17,9 @@ namespace TensorFlowNET.Examples
/// </summary>
public class BinaryTextClassification : IExample
{
public int Priority => 9;
public bool Enabled { get; set; } = true;
public string Name => "Binary Text Classification";
public bool ImportGraph { get; set; } = true;
public bool IsImportingGraph { get; set; } = true;

string dir = "binary_text_classification";
string dataFile = "imdb.zip";
@@ -138,5 +137,25 @@ namespace TensorFlowNET.Examples

return result;
}

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
}
}

+ 23
- 4
test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs View File

@@ -21,11 +21,10 @@ namespace TensorFlowNET.Examples
/// </summary>
public class CnnTextClassification : IExample
{
public int Priority => 17;
public bool Enabled { get; set; } = true;
public string Name => "CNN Text Classification";
public int? DataLimit = null;
public bool ImportGraph { get; set; } = true;
public bool IsImportingGraph { get; set; } = false;

private string dataDir = "word_cnn";
private string dataFileName = "dbpedia_csv.tar.gz";
@@ -49,7 +48,7 @@ namespace TensorFlowNET.Examples
var graph = tf.Graph().as_default();
return with(tf.Session(graph), sess =>
{
if (ImportGraph)
if (IsImportingGraph)
return RunWithImportedGraph(sess, graph);
else
return RunWithBuiltGraph(sess, graph);
@@ -222,7 +221,7 @@ namespace TensorFlowNET.Examples
Web.Download(url, dataDir, "dbpedia_subset.zip");
Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv"));

if (ImportGraph)
if (IsImportingGraph)
{
// download graph meta data
var meta_file = "word_cnn.meta";
@@ -237,5 +236,25 @@ namespace TensorFlowNET.Examples
Web.Download(url, "graph", meta_file);
}
}

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
}
}

+ 21
- 3
test/TensorFlowNET.Examples/TextProcess/NER/BiLstmCrfNer.cs View File

@@ -14,10 +14,8 @@ namespace TensorFlowNET.Examples
/// </summary>
public class BiLstmCrfNer : IExample
{
public int Priority => 101;

public bool Enabled { get; set; } = true;
public bool ImportGraph { get; set; } = false;
public bool IsImportingGraph { get; set; } = false;

public string Name => "bi-LSTM + CRF NER";

@@ -35,5 +33,25 @@ namespace TensorFlowNET.Examples
hp.filepath_tags = Path.Combine(hp.data_root_dir, "vocab.tags.txt");
hp.filepath_glove = Path.Combine(hp.data_root_dir, "glove.npz");
}

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
}
}

+ 21
- 3
test/TensorFlowNET.Examples/TextProcess/NER/CRF.cs View File

@@ -15,10 +15,8 @@ namespace TensorFlowNET.Examples
/// </summary>
public class CRF : IExample
{
public int Priority => 13;

public bool Enabled { get; set; } = true;
public bool ImportGraph { get; set; } = false;
public bool IsImportingGraph { get; set; } = false;

public string Name => "CRF";

@@ -31,5 +29,25 @@ namespace TensorFlowNET.Examples
{

}

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
}
}

+ 21
- 3
test/TensorFlowNET.Examples/TextProcess/NER/LstmCrfNer.cs View File

@@ -20,10 +20,8 @@ namespace TensorFlowNET.Examples.Text.NER
/// </summary>
public class LstmCrfNer : IExample
{
public int Priority => 14;

public bool Enabled { get; set; } = true;
public bool ImportGraph { get; set; } = true;
public bool IsImportingGraph { get; set; } = true;

public string Name => "LSTM + CRF NER";

@@ -208,5 +206,25 @@ namespace TensorFlowNET.Examples.Text.NER
Web.Download(url, "graph", meta_file);

}

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
}
}

+ 22
- 3
test/TensorFlowNET.Examples/TextProcess/NamedEntityRecognition.cs View File

@@ -11,13 +11,12 @@ namespace TensorFlowNET.Examples
/// </summary>
public class NamedEntityRecognition : IExample
{
public int Priority => 100;
public bool Enabled { get; set; } = false;
public string Name => "NER";
public bool ImportGraph { get; set; } = false;
public bool IsImportingGraph { get; set; } = false;


public bool Run()
public bool Train()
{
throw new NotImplementedException();
}
@@ -26,5 +25,25 @@ namespace TensorFlowNET.Examples
{
throw new NotImplementedException();
}

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Run()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
}
}

+ 24
- 5
test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs View File

@@ -21,11 +21,10 @@ namespace TensorFlowNET.Examples
/// </summary>
public class TextClassificationTrain : IExample
{
public int Priority => 100;
public bool Enabled { get; set; } = false;
public string Name => "Text Classification";
public int? DataLimit = null;
public bool ImportGraph { get; set; } = true;
public bool IsImportingGraph { get; set; } = true;
public bool UseSubset = false; // <----- set this true to use a limited subset of dbpedia

private string dataDir = "text_classification";
@@ -51,7 +50,7 @@ namespace TensorFlowNET.Examples
var graph = tf.Graph().as_default();
return with(tf.Session(graph), sess =>
{
if (ImportGraph)
if (IsImportingGraph)
return RunWithImportedGraph(sess, graph);
else
return RunWithBuiltGraph(sess, graph);
@@ -255,7 +254,7 @@ namespace TensorFlowNET.Examples
Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir);
}

if (ImportGraph)
if (IsImportingGraph)
{
// download graph meta data
var meta_file = model_name + ".meta";
@@ -269,6 +268,26 @@ namespace TensorFlowNET.Examples
var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file;
Web.Download(url, "graph", meta_file);
}
}
}
public Graph ImportGraph()
{
throw new NotImplementedException();
}
public Graph BuildGraph()
{
throw new NotImplementedException();
}
public bool Train()
{
throw new NotImplementedException();
}
public bool Predict()
{
throw new NotImplementedException();
}
}
}

+ 21
- 2
test/TensorFlowNET.Examples/TextProcess/Word2Vec.cs View File

@@ -16,10 +16,9 @@ namespace TensorFlowNET.Examples
/// </summary>
public class Word2Vec : IExample
{
public int Priority => 12;
public bool Enabled { get; set; } = true;
public string Name => "Word2Vec";
public bool ImportGraph { get; set; } = true;
public bool IsImportingGraph { get; set; } = true;

// Training Parameters
float learning_rate = 0.1f;
@@ -205,6 +204,26 @@ namespace TensorFlowNET.Examples
print($"Most common words: {string.Join(", ", word2id.Take(10))}");
}

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}

private class WordId
{
public string Word { get; set; }


+ 15
- 15
test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs View File

@@ -14,21 +14,21 @@ namespace TensorFlowNET.ExamplesTests
public void BasicOperations()
{
tf.Graph().as_default();
new BasicOperations() { Enabled = true }.Run();
new BasicOperations() { Enabled = true }.Train();
}
[TestMethod]
public void HelloWorld()
{
tf.Graph().as_default();
new HelloWorld() { Enabled = true }.Run();
new HelloWorld() { Enabled = true }.Train();
}
[TestMethod]
public void ImageRecognition()
{
tf.Graph().as_default();
new HelloWorld() { Enabled = true }.Run();
new HelloWorld() { Enabled = true }.Train();
}
[Ignore]
@@ -36,28 +36,28 @@ namespace TensorFlowNET.ExamplesTests
public void InceptionArchGoogLeNet()
{
tf.Graph().as_default();
new InceptionArchGoogLeNet() { Enabled = true }.Run();
new InceptionArchGoogLeNet() { Enabled = true }.Train();
}
[TestMethod]
public void KMeansClustering()
{
tf.Graph().as_default();
new KMeansClustering() { Enabled = true, ImportGraph = true, train_size = 500, validation_size = 100, test_size = 100, batch_size =100 }.Run();
new KMeansClustering() { Enabled = true, IsImportingGraph = true, train_size = 500, validation_size = 100, test_size = 100, batch_size =100 }.Train();
}
[TestMethod]
public void LinearRegression()
{
tf.Graph().as_default();
new LinearRegression() { Enabled = true }.Run();
new LinearRegression() { Enabled = true }.Train();
}
[TestMethod]
public void LogisticRegression()
{
tf.Graph().as_default();
new LogisticRegression() { Enabled = true, training_epochs=10, train_size = 500, validation_size = 100, test_size = 100 }.Run();
new LogisticRegression() { Enabled = true, training_epochs=10, train_size = 500, validation_size = 100, test_size = 100 }.Train();
}
[Ignore]
@@ -65,7 +65,7 @@ namespace TensorFlowNET.ExamplesTests
public void NaiveBayesClassifier()
{
tf.Graph().as_default();
new NaiveBayesClassifier() { Enabled = false }.Run();
new NaiveBayesClassifier() { Enabled = false }.Train();
}
[Ignore]
@@ -73,14 +73,14 @@ namespace TensorFlowNET.ExamplesTests
public void NamedEntityRecognition()
{
tf.Graph().as_default();
new NamedEntityRecognition() { Enabled = true }.Run();
new NamedEntityRecognition() { Enabled = true }.Train();
}
[TestMethod]
public void NearestNeighbor()
{
tf.Graph().as_default();
new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Run();
new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Train();
}
[Ignore]
@@ -88,7 +88,7 @@ namespace TensorFlowNET.ExamplesTests
public void TextClassificationTrain()
{
tf.Graph().as_default();
new TextClassificationTrain() { Enabled = true, DataLimit=100 }.Run();
new TextClassificationTrain() { Enabled = true, DataLimit=100 }.Train();
}
[Ignore]
@@ -96,21 +96,21 @@ namespace TensorFlowNET.ExamplesTests
public void TextClassificationWithMovieReviews()
{
tf.Graph().as_default();
new BinaryTextClassification() { Enabled = true }.Run();
new BinaryTextClassification() { Enabled = true }.Train();
}
[TestMethod]
public void NeuralNetXor()
{
tf.Graph().as_default();
Assert.IsTrue(new NeuralNetXor() { Enabled = true, ImportGraph = false }.Run());
Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = false }.Train());
}
[TestMethod]
public void NeuralNetXor_ImportedGraph()
{
tf.Graph().as_default();
Assert.IsTrue(new NeuralNetXor() { Enabled = true, ImportGraph = true }.Run());
Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = true }.Train());
}
@@ -118,7 +118,7 @@ namespace TensorFlowNET.ExamplesTests
public void ObjectDetection()
{
tf.Graph().as_default();
Assert.IsTrue(new ObjectDetection() { Enabled = true, ImportGraph = true }.Run());
Assert.IsTrue(new ObjectDetection() { Enabled = true, IsImportingGraph = true }.Train());
}
}
}

Loading…
Cancel
Save