diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
index c00e2c0e..7b70a76a 100644
--- a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
@@ -37,7 +37,7 @@ namespace Tensorflow.Operations
///
public CondContext(Tensor pred = null,
Tensor pivot = null,
- int? branch = null,
+ int branch = 0,
string name = "cond_text",
CondContextDef context_def = null,
string import_scope = null)
@@ -55,7 +55,7 @@ namespace Tensorflow.Operations
base.__init__();
_pred = pred;
_pivot = pivot;
-
+ _branch = branch; // 0 or 1 representing this branch
// Values considered to have been already seen in this context. pred is not
// included in this context.
_values.Add(pred.name);
@@ -105,11 +105,6 @@ namespace Tensorflow.Operations
_external_values[result.name] = result;
}
- // for debug purpose
- if(ops.get_default_graph()._nodes_by_name.Count > 60)
- {
- }
-
with(ops.control_dependencies(null), ctrl =>
{
var (r0, r1) = control_flow_ops._SwitchRefOrTensor(result, _pred);
diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs
index 73f9d847..262d8e75 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.Control.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs
@@ -29,7 +29,7 @@ namespace Tensorflow
public void _add_control_input(Operation op)
{
- // c_api.TF_AddControlInput(_operDesc, op);
+ //c_api.TF_AddControlInput(_operDesc, op);
c_api.AddControlInput(graph, _handle, op);
}
diff --git a/src/TensorFlowNET.Core/Operations/Operation.Input.cs b/src/TensorFlowNET.Core/Operations/Operation.Input.cs
index 26c9c08c..349a7603 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.Input.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.Input.cs
@@ -27,9 +27,9 @@ namespace Tensorflow
for (int i = 0; i < NumInputs; i++)
{
- var tf_outputs = Input(i);
- var op = new Operation(tf_outputs.oper);
- retval[i] = op.outputs[tf_outputs.index];
+ var tf_output = Input(i);
+ var op = new Operation(tf_output.oper);
+ retval[i] = op.outputs[tf_output.index];
}
_inputs = new InputList(retval);
diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs
index e0caef72..5915c216 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.cs
@@ -51,6 +51,7 @@ namespace Tensorflow
public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle));
private NodeDef _node_def;
+ //[JsonIgnore]
public NodeDef node_def
{
get
@@ -290,6 +291,7 @@ namespace Tensorflow
// the updated inputs are reloaded from the c_api
c_api.UpdateEdge(_graph, output, input, status);
//var updated_inputs = inputs;
+ status.Check();
}
private void _assert_same_graph(Tensor tensor)
diff --git a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs
index 21daf844..31e2cad3 100644
--- a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs
@@ -36,12 +36,11 @@ namespace Tensorflow
public static (Tensor, Tensor) @switch(Tensor data, Tensor pred, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Switch", name, new { data, pred });
- var _result = (_op.outputs[0], _op.outputs[1]);
var _inputs_flat = _op.inputs;
var _attrs = ("T", _op.get_attr("T"));
// TODO: missing original code
//_execute.record_gradient("Switch", _inputs_flat, _attrs, _result, name);
- return _result;
+ return (_op.outputs[0], _op.outputs[1]);
}
public static (Tensor, Tensor) merge(Tensor[] inputs, string name = null)
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
index b52cd630..5f26becd 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
@@ -128,9 +128,9 @@ namespace Tensorflow
public Tensor(Operation op, int value_index, TF_DataType dtype)
{
- this.op = op;
- this.value_index = value_index;
- this._dtype = dtype;
+ _op = op;
+ _value_index = value_index;
+ _dtype = dtype;
_id = ops.uid();
}
}
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs
index ee165c93..6e92fe7c 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs
@@ -22,17 +22,19 @@ namespace Tensorflow
public int Id => _id;
//[JsonIgnore]
public Graph graph => op?.graph;
+ private Operation _op;
//[JsonIgnore]
- public Operation op { get; }
+ public Operation op => _op;
//[JsonIgnore]
public Tensor[] outputs => op.outputs;
///
/// The string name of this tensor.
///
- public string name => $"{(op == null ? "Operation was not named" : $"{op.name}:{value_index}")}";
+ public string name => $"{(op == null ? "Operation was not named" : $"{op.name}:{_value_index}")}";
- public int value_index { get; }
+ private int _value_index;
+ public int value_index => _value_index;
private Status status = new Status();
diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs b/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs
index 9b259e57..75e35716 100644
--- a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs
+++ b/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs
@@ -22,8 +22,8 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
var y = tf.constant(5, name: "y");
var z = control_flow_ops.cond(tf.less(x, y),
- () => tf.constant(22, name: "t2"),
- () => tf.constant(55, name: "f5"));
+ () => tf.constant(22, name: "t22"),
+ () => tf.constant(55, name: "f55"));
int result = z.eval(sess);
assertEquals(result, 22);
@@ -41,8 +41,8 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
var y = tf.constant(1, name: "y");
var z = control_flow_ops.cond(tf.less(x, y),
- () => tf.constant(22, name: "t2"),
- () => tf.constant(11, name: "f1"));
+ () => tf.constant(22, name: "t22"),
+ () => tf.constant(11, name: "f11"));
int result = z.eval(sess);
assertEquals(result, 11);
@@ -56,17 +56,18 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
with(tf.Session(graph), sess =>
{
- var x = tf.constant(2);
- var y = tf.constant(5);
- var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)),
- () => tf.add(y, tf.constant(23)));
- //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false);
+ var x = tf.constant(2, name: "x");
+ var y = tf.constant(5, name: "y");
+
+ var z = control_flow_ops.cond(tf.less(x, y),
+ () => tf.multiply(x, 17),
+ () => tf.add(y, 23));
+
int result = z.eval(sess);
assertEquals(result, 34);
});
}
- //[Ignore("This Test Fails due to missing edges in the graph!")]
[TestMethod]
public void testCondFalse()
{
@@ -76,8 +77,11 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
{
var x = tf.constant(2);
var y = tf.constant(1);
- var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)),
- () => tf.add(y, tf.constant(23)));
+
+ var z = control_flow_ops.cond(tf.less(x, y),
+ () => tf.multiply(x, 17),
+ () => tf.add(y, 23));
+
int result = z.eval(sess);
assertEquals(result, 24);
});