Browse Source

Merge pull request #223 from henon/master

cond false-branch fixed
tags/v0.9
Haiping GitHub 6 years ago
parent
commit
98c7dbf27e
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 54 additions and 52 deletions
  1. +4
    -1
      src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
  2. +2
    -1
      src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
  3. +29
    -5
      src/TensorFlowNET.Core/Operations/Operation.cs
  4. +5
    -1
      src/TensorFlowNET.Core/Python.cs
  5. +11
    -10
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  6. +3
    -34
      test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs

+ 4
- 1
src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs View File

@@ -10,7 +10,7 @@ namespace Tensorflow.Operations
/// </summary>
public class CondContext : ControlFlowContext
{
private string _name;

/// <summary>
/// The boolean tensor for the cond predicate
@@ -207,6 +207,9 @@ namespace Tensorflow.Operations
_values.Add(real_val.name);
_external_values[real_val.name] = real_val;
}
var (t0, t1) = control_flow_ops._SwitchRefOrTensor(real_val, _pred);
real_val = new[] {t0, t1}[_branch];
_external_values[val.name] = real_val;
}
else
{


+ 2
- 1
src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs View File

@@ -37,7 +37,8 @@ namespace Tensorflow.Operations
_context_stack = new Stack<IControlFlowContext>();
}

public string name { get; set; }
public string name { get => _name; }
protected string _name;

public void __init__()
{


+ 29
- 5
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -279,12 +279,36 @@ namespace Tensorflow
/// <param name="tensor"> the Tensor to be used as the input at the given index.</param>
public void _update_input(int index, Tensor tensor)
{
throw new NotImplementedException("_update_input");
var input = _tf_input(index);
var output = tensor._as_tf_output();
_assert_same_graph( tensor);
// Reset cached inputs.
_inputs=new InputList(new Tensor[]{ tensor }); // is this right? original code: self._inputs_val=None
// TODO: implement below code dependencies
//_assert_same_graph( tensor);
//// Reset cached inputs.
//_inputs_val = null;
//c_api.UpdateEdge(_graph._c_graph, tensor._as_tf_output(), _tf_input(index));
//c_api.UpdateEdge(_graph._c_graph, output, input);
}

private void _assert_same_graph(Tensor tensor)
{
//TODO: implement
}

/// <summary>
/// Create and return a new TF_Output for output_idx'th output of this op.
/// </summary>
public TF_Output _tf_output(int output_idx)
{
var tf_output = new TF_Output(op, output_idx);
return tf_output;
}

/// <summary>
/// Create and return a new TF_Input for input_idx'th input of this op.
/// </summary>
public TF_Input _tf_input(int input_idx)
{
var tf_input = new TF_Input(op, input_idx);
return tf_input;
}
}
}

+ 5
- 1
src/TensorFlowNET.Core/Python.cs View File

@@ -3,6 +3,7 @@ using System;
using System.Collections;
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.Linq;
using System.Text;

@@ -82,7 +83,10 @@ namespace Tensorflow
}
catch (Exception ex)
{
Console.WriteLine(ex.ToString());
Console.WriteLine(ex.ToString());
#if DEBUG
Debugger.Break();
#endif
return default(TOut);
}
finally


+ 11
- 10
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -255,16 +255,17 @@ namespace Tensorflow

public override string ToString()
{
if(NDims == 0)
{
switch (dtype)
{
case TF_DataType.TF_INT32:
return Data<int>()[0].ToString();
}
}

return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype.ToString()}";
// this can throw IndexOutOfRangeException
//if(NDims == 0)
//{
// switch (dtype)
// {
// case TF_DataType.TF_INT32:
// return Data<int>()[0].ToString();
// }
//}

return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}";
}

public void Dispose()


+ 3
- 34
test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs View File

@@ -64,7 +64,6 @@ namespace TensorFlowNET.UnitTest.ops_test
});
}
[Ignore("Switch op gets not inserted correctly in the graph")]
[TestMethod]
public void TestCond()
{
@@ -94,42 +93,12 @@ namespace TensorFlowNET.UnitTest.ops_test
//self.assertEqual(op.outputs, new object[0]);
var op_input = op.inputs[0].op;
self.assertEqual(op_input.type, "Switch");
self.assertEqual(op_input.inputs[0], x);
self.assertEqual(op_input.inputs[0].name, x.name);
self.assertEqual(op.graph, g);
self.assertIsNotNone(op._get_control_flow_context());
self.assertEqual((op._get_control_flow_context() as ControlFlowContext).name, "cond/cond_text");
var cond_text = op._get_control_flow_context() as ControlFlowContext;
self.assertEqual(cond_text.name, "cond/cond_text");
});
/*
@test_util.run_v1_only("b/120545219")
def testCond(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
def true_fn():
ops._create_c_op(ops.get_default_graph(),
ops._NodeDef("IntInput", "cond/myop"), [x], [])
new_ops = g._add_new_tf_operations()
self.assertEqual(len(new_ops), 1)
return x
control_flow_ops.cond(x < 10, true_fn, lambda: x)
op = g.get_operation_by_name("cond/myop")
self.assertIsNotNone(op)
self.assertEqual(op.name, "cond/myop")
self.assertEqual(op.type, "IntInput")
self.assertEqual(op.outputs, [])
op_input = op.inputs[0].op
self.assertEqual(op_input.type, "Switch")
self.assertEqual(op_input.inputs[0], x)
self.assertEqual(op.graph, g)
# pylint: disable=protected-access
self.assertIsNotNone(op._get_control_flow_context())
self.assertEqual(op._get_control_flow_context().name,
"cond/cond_text")
# pylint: enable=protected-access
*/
}
[Ignore("Todo: Port")]


Loading…
Cancel
Save