From 477d03db16bbbda93b8487dbb502347606415ab2 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 23 Jun 2019 20:43:43 -0500 Subject: [PATCH] remove order of _control_dependencies_for_inputs. --- src/TensorFlowNET.Core/Graphs/Graph.Control.cs | 2 +- src/TensorFlowNET.Core/Graphs/Graph.cs | 3 --- src/TensorFlowNET.Core/TensorFlowNET.Core.csproj | 3 +-- .../TextProcess/CnnTextClassification.cs | 10 ++++------ 4 files changed, 6 insertions(+), 12 deletions(-) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs index f88f02c9..fda9ff01 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs @@ -43,7 +43,7 @@ namespace Tensorflow ret.AddRange(controller.control_inputs.Where(x => !input_ops.Contains(x))); } - return ret.OrderBy(x => x.op.name).ToArray(); + return ret.ToArray(); } /// diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 657e2589..66bd6bbe 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -248,9 +248,6 @@ namespace Tensorflow var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); var input_ops = inputs.Select(x => x.op).ToArray(); - if (name == "loss/gradients/embedding/embedding_lookup_grad/Reshape") - ; - var control_inputs = _control_dependencies_for_inputs(input_ops); var op = new Operation(node_def, diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 59ed78f1..7bc7ae52 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -30,7 +30,7 @@ Docs: https://tensorflownet.readthedocs.io true - TRACE;DEBUG;GRAPH_SERIALIZE + TRACE;DEBUG @@ -54,7 +54,6 @@ Docs: https://tensorflownet.readthedocs.io - diff --git a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs index cd370d1b..883783b2 100644 --- a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs +++ b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs @@ -22,7 +22,7 @@ namespace TensorFlowNET.Examples public bool Enabled { get; set; } = true; public string Name => "CNN Text Classification"; public int? DataLimit = null; - public bool IsImportingGraph { get; set; } = false; + public bool IsImportingGraph { get; set; } = true; private const string dataDir = "word_cnn"; private string dataFileName = "dbpedia_csv.tar.gz"; @@ -44,9 +44,7 @@ namespace TensorFlowNET.Examples { PrepareData(); - Train(); - - return true; + return Train(); } // TODO: this originally is an SKLearn utility function. it randomizes train and test which we don't do here @@ -305,13 +303,13 @@ namespace TensorFlowNET.Examples } } - return false; + return max_accuracy > 0.8; } public bool Train() { var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); - string json = JsonConvert.SerializeObject(graph, Formatting.Indented); + // string json = JsonConvert.SerializeObject(graph, Formatting.Indented); return with(tf.Session(graph), sess => Train(sess, graph)); }