diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs index b96576e5..a874832f 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs @@ -22,23 +22,26 @@ namespace Tensorflow { public partial class ResourceVariable { - public static Tensor operator +(ResourceVariable x, int y) => op_helper("add", x, y); - public static Tensor operator +(ResourceVariable x, float y) => op_helper("add", x, y); - public static Tensor operator +(ResourceVariable x, double y) => op_helper("add", x, y); - - public static Tensor operator -(ResourceVariable x, int y) => op_helper("sub", x, y); - public static Tensor operator -(ResourceVariable x, float y) => op_helper("sub", x, y); - public static Tensor operator -(ResourceVariable x, double y) => op_helper("sub", x, y); - public static Tensor operator -(ResourceVariable x, Tensor y) => op_helper("sub", x, y); + public static OpDefLibrary _op_def_lib = new OpDefLibrary(); - public static Tensor operator *(ResourceVariable x, ResourceVariable y) => op_helper("mul", x, y); - public static Tensor operator *(ResourceVariable x, NDArray y) => op_helper("mul", x, y); + public static ResourceVariable operator +(ResourceVariable x, int y) => op_helper("add", x, y); + public static ResourceVariable operator +(ResourceVariable x, float y) => op_helper("add", x, y); + public static ResourceVariable operator +(ResourceVariable x, double y) => op_helper("add", x, y); + public static ResourceVariable operator +(ResourceVariable x, ResourceVariable y) => op_helper("add", x, y); + public static ResourceVariable operator -(ResourceVariable x, int y) => op_helper("sub", x, y); + public static ResourceVariable operator -(ResourceVariable x, float y) => op_helper("sub", x, y); + public static ResourceVariable operator -(ResourceVariable x, double y) => op_helper("sub", x, y); + public static ResourceVariable operator -(ResourceVariable x, Tensor y) => op_helper("sub", x, y); + public static ResourceVariable operator -(ResourceVariable x, ResourceVariable y) => op_helper("sub", x, y); - public static Tensor operator <(ResourceVariable x, Tensor y) => gen_math_ops.less(x.value(), y); + public static ResourceVariable operator *(ResourceVariable x, ResourceVariable y) => op_helper("mul", x, y); + public static ResourceVariable operator *(ResourceVariable x, NDArray y) => op_helper("mul", x, y); - public static Tensor operator >(ResourceVariable x, Tensor y) => gen_math_ops.greater(x.value(), y); + public static ResourceVariable operator <(ResourceVariable x, Tensor y) => less(x.value(), y); - private static Tensor op_helper(string default_name, ResourceVariable x, T y) + public static ResourceVariable operator >(ResourceVariable x, Tensor y) => greater(x.value(), y); + + private static ResourceVariable op_helper(string default_name, ResourceVariable x, T y) => tf_with(ops.name_scope(null, default_name, new { x, y }), scope => { string name = scope; @@ -64,7 +67,33 @@ namespace Tensorflow // x.assign(result); // result.ResourceVar = x; - return result; + return tf.Variable(result); }); + + private static ResourceVariable less(Tx x, Ty y, string name = null) + { + if (tf.context.executing_eagerly()) + { + var results = EagerTensorPass.Create(); + var inputs = EagerTensorPass.From(x, y); + Status status = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Less", name, + inputs.Points, inputs.Length, + null, null, + results.Points, results.Length); + status.Check(true); + return tf.Variable(results[0].Resolve()); + } + + var _op = _op_def_lib._apply_op_helper("Less", name: name, args: new { x, y }); + + return tf.Variable(_op.outputs[0]); + } + private static ResourceVariable greater(Tx x, Ty y, string name = null) + { + var _op = _op_def_lib._apply_op_helper("Greater", name: name, args: new { x, y }); + + return tf.Variable(_op.outputs[0]); + } } } diff --git a/test/TensorFlowNET.UnitTest/Basics/VariableTest.cs b/test/TensorFlowNET.UnitTest/Basics/VariableTest.cs index e41edb89..a6302655 100644 --- a/test/TensorFlowNET.UnitTest/Basics/VariableTest.cs +++ b/test/TensorFlowNET.UnitTest/Basics/VariableTest.cs @@ -56,10 +56,10 @@ namespace TensorFlowNET.UnitTest.Basics public void Accumulation() { var x = tf.Variable(10, name: "x"); - /*for (int i = 0; i < 5; i++) + for (int i = 0; i < 5; i++) x = x + 1; - Assert.AreEqual(15, (int)x.numpy());*/ + Assert.AreEqual(15, (int)x.numpy()); } [TestMethod]