Browse Source

minor changes

tags/v0.9
Meinrad Recheis 6 years ago
parent
commit
9dba680041
3 changed files with 8 additions and 24 deletions
  1. +0
    -22
      src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
  2. +1
    -1
      test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs
  3. +7
    -1
      test/TensorFlowNET.UnitTest/ops_test/ControlDependenciesTest.cs

+ 0
- 22
src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs View File

@@ -290,33 +290,11 @@ namespace Tensorflow
{
// TODO: here a chunk of original code is missing
/*
if fn1 is not None:
if true_fn is not None:
raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.")
true_fn = fn1
elif true_fn is None:
raise TypeError("cond(): true_fn argument required")
if fn2 is not None:
if false_fn is not None:
raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.")
false_fn = fn2
elif false_fn is None:
raise TypeError("cond(): false_fn argument required")

if not callable(true_fn):
raise TypeError("true_fn must be callable.")
if not callable(false_fn):
raise TypeError("false_fn must be callable.")

with ops.name_scope(name, "cond", [pred]):
if context.executing_eagerly():
if pred:
return _UnpackIfSingleton(true_fn())
return _UnpackIfSingleton(false_fn())

# Add the Switch to the graph.
if isinstance(pred, bool):
raise TypeError("pred must not be a Python bool")
*/

// Add the Switch to the graph.


+ 1
- 1
test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs View File

@@ -19,7 +19,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
var y = tf.constant(5);
var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)),
() => tf.add(y, tf.constant(23)));
tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false);
//tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false);
self.assertEquals(eval_scalar(z), 34);
});
}


+ 7
- 1
test/TensorFlowNET.UnitTest/ops_test/ControlDependenciesTest.cs View File

@@ -139,7 +139,7 @@ namespace TensorFlowNET.UnitTest.ops_test
var a_2 = constant_op.constant(3.0);
var a_3 = constant_op.constant(4.0);
var a_4 = constant_op.constant(5.0);
Operation b_1 = null, b_2 = null;
Tensor b_1 = null, b_2 = null;
with(g.control_dependencies(new[] { a_1, a_2, a_3, a_4 }), ctrl =>
{
b_1 = constant_op.constant(6.0);
@@ -157,6 +157,12 @@ namespace TensorFlowNET.UnitTest.ops_test
});
});
});
var z=tf.add(a_1, tf.multiply(b_2, b_1));
with(g.control_dependencies(new[] {z}), ctrl =>
{
var z1 = tf.add(a_3, tf.multiply(a_4, a_2));
});
tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false);
assertItemsEqual(b_1.op.control_inputs, new[] { a_1.op, a_2.op, a_3.op, a_4.op });
assertItemsEqual(b_2.op.control_inputs, b_1.op.control_inputs);
}


Loading…
Cancel
Save