Browse Source

fix #229

tags/v0.9
Oceania2018 6 years ago
parent
commit
6a1ed38e02
8 changed files with 33 additions and 31 deletions
  1. +2
    -7
      src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Operations/Operation.Control.cs
  3. +3
    -3
      src/TensorFlowNET.Core/Operations/Operation.Input.cs
  4. +2
    -0
      src/TensorFlowNET.Core/Operations/Operation.cs
  5. +1
    -2
      src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs
  6. +3
    -3
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  7. +5
    -3
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  8. +16
    -12
      test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs

+ 2
- 7
src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs View File

@@ -37,7 +37,7 @@ namespace Tensorflow.Operations
/// <param name="import_scope"></param>
public CondContext(Tensor pred = null,
Tensor pivot = null,
int? branch = null,
int branch = 0,
string name = "cond_text",
CondContextDef context_def = null,
string import_scope = null)
@@ -55,7 +55,7 @@ namespace Tensorflow.Operations
base.__init__();
_pred = pred;
_pivot = pivot;
_branch = branch; // 0 or 1 representing this branch
// Values considered to have been already seen in this context. pred is not
// included in this context.
_values.Add(pred.name);
@@ -105,11 +105,6 @@ namespace Tensorflow.Operations
_external_values[result.name] = result;
}
// for debug purpose
if(ops.get_default_graph()._nodes_by_name.Count > 60)
{
}

with(ops.control_dependencies(null), ctrl =>
{
var (r0, r1) = control_flow_ops._SwitchRefOrTensor(result, _pred);


+ 1
- 1
src/TensorFlowNET.Core/Operations/Operation.Control.cs View File

@@ -29,7 +29,7 @@ namespace Tensorflow

public void _add_control_input(Operation op)
{
// c_api.TF_AddControlInput(_operDesc, op);
//c_api.TF_AddControlInput(_operDesc, op);
c_api.AddControlInput(graph, _handle, op);
}



+ 3
- 3
src/TensorFlowNET.Core/Operations/Operation.Input.cs View File

@@ -27,9 +27,9 @@ namespace Tensorflow
for (int i = 0; i < NumInputs; i++)
{
var tf_outputs = Input(i);
var op = new Operation(tf_outputs.oper);
retval[i] = op.outputs[tf_outputs.index];
var tf_output = Input(i);
var op = new Operation(tf_output.oper);
retval[i] = op.outputs[tf_output.index];
}
_inputs = new InputList(retval);


+ 2
- 0
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -51,6 +51,7 @@ namespace Tensorflow
public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle));

private NodeDef _node_def;
//[JsonIgnore]
public NodeDef node_def
{
get
@@ -290,6 +291,7 @@ namespace Tensorflow
// the updated inputs are reloaded from the c_api
c_api.UpdateEdge(_graph, output, input, status);
//var updated_inputs = inputs;
status.Check();
}
private void _assert_same_graph(Tensor tensor)


+ 1
- 2
src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs View File

@@ -36,12 +36,11 @@ namespace Tensorflow
public static (Tensor, Tensor) @switch(Tensor data, Tensor pred, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Switch", name, new { data, pred });
var _result = (_op.outputs[0], _op.outputs[1]);
var _inputs_flat = _op.inputs;
var _attrs = ("T", _op.get_attr("T"));
// TODO: missing original code
//_execute.record_gradient("Switch", _inputs_flat, _attrs, _result, name);
return _result;
return (_op.outputs[0], _op.outputs[1]);
}

public static (Tensor, Tensor) merge(Tensor[] inputs, string name = null)


+ 3
- 3
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -128,9 +128,9 @@ namespace Tensorflow

public Tensor(Operation op, int value_index, TF_DataType dtype)
{
this.op = op;
this.value_index = value_index;
this._dtype = dtype;
_op = op;
_value_index = value_index;
_dtype = dtype;
_id = ops.uid();
}
}


+ 5
- 3
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -22,17 +22,19 @@ namespace Tensorflow
public int Id => _id;
//[JsonIgnore]
public Graph graph => op?.graph;
private Operation _op;
//[JsonIgnore]
public Operation op { get; }
public Operation op => _op;
//[JsonIgnore]
public Tensor[] outputs => op.outputs;

/// <summary>
/// The string name of this tensor.
/// </summary>
public string name => $"{(op == null ? "Operation was not named" : $"{op.name}:{value_index}")}";
public string name => $"{(op == null ? "Operation was not named" : $"{op.name}:{_value_index}")}";

public int value_index { get; }
private int _value_index;
public int value_index => _value_index;

private Status status = new Status();



+ 16
- 12
test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs View File

@@ -22,8 +22,8 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
var y = tf.constant(5, name: "y");
var z = control_flow_ops.cond(tf.less(x, y),
() => tf.constant(22, name: "t2"),
() => tf.constant(55, name: "f5"));
() => tf.constant(22, name: "t22"),
() => tf.constant(55, name: "f55"));
int result = z.eval(sess);
assertEquals(result, 22);
@@ -41,8 +41,8 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
var y = tf.constant(1, name: "y");
var z = control_flow_ops.cond(tf.less(x, y),
() => tf.constant(22, name: "t2"),
() => tf.constant(11, name: "f1"));
() => tf.constant(22, name: "t22"),
() => tf.constant(11, name: "f11"));
int result = z.eval(sess);
assertEquals(result, 11);
@@ -56,17 +56,18 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
with(tf.Session(graph), sess =>
{
var x = tf.constant(2);
var y = tf.constant(5);
var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)),
() => tf.add(y, tf.constant(23)));
//tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false);
var x = tf.constant(2, name: "x");
var y = tf.constant(5, name: "y");
var z = control_flow_ops.cond(tf.less(x, y),
() => tf.multiply(x, 17),
() => tf.add(y, 23));
int result = z.eval(sess);
assertEquals(result, 34);
});
}
//[Ignore("This Test Fails due to missing edges in the graph!")]
[TestMethod]
public void testCondFalse()
{
@@ -76,8 +77,11 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
{
var x = tf.constant(2);
var y = tf.constant(1);
var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)),
() => tf.add(y, tf.constant(23)));
var z = control_flow_ops.cond(tf.less(x, y),
() => tf.multiply(x, 17),
() => tf.add(y, 23));
int result = z.eval(sess);
assertEquals(result, 24);
});


Loading…
Cancel
Save