Browse Source

added original complex cond test cases

tags/v0.9
Meinrad Recheis 6 years ago
parent
commit
7e77cb5c5e
2 changed files with 44 additions and 11 deletions
  1. +8
    -9
      test/TensorFlowNET.UnitTest/PythonTest.cs
  2. +36
    -2
      test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs

+ 8
- 9
test/TensorFlowNET.UnitTest/PythonTest.cs View File

@@ -138,7 +138,7 @@ namespace TensorFlowNET.UnitTest
/// </summary> /// </summary>
public T evaluate<T>(Tensor tensor) public T evaluate<T>(Tensor tensor)
{ {
var results = new Dictionary<string, NDArray>();
object result = null;
// if context.executing_eagerly(): // if context.executing_eagerly():
// return self._eval_helper(tensors) // return self._eval_helper(tensors)
// else: // else:
@@ -146,26 +146,25 @@ namespace TensorFlowNET.UnitTest
var sess = ops.get_default_session(); var sess = ops.get_default_session();
if (sess == null) if (sess == null)
sess = self.session(); sess = self.session();
T t_result = (T)(object)null;
with<Session>(sess, s => with<Session>(sess, s =>
{ {
var ndarray=tensor.eval();
var ndarray=tensor.eval();
if (typeof(T) == typeof(double)) if (typeof(T) == typeof(double))
{ {
double d = ndarray;
t_result = (T)(object)d;
double x = ndarray;
result=x;
} }
else if (typeof(T) == typeof(int)) else if (typeof(T) == typeof(int))
{ {
int d = ndarray;
t_result = (T) (object) d;
int x = ndarray;
result = x;
} }
else else
{ {
t_result = (T)(object)ndarray;
result = ndarray;
} }
}); });
return t_result;
return (T)result;
} }
} }


+ 36
- 2
test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs View File

@@ -12,7 +12,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
public class CondTestCases : PythonTest public class CondTestCases : PythonTest
{ {
[TestMethod] [TestMethod]
public void testCondTrue()
public void testCondTrue_ConstOnly()
{ {
var graph = tf.Graph().as_default(); var graph = tf.Graph().as_default();
@@ -31,7 +31,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
} }
[TestMethod] [TestMethod]
public void testCondFalse()
public void testCondFalse_ConstOnly()
{ {
var graph = tf.Graph().as_default(); var graph = tf.Graph().as_default();
@@ -49,6 +49,40 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
}); });
} }
[TestMethod]
public void testCondTrue()
{
var graph = tf.Graph().as_default();
with(tf.Session(graph), sess =>
{
var x = tf.constant(2);
var y = tf.constant(5);
var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)),
() => tf.add(y, tf.constant(23)));
//tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false);
int result = z.eval(sess);
assertEquals(result, 34);
});
}
//[Ignore("This Test Fails due to missing edges in the graph!")]
[TestMethod]
public void testCondFalse()
{
var graph = tf.Graph().as_default();
with(tf.Session(graph), sess =>
{
var x = tf.constant(2);
var y = tf.constant(1);
var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)),
() => tf.add(y, tf.constant(23)));
int result = z.eval(sess);
assertEquals(result, 24);
});
}
// NOTE: all other test python test cases of this class are either not needed due to strong typing or dest a deprecated api // NOTE: all other test python test cases of this class are either not needed due to strong typing or dest a deprecated api
} }


Loading…
Cancel
Save