Browse Source

Merge pull request #272 from bpietroiu/Issue-CnnTextClassification

CnnTextClassification debugging
tags/v0.9
Haiping GitHub 6 years ago
parent
commit
60ec5af747
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 5 deletions
  1. +7
    -0
      src/TensorFlowNET.Core/Train/AdamOptimizer.cs
  2. +6
    -5
      test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs
  3. +9
    -0
      test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs

+ 7
- 0
src/TensorFlowNET.Core/Train/AdamOptimizer.cs View File

@@ -70,5 +70,12 @@ namespace Tensorflow.Train
return (_get_non_slot_variable("beta1_power", graph: graph),
_get_non_slot_variable("beta2_power", graph: graph));
}

public override void _prepare()
{
//copied from GradientDescentOptimizer
LearningRate = _call_if_callable(LearningRate);
LearningRateTensor = ops.convert_to_tensor(LearningRate, name: "learning_rate");
}
}
}

+ 6
- 5
test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs View File

@@ -57,10 +57,10 @@ namespace TensorFlowNET.Examples
//int classes = y.Data<int>().Distinct().Count();
//int samples = len / classes;
int train_size = (int)Math.Round(len * (1 - test_size));
var train_x = x[new Slice(stop: train_size), new Slice()];
var valid_x = x[new Slice(start: train_size), new Slice()];
var train_y = y[new Slice(stop: train_size)];
var valid_y = y[new Slice(start: train_size)];
train_x = x[new Slice(stop: train_size), new Slice()];
valid_x = x[new Slice(start: train_size), new Slice()];
train_y = y[new Slice(stop: train_size)];
valid_y = y[new Slice(start: train_size)];
Console.WriteLine("\tDONE");
return (train_x, valid_x, train_y, valid_y);
}
@@ -135,7 +135,8 @@ namespace TensorFlowNET.Examples
{
// delete old cached file which contains errors
Console.WriteLine("Discarding cached file: " + meta_path);
File.Delete(meta_path);
if(File.Exists(meta_path))
File.Delete(meta_path);
}
var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file;
Web.Download(url, "graph", meta_file);


+ 9
- 0
test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs View File

@@ -91,6 +91,15 @@ namespace TensorFlowNET.ExamplesTests
new TextClassificationTrain() { Enabled = true, DataLimit=100 }.Run();
}
[TestMethod]
public void CnnTextClassificationTrain()
{
tf.Graph().as_default();
new CnnTextClassification() { Enabled = true, IsImportingGraph = false }.Run();
}
[Ignore]
[TestMethod]
public void TextClassificationWithMovieReviews()


Loading…
Cancel
Save