Browse Source

Merge pull request #223 from henon/master

cond false-branch fixed
tags/v0.9
Haiping GitHub 6 years ago
parent
commit
98c7dbf27e
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 54 additions and 52 deletions
  1. +4
    -1
      src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
  2. +2
    -1
      src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
  3. +29
    -5
      src/TensorFlowNET.Core/Operations/Operation.cs
  4. +5
    -1
      src/TensorFlowNET.Core/Python.cs
  5. +11
    -10
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  6. +3
    -34
      test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs

+ 4
- 1
src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs View File

@@ -10,7 +10,7 @@ namespace Tensorflow.Operations
/// </summary> /// </summary>
public class CondContext : ControlFlowContext public class CondContext : ControlFlowContext
{ {
private string _name;


/// <summary> /// <summary>
/// The boolean tensor for the cond predicate /// The boolean tensor for the cond predicate
@@ -207,6 +207,9 @@ namespace Tensorflow.Operations
_values.Add(real_val.name); _values.Add(real_val.name);
_external_values[real_val.name] = real_val; _external_values[real_val.name] = real_val;
} }
var (t0, t1) = control_flow_ops._SwitchRefOrTensor(real_val, _pred);
real_val = new[] {t0, t1}[_branch];
_external_values[val.name] = real_val;
} }
else else
{ {


+ 2
- 1
src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs View File

@@ -37,7 +37,8 @@ namespace Tensorflow.Operations
_context_stack = new Stack<IControlFlowContext>(); _context_stack = new Stack<IControlFlowContext>();
} }


public string name { get; set; }
public string name { get => _name; }
protected string _name;


public void __init__() public void __init__()
{ {


+ 29
- 5
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -279,12 +279,36 @@ namespace Tensorflow
/// <param name="tensor"> the Tensor to be used as the input at the given index.</param> /// <param name="tensor"> the Tensor to be used as the input at the given index.</param>
public void _update_input(int index, Tensor tensor) public void _update_input(int index, Tensor tensor)
{ {
throw new NotImplementedException("_update_input");
var input = _tf_input(index);
var output = tensor._as_tf_output();
_assert_same_graph( tensor);
// Reset cached inputs.
_inputs=new InputList(new Tensor[]{ tensor }); // is this right? original code: self._inputs_val=None
// TODO: implement below code dependencies // TODO: implement below code dependencies
//_assert_same_graph( tensor);
//// Reset cached inputs.
//_inputs_val = null;
//c_api.UpdateEdge(_graph._c_graph, tensor._as_tf_output(), _tf_input(index));
//c_api.UpdateEdge(_graph._c_graph, output, input);
}

private void _assert_same_graph(Tensor tensor)
{
//TODO: implement
}

/// <summary>
/// Create and return a new TF_Output for output_idx'th output of this op.
/// </summary>
public TF_Output _tf_output(int output_idx)
{
var tf_output = new TF_Output(op, output_idx);
return tf_output;
}

/// <summary>
/// Create and return a new TF_Input for input_idx'th input of this op.
/// </summary>
public TF_Input _tf_input(int input_idx)
{
var tf_input = new TF_Input(op, input_idx);
return tf_input;
} }
} }
} }

+ 5
- 1
src/TensorFlowNET.Core/Python.cs View File

@@ -3,6 +3,7 @@ using System;
using System.Collections; using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.ComponentModel; using System.ComponentModel;
using System.Diagnostics;
using System.Linq; using System.Linq;
using System.Text; using System.Text;


@@ -82,7 +83,10 @@ namespace Tensorflow
} }
catch (Exception ex) catch (Exception ex)
{ {
Console.WriteLine(ex.ToString());
Console.WriteLine(ex.ToString());
#if DEBUG
Debugger.Break();
#endif
return default(TOut); return default(TOut);
} }
finally finally


+ 11
- 10
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -255,16 +255,17 @@ namespace Tensorflow


public override string ToString() public override string ToString()
{ {
if(NDims == 0)
{
switch (dtype)
{
case TF_DataType.TF_INT32:
return Data<int>()[0].ToString();
}
}

return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype.ToString()}";
// this can throw IndexOutOfRangeException
//if(NDims == 0)
//{
// switch (dtype)
// {
// case TF_DataType.TF_INT32:
// return Data<int>()[0].ToString();
// }
//}

return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}";
} }


public void Dispose() public void Dispose()


+ 3
- 34
test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs View File

@@ -64,7 +64,6 @@ namespace TensorFlowNET.UnitTest.ops_test
}); });
} }
[Ignore("Switch op gets not inserted correctly in the graph")]
[TestMethod] [TestMethod]
public void TestCond() public void TestCond()
{ {
@@ -94,42 +93,12 @@ namespace TensorFlowNET.UnitTest.ops_test
//self.assertEqual(op.outputs, new object[0]); //self.assertEqual(op.outputs, new object[0]);
var op_input = op.inputs[0].op; var op_input = op.inputs[0].op;
self.assertEqual(op_input.type, "Switch"); self.assertEqual(op_input.type, "Switch");
self.assertEqual(op_input.inputs[0], x);
self.assertEqual(op_input.inputs[0].name, x.name);
self.assertEqual(op.graph, g); self.assertEqual(op.graph, g);
self.assertIsNotNone(op._get_control_flow_context()); self.assertIsNotNone(op._get_control_flow_context());
self.assertEqual((op._get_control_flow_context() as ControlFlowContext).name, "cond/cond_text");
var cond_text = op._get_control_flow_context() as ControlFlowContext;
self.assertEqual(cond_text.name, "cond/cond_text");
}); });
/*
@test_util.run_v1_only("b/120545219")
def testCond(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
def true_fn():
ops._create_c_op(ops.get_default_graph(),
ops._NodeDef("IntInput", "cond/myop"), [x], [])
new_ops = g._add_new_tf_operations()
self.assertEqual(len(new_ops), 1)
return x
control_flow_ops.cond(x < 10, true_fn, lambda: x)
op = g.get_operation_by_name("cond/myop")
self.assertIsNotNone(op)
self.assertEqual(op.name, "cond/myop")
self.assertEqual(op.type, "IntInput")
self.assertEqual(op.outputs, [])
op_input = op.inputs[0].op
self.assertEqual(op_input.type, "Switch")
self.assertEqual(op_input.inputs[0], x)
self.assertEqual(op.graph, g)
# pylint: disable=protected-access
self.assertIsNotNone(op._get_control_flow_context())
self.assertEqual(op._get_control_flow_context().name,
"cond/cond_text")
# pylint: enable=protected-access
*/
} }
[Ignore("Todo: Port")] [Ignore("Todo: Port")]


Loading…
Cancel
Save