diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs index 385caf1c..d525dc66 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs @@ -10,7 +10,7 @@ namespace Tensorflow.Operations /// public class CondContext : ControlFlowContext { - private string _name; + /// /// The boolean tensor for the cond predicate @@ -207,6 +207,9 @@ namespace Tensorflow.Operations _values.Add(real_val.name); _external_values[real_val.name] = real_val; } + var (t0, t1) = control_flow_ops._SwitchRefOrTensor(real_val, _pred); + real_val = new[] {t0, t1}[_branch]; + _external_values[val.name] = real_val; } else { diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs index 0556b526..fef79c8d 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs @@ -37,7 +37,8 @@ namespace Tensorflow.Operations _context_stack = new Stack(); } - public string name { get; set; } + public string name { get => _name; } + protected string _name; public void __init__() { diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 245e38b5..deab05f1 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -279,12 +279,36 @@ namespace Tensorflow /// the Tensor to be used as the input at the given index. public void _update_input(int index, Tensor tensor) { - throw new NotImplementedException("_update_input"); + var input = _tf_input(index); + var output = tensor._as_tf_output(); + _assert_same_graph( tensor); + // Reset cached inputs. + _inputs=new InputList(new Tensor[]{ tensor }); // is this right? original code: self._inputs_val=None // TODO: implement below code dependencies - //_assert_same_graph( tensor); - //// Reset cached inputs. - //_inputs_val = null; - //c_api.UpdateEdge(_graph._c_graph, tensor._as_tf_output(), _tf_input(index)); + //c_api.UpdateEdge(_graph._c_graph, output, input); + } + + private void _assert_same_graph(Tensor tensor) + { + //TODO: implement + } + + /// + /// Create and return a new TF_Output for output_idx'th output of this op. + /// + public TF_Output _tf_output(int output_idx) + { + var tf_output = new TF_Output(op, output_idx); + return tf_output; + } + + /// + /// Create and return a new TF_Input for input_idx'th input of this op. + /// + public TF_Input _tf_input(int input_idx) + { + var tf_input = new TF_Input(op, input_idx); + return tf_input; } } } diff --git a/src/TensorFlowNET.Core/Python.cs b/src/TensorFlowNET.Core/Python.cs index eaa57681..2d8e3e8e 100644 --- a/src/TensorFlowNET.Core/Python.cs +++ b/src/TensorFlowNET.Core/Python.cs @@ -3,6 +3,7 @@ using System; using System.Collections; using System.Collections.Generic; using System.ComponentModel; +using System.Diagnostics; using System.Linq; using System.Text; @@ -82,7 +83,10 @@ namespace Tensorflow } catch (Exception ex) { - Console.WriteLine(ex.ToString()); + Console.WriteLine(ex.ToString()); +#if DEBUG + Debugger.Break(); +#endif return default(TOut); } finally diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 1665bc40..4d5c58db 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -255,16 +255,17 @@ namespace Tensorflow public override string ToString() { - if(NDims == 0) - { - switch (dtype) - { - case TF_DataType.TF_INT32: - return Data()[0].ToString(); - } - } - - return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype.ToString()}"; + // this can throw IndexOutOfRangeException + //if(NDims == 0) + //{ + // switch (dtype) + // { + // case TF_DataType.TF_INT32: + // return Data()[0].ToString(); + // } + //} + + return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}"; } public void Dispose() diff --git a/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs b/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs index bace1dde..88064f21 100644 --- a/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs +++ b/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs @@ -64,7 +64,6 @@ namespace TensorFlowNET.UnitTest.ops_test }); } - [Ignore("Switch op gets not inserted correctly in the graph")] [TestMethod] public void TestCond() { @@ -94,42 +93,12 @@ namespace TensorFlowNET.UnitTest.ops_test //self.assertEqual(op.outputs, new object[0]); var op_input = op.inputs[0].op; self.assertEqual(op_input.type, "Switch"); - self.assertEqual(op_input.inputs[0], x); + self.assertEqual(op_input.inputs[0].name, x.name); self.assertEqual(op.graph, g); self.assertIsNotNone(op._get_control_flow_context()); - self.assertEqual((op._get_control_flow_context() as ControlFlowContext).name, "cond/cond_text"); + var cond_text = op._get_control_flow_context() as ControlFlowContext; + self.assertEqual(cond_text.name, "cond/cond_text"); }); - /* - @test_util.run_v1_only("b/120545219") - def testCond(self): - g = ops.Graph() - with g.as_default(): - x = test_ops.int_output() - - def true_fn(): - ops._create_c_op(ops.get_default_graph(), - ops._NodeDef("IntInput", "cond/myop"), [x], []) - new_ops = g._add_new_tf_operations() - self.assertEqual(len(new_ops), 1) - return x - - control_flow_ops.cond(x < 10, true_fn, lambda: x) - - op = g.get_operation_by_name("cond/myop") - self.assertIsNotNone(op) - self.assertEqual(op.name, "cond/myop") - self.assertEqual(op.type, "IntInput") - self.assertEqual(op.outputs, []) - op_input = op.inputs[0].op - self.assertEqual(op_input.type, "Switch") - self.assertEqual(op_input.inputs[0], x) - self.assertEqual(op.graph, g) - # pylint: disable=protected-access - self.assertIsNotNone(op._get_control_flow_context()) - self.assertEqual(op._get_control_flow_context().name, - "cond/cond_text") - # pylint: enable=protected-access - */ } [Ignore("Todo: Port")]