From 6a1ed38e027aaa35524be981fc5a92099d5a6319 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 14 Apr 2019 20:09:31 -0500 Subject: [PATCH] fix #229 --- .../Operations/ControlFlows/CondContext.cs | 9 ++---- .../Operations/Operation.Control.cs | 2 +- .../Operations/Operation.Input.cs | 6 ++-- .../Operations/Operation.cs | 2 ++ .../Operations/gen_control_flow_ops.py.cs | 3 +- .../Tensors/Tensor.Creation.cs | 6 ++-- src/TensorFlowNET.Core/Tensors/Tensor.cs | 8 ++++-- .../control_flow_ops_test/CondTestCases.cs | 28 +++++++++++-------- 8 files changed, 33 insertions(+), 31 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs index c00e2c0e..7b70a76a 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs @@ -37,7 +37,7 @@ namespace Tensorflow.Operations /// public CondContext(Tensor pred = null, Tensor pivot = null, - int? branch = null, + int branch = 0, string name = "cond_text", CondContextDef context_def = null, string import_scope = null) @@ -55,7 +55,7 @@ namespace Tensorflow.Operations base.__init__(); _pred = pred; _pivot = pivot; - + _branch = branch; // 0 or 1 representing this branch // Values considered to have been already seen in this context. pred is not // included in this context. _values.Add(pred.name); @@ -105,11 +105,6 @@ namespace Tensorflow.Operations _external_values[result.name] = result; } - // for debug purpose - if(ops.get_default_graph()._nodes_by_name.Count > 60) - { - } - with(ops.control_dependencies(null), ctrl => { var (r0, r1) = control_flow_ops._SwitchRefOrTensor(result, _pred); diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs index 73f9d847..262d8e75 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Control.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs @@ -29,7 +29,7 @@ namespace Tensorflow public void _add_control_input(Operation op) { - // c_api.TF_AddControlInput(_operDesc, op); + //c_api.TF_AddControlInput(_operDesc, op); c_api.AddControlInput(graph, _handle, op); } diff --git a/src/TensorFlowNET.Core/Operations/Operation.Input.cs b/src/TensorFlowNET.Core/Operations/Operation.Input.cs index 26c9c08c..349a7603 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Input.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Input.cs @@ -27,9 +27,9 @@ namespace Tensorflow for (int i = 0; i < NumInputs; i++) { - var tf_outputs = Input(i); - var op = new Operation(tf_outputs.oper); - retval[i] = op.outputs[tf_outputs.index]; + var tf_output = Input(i); + var op = new Operation(tf_output.oper); + retval[i] = op.outputs[tf_output.index]; } _inputs = new InputList(retval); diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index e0caef72..5915c216 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -51,6 +51,7 @@ namespace Tensorflow public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); private NodeDef _node_def; + //[JsonIgnore] public NodeDef node_def { get @@ -290,6 +291,7 @@ namespace Tensorflow // the updated inputs are reloaded from the c_api c_api.UpdateEdge(_graph, output, input, status); //var updated_inputs = inputs; + status.Check(); } private void _assert_same_graph(Tensor tensor) diff --git a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs index 21daf844..31e2cad3 100644 --- a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs @@ -36,12 +36,11 @@ namespace Tensorflow public static (Tensor, Tensor) @switch(Tensor data, Tensor pred, string name = null) { var _op = _op_def_lib._apply_op_helper("Switch", name, new { data, pred }); - var _result = (_op.outputs[0], _op.outputs[1]); var _inputs_flat = _op.inputs; var _attrs = ("T", _op.get_attr("T")); // TODO: missing original code //_execute.record_gradient("Switch", _inputs_flat, _attrs, _result, name); - return _result; + return (_op.outputs[0], _op.outputs[1]); } public static (Tensor, Tensor) merge(Tensor[] inputs, string name = null) diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index b52cd630..5f26becd 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -128,9 +128,9 @@ namespace Tensorflow public Tensor(Operation op, int value_index, TF_DataType dtype) { - this.op = op; - this.value_index = value_index; - this._dtype = dtype; + _op = op; + _value_index = value_index; + _dtype = dtype; _id = ops.uid(); } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index ee165c93..6e92fe7c 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -22,17 +22,19 @@ namespace Tensorflow public int Id => _id; //[JsonIgnore] public Graph graph => op?.graph; + private Operation _op; //[JsonIgnore] - public Operation op { get; } + public Operation op => _op; //[JsonIgnore] public Tensor[] outputs => op.outputs; /// /// The string name of this tensor. /// - public string name => $"{(op == null ? "Operation was not named" : $"{op.name}:{value_index}")}"; + public string name => $"{(op == null ? "Operation was not named" : $"{op.name}:{_value_index}")}"; - public int value_index { get; } + private int _value_index; + public int value_index => _value_index; private Status status = new Status(); diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs b/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs index 9b259e57..75e35716 100644 --- a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs +++ b/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs @@ -22,8 +22,8 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test var y = tf.constant(5, name: "y"); var z = control_flow_ops.cond(tf.less(x, y), - () => tf.constant(22, name: "t2"), - () => tf.constant(55, name: "f5")); + () => tf.constant(22, name: "t22"), + () => tf.constant(55, name: "f55")); int result = z.eval(sess); assertEquals(result, 22); @@ -41,8 +41,8 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test var y = tf.constant(1, name: "y"); var z = control_flow_ops.cond(tf.less(x, y), - () => tf.constant(22, name: "t2"), - () => tf.constant(11, name: "f1")); + () => tf.constant(22, name: "t22"), + () => tf.constant(11, name: "f11")); int result = z.eval(sess); assertEquals(result, 11); @@ -56,17 +56,18 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test with(tf.Session(graph), sess => { - var x = tf.constant(2); - var y = tf.constant(5); - var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)), - () => tf.add(y, tf.constant(23))); - //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false); + var x = tf.constant(2, name: "x"); + var y = tf.constant(5, name: "y"); + + var z = control_flow_ops.cond(tf.less(x, y), + () => tf.multiply(x, 17), + () => tf.add(y, 23)); + int result = z.eval(sess); assertEquals(result, 34); }); } - //[Ignore("This Test Fails due to missing edges in the graph!")] [TestMethod] public void testCondFalse() { @@ -76,8 +77,11 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test { var x = tf.constant(2); var y = tf.constant(1); - var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)), - () => tf.add(y, tf.constant(23))); + + var z = control_flow_ops.cond(tf.less(x, y), + () => tf.multiply(x, 17), + () => tf.add(y, 23)); + int result = z.eval(sess); assertEquals(result, 24); });