Browse Source

change tf.math.multiply, math.add to generic.

tags/v0.9
Oceania2018 6 years ago
parent
commit
5627443b7e
4 changed files with 46 additions and 16 deletions
  1. +2
    -2
      src/TensorFlowNET.Core/APIs/tf.math.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Operations/Operation.cs
  3. +2
    -2
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  4. +41
    -11
      test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs

+ 2
- 2
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -27,7 +27,7 @@ namespace Tensorflow
public static Tensor asin(Tensor x, string name = null)
=> gen_math_ops.asin(x, name);

public static Tensor add(Tensor a, Tensor b)
public static Tensor add<Tx, Ty>(Tx a, Ty b)
=> gen_math_ops.add(a, b);

/// <summary>
@@ -251,7 +251,7 @@ namespace Tensorflow
public static Tensor minimum<T1, T2>(T1 x, T2 y, string name = null)
=> gen_math_ops.minimum(x, y, name: name);

public static Tensor multiply(Tensor x, Tensor y)
public static Tensor multiply<Tx, Ty>(Tx x, Ty y)
=> gen_math_ops.mul(x, y);

public static Tensor negative(Tensor x, string name = null)


+ 1
- 1
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -287,7 +287,7 @@ namespace Tensorflow
// Reset cached inputs.
_inputs = null;// new InputList(new Tensor[] { tensor }); // is this right? original code: self._inputs_val=None
// TODO: implement below code dependencies
//c_api.UpdateEdge(_graph._c_graph, output, input);
// c_api.TF_UpdateEdge(graph, output, input, status);
}

private void _assert_same_graph(Tensor tensor)


+ 2
- 2
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -80,7 +80,7 @@ namespace Tensorflow
return _op.outputs[0];
}
public static Tensor add(Tensor x, Tensor y, string name = null)
public static Tensor add<Tx, Ty>(Tx x, Ty y, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y });
@@ -300,7 +300,7 @@ namespace Tensorflow
return _op.outputs[0];
}
public static Tensor mul(Tensor x, Tensor y, string name = null)
public static Tensor mul<Tx, Ty>(Tx x, Ty y, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y });


+ 41
- 11
test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs View File

@@ -1,4 +1,5 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow;
namespace TensorFlowNET.UnitTest.control_flow_ops_test
@@ -18,25 +19,54 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
var x = tf.constant(2);
var y = tf.constant(5);
var z = control_flow_ops.cond(tf.less(x, y),
() => tf.multiply(x, tf.constant(17)),
() => tf.add(y, tf.constant(23)));
() => tf.multiply(x, 17),
() => tf.add(y, 23));
int result = z.eval(sess);
assertEquals(result, 34);
});
}
[Ignore("Todo")]
[TestMethod]
public void testCondFalse()
{
// def testCondFalse(self):
// x = constant_op.constant(2)
// y = constant_op.constant(1)
// z = control_flow_ops.cond(
// math_ops.less(
// x,
// y), lambda: math_ops.multiply(x, 17), lambda: math_ops.add(y, 23))
// self.assertEquals(self.evaluate(z), 24)
/* python
* import tensorflow as tf
from tensorflow.python.framework import ops
def if_true():
return tf.math.multiply(x, 17)
def if_false():
return tf.math.add(y, 23)
with tf.Session() as sess:
x = tf.constant(2)
y = tf.constant(1)
pred = tf.math.less(x,y)
z = tf.cond(pred, if_true, if_false)
result = z.eval()
print(result == 24) */
with(tf.Session(), sess =>
{
var x = tf.constant(2);
var y = tf.constant(1);
var pred = tf.less(x, y);
Func<ITensorOrOperation> if_true = delegate
{
return tf.multiply(x, 17);
};
Func<ITensorOrOperation> if_false = delegate
{
return tf.add(y, 23);
};
var z = control_flow_ops.cond(pred, if_true, if_false);
int result = z.eval(sess);
assertEquals(result, 24);
});
}
[Ignore("Todo")]


Loading…
Cancel
Save