@@ -10,7 +10,7 @@ namespace Tensorflow.Operations | |||||
/// </summary> | /// </summary> | ||||
public class CondContext : ControlFlowContext | public class CondContext : ControlFlowContext | ||||
{ | { | ||||
private string _name; | |||||
/// <summary> | /// <summary> | ||||
/// The boolean tensor for the cond predicate | /// The boolean tensor for the cond predicate | ||||
@@ -207,6 +207,9 @@ namespace Tensorflow.Operations | |||||
_values.Add(real_val.name); | _values.Add(real_val.name); | ||||
_external_values[real_val.name] = real_val; | _external_values[real_val.name] = real_val; | ||||
} | } | ||||
var (t0, t1) = control_flow_ops._SwitchRefOrTensor(real_val, _pred); | |||||
real_val = new[] {t0, t1}[_branch]; | |||||
_external_values[val.name] = real_val; | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -37,7 +37,8 @@ namespace Tensorflow.Operations | |||||
_context_stack = new Stack<IControlFlowContext>(); | _context_stack = new Stack<IControlFlowContext>(); | ||||
} | } | ||||
public string name { get; set; } | |||||
public string name { get => _name; } | |||||
protected string _name; | |||||
public void __init__() | public void __init__() | ||||
{ | { | ||||
@@ -279,12 +279,36 @@ namespace Tensorflow | |||||
/// <param name="tensor"> the Tensor to be used as the input at the given index.</param> | /// <param name="tensor"> the Tensor to be used as the input at the given index.</param> | ||||
public void _update_input(int index, Tensor tensor) | public void _update_input(int index, Tensor tensor) | ||||
{ | { | ||||
throw new NotImplementedException("_update_input"); | |||||
var input = _tf_input(index); | |||||
var output = tensor._as_tf_output(); | |||||
_assert_same_graph( tensor); | |||||
// Reset cached inputs. | |||||
_inputs=new InputList(new Tensor[]{ tensor }); // is this right? original code: self._inputs_val=None | |||||
// TODO: implement below code dependencies | // TODO: implement below code dependencies | ||||
//_assert_same_graph( tensor); | |||||
//// Reset cached inputs. | |||||
//_inputs_val = null; | |||||
//c_api.UpdateEdge(_graph._c_graph, tensor._as_tf_output(), _tf_input(index)); | |||||
//c_api.UpdateEdge(_graph._c_graph, output, input); | |||||
} | |||||
private void _assert_same_graph(Tensor tensor) | |||||
{ | |||||
//TODO: implement | |||||
} | |||||
/// <summary> | |||||
/// Create and return a new TF_Output for output_idx'th output of this op. | |||||
/// </summary> | |||||
public TF_Output _tf_output(int output_idx) | |||||
{ | |||||
var tf_output = new TF_Output(op, output_idx); | |||||
return tf_output; | |||||
} | |||||
/// <summary> | |||||
/// Create and return a new TF_Input for input_idx'th input of this op. | |||||
/// </summary> | |||||
public TF_Input _tf_input(int input_idx) | |||||
{ | |||||
var tf_input = new TF_Input(op, input_idx); | |||||
return tf_input; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -3,6 +3,7 @@ using System; | |||||
using System.Collections; | using System.Collections; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.ComponentModel; | using System.ComponentModel; | ||||
using System.Diagnostics; | |||||
using System.Linq; | using System.Linq; | ||||
using System.Text; | using System.Text; | ||||
@@ -82,7 +83,10 @@ namespace Tensorflow | |||||
} | } | ||||
catch (Exception ex) | catch (Exception ex) | ||||
{ | { | ||||
Console.WriteLine(ex.ToString()); | |||||
Console.WriteLine(ex.ToString()); | |||||
#if DEBUG | |||||
Debugger.Break(); | |||||
#endif | |||||
return default(TOut); | return default(TOut); | ||||
} | } | ||||
finally | finally | ||||
@@ -255,16 +255,17 @@ namespace Tensorflow | |||||
public override string ToString() | public override string ToString() | ||||
{ | { | ||||
if(NDims == 0) | |||||
{ | |||||
switch (dtype) | |||||
{ | |||||
case TF_DataType.TF_INT32: | |||||
return Data<int>()[0].ToString(); | |||||
} | |||||
} | |||||
return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype.ToString()}"; | |||||
// this can throw IndexOutOfRangeException | |||||
//if(NDims == 0) | |||||
//{ | |||||
// switch (dtype) | |||||
// { | |||||
// case TF_DataType.TF_INT32: | |||||
// return Data<int>()[0].ToString(); | |||||
// } | |||||
//} | |||||
return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}"; | |||||
} | } | ||||
public void Dispose() | public void Dispose() | ||||
@@ -64,7 +64,6 @@ namespace TensorFlowNET.UnitTest.ops_test | |||||
}); | }); | ||||
} | } | ||||
[Ignore("Switch op gets not inserted correctly in the graph")] | |||||
[TestMethod] | [TestMethod] | ||||
public void TestCond() | public void TestCond() | ||||
{ | { | ||||
@@ -94,42 +93,12 @@ namespace TensorFlowNET.UnitTest.ops_test | |||||
//self.assertEqual(op.outputs, new object[0]); | //self.assertEqual(op.outputs, new object[0]); | ||||
var op_input = op.inputs[0].op; | var op_input = op.inputs[0].op; | ||||
self.assertEqual(op_input.type, "Switch"); | self.assertEqual(op_input.type, "Switch"); | ||||
self.assertEqual(op_input.inputs[0], x); | |||||
self.assertEqual(op_input.inputs[0].name, x.name); | |||||
self.assertEqual(op.graph, g); | self.assertEqual(op.graph, g); | ||||
self.assertIsNotNone(op._get_control_flow_context()); | self.assertIsNotNone(op._get_control_flow_context()); | ||||
self.assertEqual((op._get_control_flow_context() as ControlFlowContext).name, "cond/cond_text"); | |||||
var cond_text = op._get_control_flow_context() as ControlFlowContext; | |||||
self.assertEqual(cond_text.name, "cond/cond_text"); | |||||
}); | }); | ||||
/* | |||||
@test_util.run_v1_only("b/120545219") | |||||
def testCond(self): | |||||
g = ops.Graph() | |||||
with g.as_default(): | |||||
x = test_ops.int_output() | |||||
def true_fn(): | |||||
ops._create_c_op(ops.get_default_graph(), | |||||
ops._NodeDef("IntInput", "cond/myop"), [x], []) | |||||
new_ops = g._add_new_tf_operations() | |||||
self.assertEqual(len(new_ops), 1) | |||||
return x | |||||
control_flow_ops.cond(x < 10, true_fn, lambda: x) | |||||
op = g.get_operation_by_name("cond/myop") | |||||
self.assertIsNotNone(op) | |||||
self.assertEqual(op.name, "cond/myop") | |||||
self.assertEqual(op.type, "IntInput") | |||||
self.assertEqual(op.outputs, []) | |||||
op_input = op.inputs[0].op | |||||
self.assertEqual(op_input.type, "Switch") | |||||
self.assertEqual(op_input.inputs[0], x) | |||||
self.assertEqual(op.graph, g) | |||||
# pylint: disable=protected-access | |||||
self.assertIsNotNone(op._get_control_flow_context()) | |||||
self.assertEqual(op._get_control_flow_context().name, | |||||
"cond/cond_text") | |||||
# pylint: enable=protected-access | |||||
*/ | |||||
} | } | ||||
[Ignore("Todo: Port")] | [Ignore("Todo: Port")] | ||||