Browse Source

minor changes

tags/v0.9
Meinrad Recheis 6 years ago
parent
commit
9dba680041
3 changed files with 8 additions and 24 deletions
  1. +0
    -22
      src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
  2. +1
    -1
      test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs
  3. +7
    -1
      test/TensorFlowNET.UnitTest/ops_test/ControlDependenciesTest.cs

+ 0
- 22
src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs View File

@@ -290,33 +290,11 @@ namespace Tensorflow
{ {
// TODO: here a chunk of original code is missing // TODO: here a chunk of original code is missing
/* /*
if fn1 is not None:
if true_fn is not None:
raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.")
true_fn = fn1
elif true_fn is None:
raise TypeError("cond(): true_fn argument required")
if fn2 is not None:
if false_fn is not None:
raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.")
false_fn = fn2
elif false_fn is None:
raise TypeError("cond(): false_fn argument required")

if not callable(true_fn):
raise TypeError("true_fn must be callable.")
if not callable(false_fn):
raise TypeError("false_fn must be callable.")

with ops.name_scope(name, "cond", [pred]): with ops.name_scope(name, "cond", [pred]):
if context.executing_eagerly(): if context.executing_eagerly():
if pred: if pred:
return _UnpackIfSingleton(true_fn()) return _UnpackIfSingleton(true_fn())
return _UnpackIfSingleton(false_fn()) return _UnpackIfSingleton(false_fn())

# Add the Switch to the graph.
if isinstance(pred, bool):
raise TypeError("pred must not be a Python bool")
*/ */


// Add the Switch to the graph. // Add the Switch to the graph.


+ 1
- 1
test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs View File

@@ -19,7 +19,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
var y = tf.constant(5); var y = tf.constant(5);
var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)), var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)),
() => tf.add(y, tf.constant(23))); () => tf.add(y, tf.constant(23)));
tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false);
//tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false);
self.assertEquals(eval_scalar(z), 34); self.assertEquals(eval_scalar(z), 34);
}); });
} }


+ 7
- 1
test/TensorFlowNET.UnitTest/ops_test/ControlDependenciesTest.cs View File

@@ -139,7 +139,7 @@ namespace TensorFlowNET.UnitTest.ops_test
var a_2 = constant_op.constant(3.0); var a_2 = constant_op.constant(3.0);
var a_3 = constant_op.constant(4.0); var a_3 = constant_op.constant(4.0);
var a_4 = constant_op.constant(5.0); var a_4 = constant_op.constant(5.0);
Operation b_1 = null, b_2 = null;
Tensor b_1 = null, b_2 = null;
with(g.control_dependencies(new[] { a_1, a_2, a_3, a_4 }), ctrl => with(g.control_dependencies(new[] { a_1, a_2, a_3, a_4 }), ctrl =>
{ {
b_1 = constant_op.constant(6.0); b_1 = constant_op.constant(6.0);
@@ -157,6 +157,12 @@ namespace TensorFlowNET.UnitTest.ops_test
}); });
}); });
}); });
var z=tf.add(a_1, tf.multiply(b_2, b_1));
with(g.control_dependencies(new[] {z}), ctrl =>
{
var z1 = tf.add(a_3, tf.multiply(a_4, a_2));
});
tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false);
assertItemsEqual(b_1.op.control_inputs, new[] { a_1.op, a_2.op, a_3.op, a_4.op }); assertItemsEqual(b_1.op.control_inputs, new[] { a_1.op, a_2.op, a_3.op, a_4.op });
assertItemsEqual(b_2.op.control_inputs, b_1.op.control_inputs); assertItemsEqual(b_2.op.control_inputs, b_1.op.control_inputs);
} }


Loading…
Cancel
Save