Browse Source

WhileContext BuildLoop

tags/v0.12
Oceania2018 6 years ago
parent
commit
07f70f9425
12 changed files with 117 additions and 40 deletions
  1. +2
    -2
      src/TensorFlowNET.Core/APIs/tf.control_flow.cs
  2. +25
    -0
      src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs
  3. +32
    -0
      src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs
  4. +16
    -6
      src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Operations/TensorArray.cs
  6. +3
    -3
      src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs
  7. +27
    -5
      src/TensorFlowNET.Core/Operations/control_flow_ops.cs
  8. +2
    -5
      src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs
  9. +0
    -1
      src/TensorFlowNET.Core/Util/nest.py.cs
  10. +4
    -4
      src/TensorFlowNET.Core/tensorflow.cs
  11. +2
    -3
      test/TensorFlowNET.UnitTest/CSession.cs
  12. +3
    -10
      test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs

+ 2
- 2
src/TensorFlowNET.Core/APIs/tf.control_flow.cs View File

@@ -37,7 +37,7 @@ namespace Tensorflow
public Operation group<T>(T[] inputs, string name = null) where T : ITensorOrOperation
=> control_flow_ops.group(inputs, name: name);

public Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars,
/*public Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars,
TensorShape shape_invariants = null,
int parallel_iterations = 10,
bool back_prop = true,
@@ -52,7 +52,7 @@ namespace Tensorflow
swap_memory: swap_memory,
name: name,
maximum_iterations: maximum_iterations,
return_same_structure: return_same_structure);
return_same_structure: return_same_structure);*/

public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs)
=> ops.control_dependencies(control_inputs);


+ 25
- 0
src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs View File

@@ -0,0 +1,25 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Operations
{
internal class LoopVar<TItem>
{
public Tensor Counter { get; }
public TItem[] Items { get; }
public TItem Item { get; }

public LoopVar(Tensor counter, TItem[] items)
{
Counter = counter;
Items = items;
}

public LoopVar(Tensor counter, TItem item)
{
Counter = counter;
Item = item;
}
}
}

+ 32
- 0
src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs View File

@@ -0,0 +1,32 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Operations
{
internal class BodyItemInRnnWhileLoop
{
/// <summary>
/// int32 scalar Tensor.
/// </summary>
public Tensor time { get; set; }
/// <summary>
/// List of `TensorArray`s that represent the output.
/// </summary>
public TensorArray[] output_ta_t { get; set; }
/// <summary>
/// nested tuple of vector tensors that represent the state.
/// </summary>
public Tensor state { get; set; }

public BodyItemInRnnWhileLoop(Tensor time, TensorArray[] output_ta_t, Tensor state)
{
this.time = time;
this.output_ta_t = output_ta_t;
this.state = state;
}

public static implicit operator (Tensor, TensorArray[], Tensor)(BodyItemInRnnWhileLoop item)
=> (item.time, item.output_ta_t, item.state);
}
}

+ 16
- 6
src/TensorFlowNET.Core/Operations/NnOps/rnn.cs View File

@@ -145,7 +145,7 @@ namespace Tensorflow.Operations
{
var ta = new TensorArray(dtype: dtype_,
size: time_steps,
element_shape: new[] { element_shape },
element_shape: element_shape,
tensor_array_name: base_name + name);
return ta;
};
@@ -178,19 +178,29 @@ namespace Tensorflow.Operations

// Make sure that we run at least 1 step, if necessary, to ensure
// the TensorArrays pick up the dynamic shape.
Tensor loop_bound;
Tensor loop_bound = null;
if (in_graph_mode)
loop_bound = math_ops.minimum(
time_steps, math_ops.maximum(1, max_sequence_length));

/*Func<Tensor, Tensor> cond = (ctime) =>
Func<BodyItemInRnnWhileLoop, Tensor> cond = (item) =>
{
return null;
return time < loop_bound;
};

control_flow_ops.while_loop(
// Take a time step of the dynamic RNN.
Func<BodyItemInRnnWhileLoop, BodyItemInRnnWhileLoop> _time_step = (item) =>
{
return item;
};

control_flow_ops.while_loop<BodyItemInRnnWhileLoop>(
cond: cond,
body = );*/
body: _time_step,
loop_vars: new BodyItemInRnnWhileLoop(time, output_ta.ToArray(), state),
parallel_iterations: parallel_iterations,
maximum_iterations: time_steps,
swap_memory: swap_memory);

