diff --git a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs index 046732c1..d449eaf8 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs @@ -143,9 +143,6 @@ namespace Tensorflow } }); - - // temp fix name scope - op.Graph._name_stack = "gradients"; } } else diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.py.cs b/src/TensorFlowNET.Core/Gradients/math_grad.py.cs index f2b146a4..2560a2e5 100644 --- a/src/TensorFlowNET.Core/Gradients/math_grad.py.cs +++ b/src/TensorFlowNET.Core/Gradients/math_grad.py.cs @@ -48,7 +48,8 @@ namespace Tensorflow var reduce_sum1 = math_ops.reduce_sum(realdiv1, rx); var realdiv2 = gen_math_ops.real_div(-x, y); var realdiv3 = gen_math_ops.real_div(realdiv2, y); - var reduce_sum2 = math_ops.reduce_sum(grad * realdiv3, ry); + var mul = grad * realdiv3; + var reduce_sum2 = math_ops.reduce_sum(mul, ry); return (gen_array_ops.reshape(reduce_sum1, sx), gen_array_ops.reshape(reduce_sum2, sy)); } diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 3d46196d..dc25f853 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -23,7 +23,6 @@ namespace Tensorflow private List _unfetchable_ops = new List(); public string _name_stack = ""; - public string old_stack = ""; public string _graph_key; public Status Status { get; } @@ -180,8 +179,6 @@ namespace Tensorflow public string name_scope(string name) { - old_stack = _name_stack; - string new_stack = ""; if (name.EndsWith("/")) diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index df139aa3..56c6f715 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -94,10 +94,6 @@ namespace Tensorflow return constant_op.constant(nd, name); } } - else - { - // result = gen_array_ops.shape(); - } return gen_array_ops.shape(input); }); diff --git a/src/TensorFlowNET.Core/ops.name_scope.cs b/src/TensorFlowNET.Core/ops.name_scope.cs index c0e60b52..0fee8330 100644 --- a/src/TensorFlowNET.Core/ops.name_scope.cs +++ b/src/TensorFlowNET.Core/ops.name_scope.cs @@ -14,6 +14,7 @@ namespace Tensorflow public object _values; public Context _ctx; public string _name_scope; + public string old_stack = ""; private object _g_manager; public name_scope(string name, string default_name = "", object values = null) @@ -38,15 +39,14 @@ namespace Tensorflow if (g == null) g = get_default_graph(); + old_stack = g._name_stack; _name_scope = g.name_scope(_name); } public void Dispose() { var g = get_default_graph(); - g._name_stack = g.old_stack; - // clear g._name_stack - g.old_stack = ""; + g._name_stack = old_stack; } public void __exit__() diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index 823fc6fd..ff3b1c5b 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -294,11 +294,11 @@ namespace Tensorflow switch (oper.type) { case "Add": - return math_grad._AddGrad(op, out_grads); + return math_grad._AddGrad(oper, out_grads); case "Sum": - return math_grad._SumGrad(op, out_grads); + return math_grad._SumGrad(oper, out_grads); case "RealDiv": - return math_grad._RealDivGrad(op, out_grads); + return math_grad._RealDivGrad(oper, out_grads); default: throw new NotImplementedException($"get_gradient_function {oper.type}"); } diff --git a/test/TensorFlowNET.UnitTest/NameScopeTest.cs b/test/TensorFlowNET.UnitTest/NameScopeTest.cs new file mode 100644 index 00000000..64b2bd1f --- /dev/null +++ b/test/TensorFlowNET.UnitTest/NameScopeTest.cs @@ -0,0 +1,45 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow; + +namespace TensorFlowNET.UnitTest +{ + [TestClass] + public class NameScopeTest : Python + { + Graph g = ops.get_default_graph(); + string name = ""; + + [TestMethod] + public void NestedNameScope() + { + with(new ops.name_scope("scope1"), scope1 => + { + name = scope1; + Assert.AreEqual("scope1", g._name_stack); + Assert.AreEqual("scope1/", name); + + var const1 = tf.constant(1.0); + Assert.AreEqual("scope1/Const:0", const1.name); + + with(new ops.name_scope("scope2"), scope2 => + { + name = scope2; + Assert.AreEqual("scope1/scope2", g._name_stack); + Assert.AreEqual("scope1/scope2/", name); + + var const2 = tf.constant(2.0); + Assert.AreEqual("scope1/scope2/Const:0", const2.name); + }); + + Assert.AreEqual("scope1", g._name_stack); + var const3 = tf.constant(2.0); + Assert.AreEqual("scope1/Const_1:0", const3.name); + }); + + Assert.AreEqual("", g._name_stack); + } + } +}