diff --git a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs index 4c08de1d..2287e1eb 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs @@ -146,7 +146,7 @@ namespace Tensorflow var inputs = _NonEagerInputs(op, xs).ToList(); foreach (var (t_in, in_grad) in Python.zip(inputs, in_grads)) { - if(in_grad != null) + if(in_grad.op != null) { in_grad.shape = t_in.shape; _SetGrad(grads, t_in, in_grad); diff --git a/src/TensorFlowNET.Core/Python.cs b/src/TensorFlowNET.Core/Python.cs index 0ce36b00..2ddd119c 100644 --- a/src/TensorFlowNET.Core/Python.cs +++ b/src/TensorFlowNET.Core/Python.cs @@ -53,6 +53,25 @@ namespace Tensorflow } } + public static TOut with(IPython py, Func action) where TIn : IPython + { + try + { + py.__enter__(); + return action((TIn)py); + } + catch (Exception ex) + { + Console.WriteLine(ex.ToString()); + throw ex; + } + finally + { + py.__exit__(); + py.Dispose(); + } + } + public static IEnumerable<(T, T)> zip(NDArray t1, NDArray t2) { int index = 0; diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs index e0e7e75e..9923bb48 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs @@ -23,6 +23,16 @@ namespace Tensorflow return gen_math_ops.sub(t1, t2); } + public static Tensor operator *(double x, Tensor y) + { + return Python.with(new ops.name_scope("", "mul", new { x, y }), + scope => + { + var x1 = ops.convert_to_tensor(x, y.dtype.as_base_dtype(), name: "x"); + return gen_math_ops.mul(x1, y, name: scope); + }); + } + public static Tensor operator *(Tensor x, Tensor y) { Tensor t = null; diff --git a/test/TensorFlowNET.UnitTest/GradientTest.cs b/test/TensorFlowNET.UnitTest/GradientTest.cs index 317c8e01..6b937425 100644 --- a/test/TensorFlowNET.UnitTest/GradientTest.cs +++ b/test/TensorFlowNET.UnitTest/GradientTest.cs @@ -14,8 +14,19 @@ namespace TensorFlowNET.UnitTest { var a = tf.constant(0.0); var b = 2.0 * a; - var c = a + b; - var g = tf.gradients(c, new Tensor[] { a, b }, stop_gradients: new Tensor[] { a, b }); + Assert.AreEqual(b.name, "mul:0"); + Assert.AreEqual(b.op.inputs[0].name, "mul/x:0"); + Assert.AreEqual(b.op.inputs[1].name, "Const:0"); + + var ys = a + b; + Assert.AreEqual(ys.name, "add:0"); + Assert.AreEqual(ys.op.inputs[0].name, "Const:0"); + Assert.AreEqual(ys.op.inputs[1].name, "mul:0"); + + var xs = new Tensor[] { a, b }; + var g = tf.gradients(ys, xs, stop_gradients: new Tensor[] { a, b }); + Assert.AreEqual(g[0].name, "gradients/Fill:0"); + Assert.AreEqual(g[1].name, "gradients/Fill:0"); } } }