throw new NotImplementedException("");
}


+ 1
- 1
src/TensorFlowNET.Core/Operations/TensorArray.cs View File

@@ -39,7 +39,7 @@ namespace Tensorflow.Operations

public TensorArray(TF_DataType dtype, Tensor size = default, bool? clear_after_read = null, bool? dynamic_size = null,
string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
bool infer_shape = true, TensorShape[] element_shape = null,
bool infer_shape = true, TensorShape element_shape = null,
bool colocate_with_first_write_call = true, string name = null)
{
_implementation = new _GraphTensorArray(dtype,


+ 3
- 3
src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs View File

@@ -44,7 +44,7 @@ namespace Tensorflow.Operations

public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null,
bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
bool infer_shape = true, TensorShape[] element_shape = null,
bool infer_shape = true, TensorShape element_shape = null,
bool colocate_with_first_write_call = true, string name = null)
{
clear_after_read = clear_after_read ?? true;
@@ -68,7 +68,7 @@ namespace Tensorflow.Operations
else
{
_infer_shape = true;
_element_shape = new List<TensorShape> { };
_element_shape = new List<TensorShape> { element_shape };
}

tf_with(ops.name_scope(name, "TensorArray", new { handle, size, flow }), scope =>
@@ -135,7 +135,7 @@ namespace Tensorflow.Operations

var ta = new TensorArray(_dtype,
infer_shape:_infer_shape,
element_shape: _element_shape.ToArray(),
element_shape: _element_shape[0],
dynamic_size: _dynamic_size,
handle: _handle,
flow: flow_out,


+ 27
- 5
src/TensorFlowNET.Core/Operations/control_flow_ops.cs View File

@@ -485,7 +485,7 @@ namespace Tensorflow
});
}

public static Tensor[] _convert_flows_to_tensorarrays<T>(T[] tensors_or_tensorarrays, Tensor[] tensors_or_flows)
public static Tensor[] _convert_flows_to_tensorarrays<T>(T tensors_or_tensorarrays, Tensor[] tensors_or_flows)
{
// zip(tensors_or_tensorarrays, tensors_or_flows).Select((ta, t_or_flow) => ta).ToArray();
return tensors_or_flows;
@@ -591,18 +591,18 @@ namespace Tensorflow
/// <param name="body"></param>
/// <param name="loop_vars"></param>
/// <param name="i"></param>
public static Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars,
public static Tensor while_loop<TItem>(Func<TItem, Tensor> cond, Func<TItem, TItem> body, TItem loop_vars,
TensorShape shape_invariants = null,
int parallel_iterations = 10,
bool back_prop = true,
bool swap_memory = false,
string name = null,
int? maximum_iterations = null,
Tensor maximum_iterations = null,
bool return_same_structure = false)
{
tf_with(ops.name_scope(name, "while", loop_vars), scope =>
{
if (loop_vars == null || loop_vars.Length == 0)
if (loop_vars == null)
throw new ValueError("No loop variables provided");
if (cond == null)
throw new ValueError("cond must be callable.");
@@ -611,6 +611,28 @@ namespace Tensorflow
if (parallel_iterations < 1)
throw new ValueError("parallel_iterations must be a positive integer.");

var try_to_pack = loop_vars is Tensor && !return_same_structure;
var counter = constant_op.constant(0, dtype: maximum_iterations.dtype, name: "iteration_counter");
var orig_cond = cond;
var orig_body = body;

LoopVar<TItem> loop_vars_1 = null;
Func<Tensor, TItem, LoopVar<TItem>> body_buildloop = null;
Func<Tensor, TItem, Tensor> cond_buildloop = null;

if (try_to_pack)
{

}
else
{
loop_vars_1 = new LoopVar<TItem>(counter, loop_vars);
cond_buildloop = (i, lv) =>
math_ops.logical_and(i < maximum_iterations, orig_cond(lv));
body_buildloop = (i, lv) => new LoopVar<TItem>(i + 1, orig_body(lv));
}
try_to_pack = false;

var loop_context = new WhileContext(
maximum_iterations: maximum_iterations,
parallel_iterations: parallel_iterations,
@@ -620,7 +642,7 @@ namespace Tensorflow
if (loop_context.outer_context == null)
ops.add_to_collection(tf.GraphKeys.WHILE_CONTEXT, loop_context);

var results = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants,
var results = loop_context.BuildLoop(cond_buildloop, body_buildloop, loop_vars, shape_invariants,
return_same_structure);

if (maximum_iterations != null)


+ 2
- 5
src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs View File

@@ -28,12 +28,9 @@ namespace Tensorflow
}

public static (Tensor, Tensor) tensor_array_v3<T>(T size, TF_DataType dtype = TF_DataType.DtInvalid,
TensorShape[] element_shape = null, bool dynamic_size = false, bool clear_after_read = true,
bool identical_element_shapes = false, string tensor_array_name = "tensor_array_name", string name = null)
TensorShape element_shape = null, bool dynamic_size = false, bool clear_after_read = true,
bool identical_element_shapes = false, string tensor_array_name = "", string name = null)
{
if (tensor_array_name == null)
tensor_array_name = string.Empty;

var _op = _op_def_lib._apply_op_helper("TensorArrayV3", name, new
{
size,


+ 0
- 1
src/TensorFlowNET.Core/Util/nest.py.cs View File

@@ -223,7 +223,6 @@ namespace Tensorflow.Util
private static void _flatten_recursive<T>(T obj, List<T> list)
{
switch(obj)
{
case IDictionary dict:


+ 4
- 4
src/TensorFlowNET.Core/tensorflow.cs View File

@@ -93,14 +93,14 @@ namespace Tensorflow
return new Session().as_default();
}

public Session Session(Graph graph, SessionOptions opts = null)
public Session Session(Graph graph, ConfigProto config = null)
{
return new Session(graph, opts: opts).as_default();
return new Session(graph, config: config).as_default();
}

public Session Session(SessionOptions opts)
public Session Session(ConfigProto config)
{
return new Session(null, opts).as_default();
return new Session(null, config).as_default();
}

public void __init__()


+ 2
- 3
test/TensorFlowNET.UnitTest/CSession.cs View File

@@ -25,9 +25,8 @@ namespace TensorFlowNET.UnitTest
{
lock (Locks.ProcessWide)
{
var opts = new SessionOptions();
opts.SetConfig(new ConfigProto {InterOpParallelismThreads = 4});
session_ = new Session(graph, opts, s);
var config = new ConfigProto {InterOpParallelismThreads = 4};
session_ = new Session(graph, config, s);
}
}



+ 3
- 10
test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs View File

@@ -18,10 +18,10 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
var i = constant_op.constant(0, name: "i");
var c = new Func<Tensor, Tensor>(x => tf.less(x, 10, name: "c"));
var b = new Func<Tensor, Tensor>(x => tf.add(x, 1, name: "c"));
var r = control_flow_ops.while_loop(c, b, new[] { i });
var r = control_flow_ops.while_loop(c, b, i);
}
private void _testWhileContextHelper(int? maximum_iterations = null)
private void _testWhileContextHelper(int maximum_iterations)
{
// TODO: implement missing code dependencies
using (var sess = this.cached_session())
@@ -30,7 +30,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, 10, name: "c"));
var b = new Func<Tensor, Tensor>(x => gen_math_ops.add(x, 1, name: "c"));
control_flow_ops.while_loop(
c, b, new[] { i }, maximum_iterations: maximum_iterations);
c, b, i , maximum_iterations: tf.constant(maximum_iterations));
foreach (Operation op in sess.graph.get_operations())
{
var control_flow_context = op._get_control_flow_context();
@@ -42,13 +42,6 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
}
}
[Ignore("TODO")]
[TestMethod]
public void testWhileContext()
{
_testWhileContextHelper();
}
[Ignore("TODO")]
[TestMethod]
public void testWhileContextWithMaximumIterations()


Loading…
Cancel
Save