Browse Source

add ResourceVarible operate functions

tags/v0.20
pepure 5 years ago
parent
commit
afee297e9c
2 changed files with 45 additions and 16 deletions
  1. +43
    -14
      src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs
  2. +2
    -2
      test/TensorFlowNET.UnitTest/Basics/VariableTest.cs

+ 43
- 14
src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs View File

@@ -22,23 +22,26 @@ namespace Tensorflow
{
public partial class ResourceVariable
{
public static Tensor operator +(ResourceVariable x, int y) => op_helper("add", x, y);
public static Tensor operator +(ResourceVariable x, float y) => op_helper("add", x, y);
public static Tensor operator +(ResourceVariable x, double y) => op_helper("add", x, y);
public static Tensor operator -(ResourceVariable x, int y) => op_helper("sub", x, y);
public static Tensor operator -(ResourceVariable x, float y) => op_helper("sub", x, y);
public static Tensor operator -(ResourceVariable x, double y) => op_helper("sub", x, y);
public static Tensor operator -(ResourceVariable x, Tensor y) => op_helper("sub", x, y);
public static OpDefLibrary _op_def_lib = new OpDefLibrary();

public static Tensor operator *(ResourceVariable x, ResourceVariable y) => op_helper("mul", x, y);
public static Tensor operator *(ResourceVariable x, NDArray y) => op_helper("mul", x, y);
public static ResourceVariable operator +(ResourceVariable x, int y) => op_helper("add", x, y);
public static ResourceVariable operator +(ResourceVariable x, float y) => op_helper("add", x, y);
public static ResourceVariable operator +(ResourceVariable x, double y) => op_helper("add", x, y);
public static ResourceVariable operator +(ResourceVariable x, ResourceVariable y) => op_helper("add", x, y);
public static ResourceVariable operator -(ResourceVariable x, int y) => op_helper("sub", x, y);
public static ResourceVariable operator -(ResourceVariable x, float y) => op_helper("sub", x, y);
public static ResourceVariable operator -(ResourceVariable x, double y) => op_helper("sub", x, y);
public static ResourceVariable operator -(ResourceVariable x, Tensor y) => op_helper("sub", x, y);
public static ResourceVariable operator -(ResourceVariable x, ResourceVariable y) => op_helper("sub", x, y);

public static Tensor operator <(ResourceVariable x, Tensor y) => gen_math_ops.less(x.value(), y);
public static ResourceVariable operator *(ResourceVariable x, ResourceVariable y) => op_helper("mul", x, y);
public static ResourceVariable operator *(ResourceVariable x, NDArray y) => op_helper("mul", x, y);

public static Tensor operator >(ResourceVariable x, Tensor y) => gen_math_ops.greater(x.value(), y);
public static ResourceVariable operator <(ResourceVariable x, Tensor y) => less(x.value(), y);

private static Tensor op_helper<T>(string default_name, ResourceVariable x, T y)
public static ResourceVariable operator >(ResourceVariable x, Tensor y) => greater(x.value(), y);

private static ResourceVariable op_helper<T>(string default_name, ResourceVariable x, T y)
=> tf_with(ops.name_scope(null, default_name, new { x, y }), scope =>
{
string name = scope;
@@ -64,7 +67,33 @@ namespace Tensorflow

// x.assign(result);
// result.ResourceVar = x;
return result;
return tf.Variable(result);
});

private static ResourceVariable less<Tx, Ty>(Tx x, Ty y, string name = null)
{
if (tf.context.executing_eagerly())
{
var results = EagerTensorPass.Create();
var inputs = EagerTensorPass.From(x, y);
Status status = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Less", name,
inputs.Points, inputs.Length,
null, null,
results.Points, results.Length);
status.Check(true);
return tf.Variable(results[0].Resolve());
}

var _op = _op_def_lib._apply_op_helper("Less", name: name, args: new { x, y });

return tf.Variable(_op.outputs[0]);
}
private static ResourceVariable greater<Tx, Ty>(Tx x, Ty y, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Greater", name: name, args: new { x, y });

return tf.Variable(_op.outputs[0]);
}
}
}

+ 2
- 2
test/TensorFlowNET.UnitTest/Basics/VariableTest.cs View File

@@ -56,10 +56,10 @@ namespace TensorFlowNET.UnitTest.Basics
public void Accumulation()
{
var x = tf.Variable(10, name: "x");
/*for (int i = 0; i < 5; i++)
for (int i = 0; i < 5; i++)
x = x + 1;

Assert.AreEqual(15, (int)x.numpy());*/
Assert.AreEqual(15, (int)x.numpy());
}

[TestMethod]


Loading…
Cancel
Save