@@ -41,7 +41,11 @@ namespace Tensorflow | |||||
public void close_variable_subscopes(string scope_name) | public void close_variable_subscopes(string scope_name) | ||||
{ | { | ||||
var variable_scopes_count_tmp = new Dictionary<string, int>(); | |||||
foreach (var k in variable_scopes_count.Keys) | foreach (var k in variable_scopes_count.Keys) | ||||
variable_scopes_count_tmp.Add(k, variable_scopes_count[k]); | |||||
foreach (var k in variable_scopes_count_tmp.Keys) | |||||
if (scope_name == null || k.StartsWith(scope_name + "/")) | if (scope_name == null || k.StartsWith(scope_name + "/")) | ||||
variable_scopes_count[k] = 0; | variable_scopes_count[k] = 0; | ||||
} | } | ||||
@@ -73,6 +73,8 @@ namespace TensorFlowNET.Examples.ImageProcess | |||||
float accuracy_test = 0f; | float accuracy_test = 0f; | ||||
float loss_test = 1f; | float loss_test = 1f; | ||||
NDArray x_train; | |||||
public bool Run() | public bool Run() | ||||
{ | { | ||||
PrepareData(); | PrepareData(); | ||||
@@ -241,11 +243,19 @@ namespace TensorFlowNET.Examples.ImageProcess | |||||
public void PrepareData() | public void PrepareData() | ||||
{ | { | ||||
mnist = MNIST.read_data_sets("mnist", one_hot: true); | mnist = MNIST.read_data_sets("mnist", one_hot: true); | ||||
x_train = Reformat(mnist.train.data, mnist.train.labels); | |||||
print("Size of:"); | print("Size of:"); | ||||
print($"- Training-set:\t\t{len(mnist.train.data)}"); | print($"- Training-set:\t\t{len(mnist.train.data)}"); | ||||
print($"- Validation-set:\t{len(mnist.validation.data)}"); | print($"- Validation-set:\t{len(mnist.validation.data)}"); | ||||
} | } | ||||
private NDArray Reformat(NDArray x, NDArray y) | |||||
{ | |||||
var (img_size, num_ch, num_class) = (np.sqrt(x.shape[1]), 1, np.unique<int>(np.argmax(y, 1))); | |||||
return x; | |||||
} | |||||
public void Train(Session sess) | public void Train(Session sess) | ||||
{ | { | ||||
// Number of training iterations in each epoch | // Number of training iterations in each epoch | ||||
@@ -39,6 +39,7 @@ namespace TensorFlowNET.ExamplesTests | |||||
new InceptionArchGoogLeNet() { Enabled = true }.Run(); | new InceptionArchGoogLeNet() { Enabled = true }.Run(); | ||||
} | } | ||||
[Ignore] | |||||
[TestMethod] | [TestMethod] | ||||
public void KMeansClustering() | public void KMeansClustering() | ||||
{ | { | ||||
@@ -83,10 +84,12 @@ namespace TensorFlowNET.ExamplesTests | |||||
new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Run(); | new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Run(); | ||||
} | } | ||||
[Ignore] | |||||
[TestMethod] | [TestMethod] | ||||
public void WordCnnTextClassification() | public void WordCnnTextClassification() | ||||
=> new CnnTextClassification { Enabled = true, ModelName = "word_cnn", DataLimit =100 }.Run(); | => new CnnTextClassification { Enabled = true, ModelName = "word_cnn", DataLimit =100 }.Run(); | ||||
[Ignore] | |||||
[TestMethod] | [TestMethod] | ||||
public void CharCnnTextClassification() | public void CharCnnTextClassification() | ||||
=> new CnnTextClassification { Enabled = true, ModelName = "char_cnn", DataLimit = 100 }.Run(); | => new CnnTextClassification { Enabled = true, ModelName = "char_cnn", DataLimit = 100 }.Run(); | ||||