@@ -10,7 +10,7 @@ namespace Tensorflow.Operations | |||
/// </summary> | |||
public class CondContext : ControlFlowContext | |||
{ | |||
private string _name; | |||
/// <summary> | |||
/// The boolean tensor for the cond predicate | |||
@@ -37,7 +37,8 @@ namespace Tensorflow.Operations | |||
_context_stack = new Stack<IControlFlowContext>(); | |||
} | |||
public string name { get; set; } | |||
public string name { get => _name; } | |||
protected string _name; | |||
public void __init__() | |||
{ | |||
@@ -279,12 +279,36 @@ namespace Tensorflow | |||
/// <param name="tensor"> the Tensor to be used as the input at the given index.</param> | |||
public void _update_input(int index, Tensor tensor) | |||
{ | |||
throw new NotImplementedException("_update_input"); | |||
var input = _tf_input(index); | |||
var output = tensor._as_tf_output(); | |||
_assert_same_graph( tensor); | |||
// Reset cached inputs. | |||
_inputs=new InputList(new Tensor[]{ tensor }); // is this right? original code: self._inputs_val=None | |||
// TODO: implement below code dependencies | |||
//_assert_same_graph( tensor); | |||
//// Reset cached inputs. | |||
//_inputs_val = null; | |||
//c_api.UpdateEdge(_graph._c_graph, tensor._as_tf_output(), _tf_input(index)); | |||
//c_api.UpdateEdge(_graph._c_graph, output, input); | |||
} | |||
private void _assert_same_graph(Tensor tensor) | |||
{ | |||
//TODO: implement | |||
} | |||
/// <summary> | |||
/// Create and return a new TF_Output for output_idx'th output of this op. | |||
/// </summary> | |||
public TF_Output _tf_output(int output_idx) | |||
{ | |||
var tf_output = new TF_Output(op, output_idx); | |||
return tf_output; | |||
} | |||
/// <summary> | |||
/// Create and return a new TF_Input for input_idx'th input of this op. | |||
/// </summary> | |||
public TF_Input _tf_input(int input_idx) | |||
{ | |||
var tf_input = new TF_Input(op, input_idx); | |||
return tf_input; | |||
} | |||
} | |||
} |
@@ -255,16 +255,17 @@ namespace Tensorflow | |||
public override string ToString() | |||
{ | |||
if(NDims == 0) | |||
{ | |||
switch (dtype) | |||
{ | |||
case TF_DataType.TF_INT32: | |||
return Data<int>()[0].ToString(); | |||
} | |||
} | |||
return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype.ToString()}"; | |||
// this can throw IndexOutOfRangeException | |||
//if(NDims == 0) | |||
//{ | |||
// switch (dtype) | |||
// { | |||
// case TF_DataType.TF_INT32: | |||
// return Data<int>()[0].ToString(); | |||
// } | |||
//} | |||
return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}"; | |||
} | |||
public void Dispose() | |||
@@ -64,7 +64,6 @@ namespace TensorFlowNET.UnitTest.ops_test | |||
}); | |||
} | |||
[Ignore("Switch op gets not inserted correctly in the graph")] | |||
[TestMethod] | |||
public void TestCond() | |||
{ | |||
@@ -94,42 +93,12 @@ namespace TensorFlowNET.UnitTest.ops_test | |||
//self.assertEqual(op.outputs, new object[0]); | |||
var op_input = op.inputs[0].op; | |||
self.assertEqual(op_input.type, "Switch"); | |||
self.assertEqual(op_input.inputs[0], x); | |||
self.assertEqual(op_input.inputs[0].name, x.name); | |||
self.assertEqual(op.graph, g); | |||
self.assertIsNotNone(op._get_control_flow_context()); | |||
self.assertEqual((op._get_control_flow_context() as ControlFlowContext).name, "cond/cond_text"); | |||
var cond_text = op._get_control_flow_context() as ControlFlowContext; | |||
self.assertEqual(cond_text.name, "cond/cond_text"); | |||
}); | |||
/* | |||
@test_util.run_v1_only("b/120545219") | |||
def testCond(self): | |||
g = ops.Graph() | |||
with g.as_default(): | |||
x = test_ops.int_output() | |||
def true_fn(): | |||
ops._create_c_op(ops.get_default_graph(), | |||
ops._NodeDef("IntInput", "cond/myop"), [x], []) | |||
new_ops = g._add_new_tf_operations() | |||
self.assertEqual(len(new_ops), 1) | |||
return x | |||
control_flow_ops.cond(x < 10, true_fn, lambda: x) | |||
op = g.get_operation_by_name("cond/myop") | |||
self.assertIsNotNone(op) | |||
self.assertEqual(op.name, "cond/myop") | |||
self.assertEqual(op.type, "IntInput") | |||
self.assertEqual(op.outputs, []) | |||
op_input = op.inputs[0].op | |||
self.assertEqual(op_input.type, "Switch") | |||
self.assertEqual(op_input.inputs[0], x) | |||
self.assertEqual(op.graph, g) | |||
# pylint: disable=protected-access | |||
self.assertIsNotNone(op._get_control_flow_context()) | |||
self.assertEqual(op._get_control_flow_context().name, | |||
"cond/cond_text") | |||
# pylint: enable=protected-access | |||
*/ | |||
} | |||
[Ignore("Todo: Port")] | |||