From be024158f9c991269929f2f0b97f797a54fc962d Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Mon, 2 Sep 2019 18:25:34 -0500 Subject: [PATCH] fix RefVariable substraction. --- .../Variables/RefVariable.Operators.cs | 26 ++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs b/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs index 4ee9db76..79d7dd5f 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using System; using static Tensorflow.Binding; namespace Tensorflow @@ -29,16 +30,29 @@ namespace Tensorflow public static Tensor operator -(RefVariable x, double y) => op_helper("sub", x, y); public static Tensor operator -(RefVariable x, Tensor y) => op_helper("sub", x, y); - public static Tensor operator <(RefVariable x, Tensor y) => op_helper("Less", x, y); + public static Tensor operator <(RefVariable x, Tensor y) => gen_math_ops.less(x.value(), y); - public static Tensor operator >(RefVariable x, Tensor y) => op_helper("Greater", x, y); + public static Tensor operator >(RefVariable x, Tensor y) => gen_math_ops.greater(x.value(), y); private static Tensor op_helper(string default_name, RefVariable x, T y) { - var tensor1 = x.value(); - return tf_with(ops.name_scope(null, default_name, new { tensor1, y }), scope => { - var tensor2 = ops.convert_to_tensor(y, tensor1.dtype.as_base_dtype(), "y"); - return gen_math_ops.add(tensor1, tensor2, scope); + var xVal = x.value(); + return tf_with(ops.name_scope(null, default_name, new { xVal, y }), scope => { + string name = scope; + var yTensor = ops.convert_to_tensor(y, xVal.dtype.as_base_dtype(), "y"); + Tensor result = null; + switch (default_name) + { + case "add": + result = gen_math_ops.add(xVal, yTensor, name); + break; + case "sub": + result = gen_math_ops.sub(xVal, yTensor, name); + break; + default: + throw new NotImplementedException(""); + } + return result; }); } }