using Microsoft.VisualStudio.TestTools.UnitTesting;
using Newtonsoft.Json;
using System;
using Tensorflow;
namespace TensorFlowNET.UnitTest.control_flow_ops_test
{
///
/// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py
///
[TestClass]
public class CondTestCases : PythonTest
{
[TestMethod]
public void testCondTrue()
{
var graph = tf.Graph().as_default();
// tf.train.import_meta_graph("cond_test.meta");
var json = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented);
with(tf.Session(graph), sess =>
{
var x = tf.constant(2, name: "x"); // graph.get_operation_by_name("Const").output;
var y = tf.constant(5, name: "y"); // graph.get_operation_by_name("Const_1").output;
var pred = tf.less(x, y); // graph.get_operation_by_name("Less").output;
Func if_true = delegate
{
return tf.constant(2, name: "t2");
};
Func if_false = delegate
{
return tf.constant(5, name: "f5");
};
var z = control_flow_ops.cond(pred, if_true, if_false); // graph.get_operation_by_name("cond/Merge").output
json = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented);
int result = z.eval(sess);
assertEquals(result, 2);
});
}
[TestMethod]
public void testCondFalse()
{
/* python
* import tensorflow as tf
from tensorflow.python.framework import ops
def if_true():
return tf.math.multiply(x, 17)
def if_false():
return tf.math.add(y, 23)
with tf.Session() as sess:
x = tf.constant(2)
y = tf.constant(1)
pred = tf.math.less(x,y)
z = tf.cond(pred, if_true, if_false)
result = z.eval()
print(result == 24) */
var graph = tf.Graph().as_default();
//tf.train.import_meta_graph("cond_test.meta");
//var json = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented);
with(tf.Session(), sess =>
{
var x = tf.constant(2, name: "x");
var y = tf.constant(1, name: "y");
var pred = tf.less(x, y);
Func if_true = delegate
{
return tf.constant(2, name: "t2");
};
Func if_false = delegate
{
return tf.constant(1, name: "f1");
};
var z = control_flow_ops.cond(pred, if_true, if_false);
var json1 = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented);
int result = z.eval(sess);
assertEquals(result, 1);
});
}
[Ignore("Todo")]
[TestMethod]
public void testCondTrueLegacy()
{
// def testCondTrueLegacy(self):
// x = constant_op.constant(2)
// y = constant_op.constant(5)
// z = control_flow_ops.cond(
// math_ops.less(x, y),
// fn1=lambda: math_ops.multiply(x, 17),
// fn2=lambda: math_ops.add(y, 23))
// self.assertEquals(self.evaluate(z), 34)
}
[Ignore("Todo")]
[TestMethod]
public void testCondFalseLegacy()
{
// def testCondFalseLegacy(self):
// x = constant_op.constant(2)
// y = constant_op.constant(1)
// z = control_flow_ops.cond(
// math_ops.less(x, y),
// fn1=lambda: math_ops.multiply(x, 17),
// fn2=lambda: math_ops.add(y, 23))
// self.assertEquals(self.evaluate(z), 24)
}
[Ignore("Todo")]
[TestMethod]
public void testCondMissingArg1()
{
// def testCondMissingArg1(self):
// x = constant_op.constant(1)
// with self.assertRaises(TypeError):
// control_flow_ops.cond(True, false_fn=lambda: x)
}
[Ignore("Todo")]
[TestMethod]
public void testCondMissingArg2()
{
// def testCondMissingArg2(self):
// x = constant_op.constant(1)
// with self.assertRaises(TypeError):
// control_flow_ops.cond(True, lambda: x)
}
[Ignore("Todo")]
[TestMethod]
public void testCondDuplicateArg1()
{
// def testCondDuplicateArg1(self):
// x = constant_op.constant(1)
// with self.assertRaises(TypeError):
// control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x)
}
[Ignore("Todo")]
[TestMethod]
public void testCondDuplicateArg2()
{
// def testCondDuplicateArg2(self):
// x = constant_op.constant(1)
// with self.assertRaises(TypeError):
// control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x)
}
}
}