Browse Source

fixed operand should be mul/x:0 #155

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
f3427cb361
4 changed files with 43 additions and 3 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs
  2. +19
    -0
      src/TensorFlowNET.Core/Python.cs
  3. +10
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
  4. +13
    -2
      test/TensorFlowNET.UnitTest/GradientTest.cs

+ 1
- 1
src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs View File

@@ -146,7 +146,7 @@ namespace Tensorflow
var inputs = _NonEagerInputs(op, xs).ToList(); var inputs = _NonEagerInputs(op, xs).ToList();
foreach (var (t_in, in_grad) in Python.zip(inputs, in_grads)) foreach (var (t_in, in_grad) in Python.zip(inputs, in_grads))
{ {
if(in_grad != null)
if(in_grad.op != null)
{ {
in_grad.shape = t_in.shape; in_grad.shape = t_in.shape;
_SetGrad(grads, t_in, in_grad); _SetGrad(grads, t_in, in_grad);


+ 19
- 0
src/TensorFlowNET.Core/Python.cs View File

@@ -53,6 +53,25 @@ namespace Tensorflow
} }
} }


public static TOut with<TIn, TOut>(IPython py, Func<TIn, TOut> action) where TIn : IPython
{
try
{
py.__enter__();
return action((TIn)py);
}
catch (Exception ex)
{
Console.WriteLine(ex.ToString());
throw ex;
}
finally
{
py.__exit__();
py.Dispose();
}
}

public static IEnumerable<(T, T)> zip<T>(NDArray t1, NDArray t2) public static IEnumerable<(T, T)> zip<T>(NDArray t1, NDArray t2)
{ {
int index = 0; int index = 0;


+ 10
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs View File

@@ -23,6 +23,16 @@ namespace Tensorflow
return gen_math_ops.sub(t1, t2); return gen_math_ops.sub(t1, t2);
} }


public static Tensor operator *(double x, Tensor y)
{
return Python.with<ops.name_scope, Tensor>(new ops.name_scope("", "mul", new { x, y }),
scope =>
{
var x1 = ops.convert_to_tensor(x, y.dtype.as_base_dtype(), name: "x");
return gen_math_ops.mul(x1, y, name: scope);
});
}

public static Tensor operator *(Tensor x, Tensor y) public static Tensor operator *(Tensor x, Tensor y)
{ {
Tensor t = null; Tensor t = null;


+ 13
- 2
test/TensorFlowNET.UnitTest/GradientTest.cs View File

@@ -14,8 +14,19 @@ namespace TensorFlowNET.UnitTest
{ {
var a = tf.constant(0.0); var a = tf.constant(0.0);
var b = 2.0 * a; var b = 2.0 * a;
var c = a + b;
var g = tf.gradients(c, new Tensor[] { a, b }, stop_gradients: new Tensor[] { a, b });
Assert.AreEqual(b.name, "mul:0");
Assert.AreEqual(b.op.inputs[0].name, "mul/x:0");
Assert.AreEqual(b.op.inputs[1].name, "Const:0");

var ys = a + b;
Assert.AreEqual(ys.name, "add:0");
Assert.AreEqual(ys.op.inputs[0].name, "Const:0");
Assert.AreEqual(ys.op.inputs[1].name, "mul:0");

var xs = new Tensor[] { a, b };
var g = tf.gradients(ys, xs, stop_gradients: new Tensor[] { a, b });
Assert.AreEqual(g[0].name, "gradients/Fill:0");
Assert.AreEqual(g[1].name, "gradients/Fill:0");
} }
} }
} }

Loading…
Cancel
Save