Browse Source

fix RefVariable substraction.

tags/v0.12
Oceania2018 6 years ago
parent
commit
be024158f9
1 changed files with 20 additions and 6 deletions
  1. +20
    -6
      src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs

+ 20
- 6
src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs View File

@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/

using System;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -29,16 +30,29 @@ namespace Tensorflow
public static Tensor operator -(RefVariable x, double y) => op_helper("sub", x, y);
public static Tensor operator -(RefVariable x, Tensor y) => op_helper("sub", x, y);

public static Tensor operator <(RefVariable x, Tensor y) => op_helper("Less", x, y);
public static Tensor operator <(RefVariable x, Tensor y) => gen_math_ops.less(x.value(), y);

public static Tensor operator >(RefVariable x, Tensor y) => op_helper("Greater", x, y);
public static Tensor operator >(RefVariable x, Tensor y) => gen_math_ops.greater(x.value(), y);

private static Tensor op_helper<T>(string default_name, RefVariable x, T y)
{
var tensor1 = x.value();
return tf_with(ops.name_scope(null, default_name, new { tensor1, y }), scope => {
var tensor2 = ops.convert_to_tensor(y, tensor1.dtype.as_base_dtype(), "y");
return gen_math_ops.add(tensor1, tensor2, scope);
var xVal = x.value();
return tf_with(ops.name_scope(null, default_name, new { xVal, y }), scope => {
string name = scope;
var yTensor = ops.convert_to_tensor(y, xVal.dtype.as_base_dtype(), "y");
Tensor result = null;
switch (default_name)
{
case "add":
result = gen_math_ops.add(xVal, yTensor, name);
break;
case "sub":
result = gen_math_ops.sub(xVal, yTensor, name);
break;
default:
throw new NotImplementedException("");
}
return result;
});
}
}


Loading…
Cancel
Save