Browse Source

fixed #164

tags/v0.8.0
haiping008 6 years ago
parent
commit
4a80846d0f
7 changed files with 53 additions and 17 deletions
  1. +0
    -3
      src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs
  2. +2
    -1
      src/TensorFlowNET.Core/Gradients/math_grad.py.cs
  3. +0
    -3
      src/TensorFlowNET.Core/Graphs/Graph.cs
  4. +0
    -4
      src/TensorFlowNET.Core/Operations/array_ops.py.cs
  5. +3
    -3
      src/TensorFlowNET.Core/ops.name_scope.cs
  6. +3
    -3
      src/TensorFlowNET.Core/ops.py.cs
  7. +45
    -0
      test/TensorFlowNET.UnitTest/NameScopeTest.cs

+ 0
- 3
src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs View File

@@ -143,9 +143,6 @@ namespace Tensorflow

}
});

// temp fix name scope
op.Graph._name_stack = "gradients";
}
}
else


+ 2
- 1
src/TensorFlowNET.Core/Gradients/math_grad.py.cs View File

@@ -48,7 +48,8 @@ namespace Tensorflow
var reduce_sum1 = math_ops.reduce_sum(realdiv1, rx);
var realdiv2 = gen_math_ops.real_div(-x, y);
var realdiv3 = gen_math_ops.real_div(realdiv2, y);
var reduce_sum2 = math_ops.reduce_sum(grad * realdiv3, ry);
var mul = grad * realdiv3;
var reduce_sum2 = math_ops.reduce_sum(mul, ry);

return (gen_array_ops.reshape(reduce_sum1, sx), gen_array_ops.reshape(reduce_sum2, sy));
}


+ 0
- 3
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -23,7 +23,6 @@ namespace Tensorflow
private List<String> _unfetchable_ops = new List<string>();

public string _name_stack = "";
public string old_stack = "";
public string _graph_key;
public Status Status { get; }

@@ -180,8 +179,6 @@ namespace Tensorflow

public string name_scope(string name)
{
old_stack = _name_stack;

string new_stack = "";

if (name.EndsWith("/"))


+ 0
- 4
src/TensorFlowNET.Core/Operations/array_ops.py.cs View File

@@ -94,10 +94,6 @@ namespace Tensorflow
return constant_op.constant(nd, name);
}
}
else
{
// result = gen_array_ops.shape();
}

return gen_array_ops.shape(input);
});


+ 3
- 3
src/TensorFlowNET.Core/ops.name_scope.cs View File

@@ -14,6 +14,7 @@ namespace Tensorflow
public object _values;
public Context _ctx;
public string _name_scope;
public string old_stack = "";
private object _g_manager;

public name_scope(string name, string default_name = "", object values = null)
@@ -38,15 +39,14 @@ namespace Tensorflow
if (g == null)
g = get_default_graph();

old_stack = g._name_stack;
_name_scope = g.name_scope(_name);
}

public void Dispose()
{
var g = get_default_graph();
g._name_stack = g.old_stack;
// clear g._name_stack
g.old_stack = "";
g._name_stack = old_stack;
}

public void __exit__()


+ 3
- 3
src/TensorFlowNET.Core/ops.py.cs View File

@@ -294,11 +294,11 @@ namespace Tensorflow
switch (oper.type)
{
case "Add":
return math_grad._AddGrad(op, out_grads);
return math_grad._AddGrad(oper, out_grads);
case "Sum":
return math_grad._SumGrad(op, out_grads);
return math_grad._SumGrad(oper, out_grads);
case "RealDiv":
return math_grad._RealDivGrad(op, out_grads);
return math_grad._RealDivGrad(oper, out_grads);
default:
throw new NotImplementedException($"get_gradient_function {oper.type}");
}


+ 45
- 0
test/TensorFlowNET.UnitTest/NameScopeTest.cs View File

@@ -0,0 +1,45 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;

namespace TensorFlowNET.UnitTest
{
[TestClass]
public class NameScopeTest : Python
{
Graph g = ops.get_default_graph();
string name = "";

[TestMethod]
public void NestedNameScope()
{
with<ops.name_scope>(new ops.name_scope("scope1"), scope1 =>
{
name = scope1;
Assert.AreEqual("scope1", g._name_stack);
Assert.AreEqual("scope1/", name);

var const1 = tf.constant(1.0);
Assert.AreEqual("scope1/Const:0", const1.name);

with<ops.name_scope>(new ops.name_scope("scope2"), scope2 =>
{
name = scope2;
Assert.AreEqual("scope1/scope2", g._name_stack);
Assert.AreEqual("scope1/scope2/", name);

var const2 = tf.constant(2.0);
Assert.AreEqual("scope1/scope2/Const:0", const2.name);
});

Assert.AreEqual("scope1", g._name_stack);
var const3 = tf.constant(2.0);
Assert.AreEqual("scope1/Const_1:0", const3.name);
});

Assert.AreEqual("", g._name_stack);
}
}
}

Loading…
Cancel
Save