From 5b690a05f30718de54923a96afac825e346c5dd7 Mon Sep 17 00:00:00 2001 From: Bogdan Pietroiu Date: Mon, 17 Jun 2019 17:03:19 +0300 Subject: [PATCH] CnnTextClassification debugging --- src/TensorFlowNET.Core/Train/AdamOptimizer.cs | 7 +++++++ .../TextProcess/CnnTextClassification.cs | 11 ++++++----- .../ExamplesTests/ExamplesTest.cs | 9 +++++++++ 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs index 56e69881..8b14bf50 100644 --- a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs +++ b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs @@ -70,5 +70,12 @@ namespace Tensorflow.Train return (_get_non_slot_variable("beta1_power", graph: graph), _get_non_slot_variable("beta2_power", graph: graph)); } + + public override void _prepare() + { + //copied from GradientDescentOptimizer + LearningRate = _call_if_callable(LearningRate); + LearningRateTensor = ops.convert_to_tensor(LearningRate, name: "learning_rate"); + } } } diff --git a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs index 465b08b2..4caf3b58 100644 --- a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs +++ b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs @@ -57,10 +57,10 @@ namespace TensorFlowNET.Examples //int classes = y.Data().Distinct().Count(); //int samples = len / classes; int train_size = (int)Math.Round(len * (1 - test_size)); - var train_x = x[new Slice(stop: train_size), new Slice()]; - var valid_x = x[new Slice(start: train_size), new Slice()]; - var train_y = y[new Slice(stop: train_size)]; - var valid_y = y[new Slice(start: train_size)]; + train_x = x[new Slice(stop: train_size), new Slice()]; + valid_x = x[new Slice(start: train_size), new Slice()]; + train_y = y[new Slice(stop: train_size)]; + valid_y = y[new Slice(start: train_size)]; Console.WriteLine("\tDONE"); return (train_x, valid_x, train_y, valid_y); } @@ -135,7 +135,8 @@ namespace TensorFlowNET.Examples { // delete old cached file which contains errors Console.WriteLine("Discarding cached file: " + meta_path); - File.Delete(meta_path); + if(File.Exists(meta_path)) + File.Delete(meta_path); } var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file; Web.Download(url, "graph", meta_file); diff --git a/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs b/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs index 214bb835..a267324e 100644 --- a/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs +++ b/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs @@ -91,6 +91,15 @@ namespace TensorFlowNET.ExamplesTests new TextClassificationTrain() { Enabled = true, DataLimit=100 }.Run(); } + + [TestMethod] + public void CnnTextClassificationTrain() + { + tf.Graph().as_default(); + new CnnTextClassification() { Enabled = true, IsImportingGraph = false }.Run(); + } + + [Ignore] [TestMethod] public void TextClassificationWithMovieReviews()