@@ -37,7 +37,7 @@ namespace Tensorflow | |||||
public Operation group<T>(T[] inputs, string name = null) where T : ITensorOrOperation | public Operation group<T>(T[] inputs, string name = null) where T : ITensorOrOperation | ||||
=> control_flow_ops.group(inputs, name: name); | => control_flow_ops.group(inputs, name: name); | ||||
public Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars, | |||||
/*public Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars, | |||||
TensorShape shape_invariants = null, | TensorShape shape_invariants = null, | ||||
int parallel_iterations = 10, | int parallel_iterations = 10, | ||||
bool back_prop = true, | bool back_prop = true, | ||||
@@ -52,7 +52,7 @@ namespace Tensorflow | |||||
swap_memory: swap_memory, | swap_memory: swap_memory, | ||||
name: name, | name: name, | ||||
maximum_iterations: maximum_iterations, | maximum_iterations: maximum_iterations, | ||||
return_same_structure: return_same_structure); | |||||
return_same_structure: return_same_structure);*/ | |||||
public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) | public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) | ||||
=> ops.control_dependencies(control_inputs); | => ops.control_dependencies(control_inputs); | ||||
@@ -0,0 +1,25 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Operations | |||||
{ | |||||
internal class LoopVar<TItem> | |||||
{ | |||||
public Tensor Counter { get; } | |||||
public TItem[] Items { get; } | |||||
public TItem Item { get; } | |||||
public LoopVar(Tensor counter, TItem[] items) | |||||
{ | |||||
Counter = counter; | |||||
Items = items; | |||||
} | |||||
public LoopVar(Tensor counter, TItem item) | |||||
{ | |||||
Counter = counter; | |||||
Item = item; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,32 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Operations | |||||
{ | |||||
internal class BodyItemInRnnWhileLoop | |||||
{ | |||||
/// <summary> | |||||
/// int32 scalar Tensor. | |||||
/// </summary> | |||||
public Tensor time { get; set; } | |||||
/// <summary> | |||||
/// List of `TensorArray`s that represent the output. | |||||
/// </summary> | |||||
public TensorArray[] output_ta_t { get; set; } | |||||
/// <summary> | |||||
/// nested tuple of vector tensors that represent the state. | |||||
/// </summary> | |||||
public Tensor state { get; set; } | |||||
public BodyItemInRnnWhileLoop(Tensor time, TensorArray[] output_ta_t, Tensor state) | |||||
{ | |||||
this.time = time; | |||||
this.output_ta_t = output_ta_t; | |||||
this.state = state; | |||||
} | |||||
public static implicit operator (Tensor, TensorArray[], Tensor)(BodyItemInRnnWhileLoop item) | |||||
=> (item.time, item.output_ta_t, item.state); | |||||
} | |||||
} |
@@ -145,7 +145,7 @@ namespace Tensorflow.Operations | |||||
{ | { | ||||
var ta = new TensorArray(dtype: dtype_, | var ta = new TensorArray(dtype: dtype_, | ||||
size: time_steps, | size: time_steps, | ||||
element_shape: new[] { element_shape }, | |||||
element_shape: element_shape, | |||||
tensor_array_name: base_name + name); | tensor_array_name: base_name + name); | ||||
return ta; | return ta; | ||||
}; | }; | ||||
@@ -178,19 +178,29 @@ namespace Tensorflow.Operations | |||||
// Make sure that we run at least 1 step, if necessary, to ensure | // Make sure that we run at least 1 step, if necessary, to ensure | ||||
// the TensorArrays pick up the dynamic shape. | // the TensorArrays pick up the dynamic shape. | ||||
Tensor loop_bound; | |||||
Tensor loop_bound = null; | |||||
if (in_graph_mode) | if (in_graph_mode) | ||||
loop_bound = math_ops.minimum( | loop_bound = math_ops.minimum( | ||||
time_steps, math_ops.maximum(1, max_sequence_length)); | time_steps, math_ops.maximum(1, max_sequence_length)); | ||||
/*Func<Tensor, Tensor> cond = (ctime) => | |||||
Func<BodyItemInRnnWhileLoop, Tensor> cond = (item) => | |||||
{ | { | ||||
return null; | |||||
return time < loop_bound; | |||||
}; | }; | ||||
control_flow_ops.while_loop( | |||||
// Take a time step of the dynamic RNN. | |||||
Func<BodyItemInRnnWhileLoop, BodyItemInRnnWhileLoop> _time_step = (item) => | |||||
{ | |||||
return item; | |||||
}; | |||||
control_flow_ops.while_loop<BodyItemInRnnWhileLoop>( | |||||
cond: cond, | cond: cond, | ||||
body = );*/ | |||||
body: _time_step, | |||||
loop_vars: new BodyItemInRnnWhileLoop(time, output_ta.ToArray(), state), | |||||
parallel_iterations: parallel_iterations, | |||||
maximum_iterations: time_steps, | |||||
swap_memory: swap_memory); | |||||
throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
} | } | ||||
@@ -39,7 +39,7 @@ namespace Tensorflow.Operations | |||||
public TensorArray(TF_DataType dtype, Tensor size = default, bool? clear_after_read = null, bool? dynamic_size = null, | public TensorArray(TF_DataType dtype, Tensor size = default, bool? clear_after_read = null, bool? dynamic_size = null, | ||||
string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | ||||
bool infer_shape = true, TensorShape[] element_shape = null, | |||||
bool infer_shape = true, TensorShape element_shape = null, | |||||
bool colocate_with_first_write_call = true, string name = null) | bool colocate_with_first_write_call = true, string name = null) | ||||
{ | { | ||||
_implementation = new _GraphTensorArray(dtype, | _implementation = new _GraphTensorArray(dtype, | ||||
@@ -44,7 +44,7 @@ namespace Tensorflow.Operations | |||||
public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null, | public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null, | ||||
bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | ||||
bool infer_shape = true, TensorShape[] element_shape = null, | |||||
bool infer_shape = true, TensorShape element_shape = null, | |||||
bool colocate_with_first_write_call = true, string name = null) | bool colocate_with_first_write_call = true, string name = null) | ||||
{ | { | ||||
clear_after_read = clear_after_read ?? true; | clear_after_read = clear_after_read ?? true; | ||||
@@ -68,7 +68,7 @@ namespace Tensorflow.Operations | |||||
else | else | ||||
{ | { | ||||
_infer_shape = true; | _infer_shape = true; | ||||
_element_shape = new List<TensorShape> { }; | |||||
_element_shape = new List<TensorShape> { element_shape }; | |||||
} | } | ||||
tf_with(ops.name_scope(name, "TensorArray", new { handle, size, flow }), scope => | tf_with(ops.name_scope(name, "TensorArray", new { handle, size, flow }), scope => | ||||
@@ -135,7 +135,7 @@ namespace Tensorflow.Operations | |||||
var ta = new TensorArray(_dtype, | var ta = new TensorArray(_dtype, | ||||
infer_shape:_infer_shape, | infer_shape:_infer_shape, | ||||
element_shape: _element_shape.ToArray(), | |||||
element_shape: _element_shape[0], | |||||
dynamic_size: _dynamic_size, | dynamic_size: _dynamic_size, | ||||
handle: _handle, | handle: _handle, | ||||
flow: flow_out, | flow: flow_out, | ||||
@@ -485,7 +485,7 @@ namespace Tensorflow | |||||
}); | }); | ||||
} | } | ||||
public static Tensor[] _convert_flows_to_tensorarrays<T>(T[] tensors_or_tensorarrays, Tensor[] tensors_or_flows) | |||||
public static Tensor[] _convert_flows_to_tensorarrays<T>(T tensors_or_tensorarrays, Tensor[] tensors_or_flows) | |||||
{ | { | ||||
// zip(tensors_or_tensorarrays, tensors_or_flows).Select((ta, t_or_flow) => ta).ToArray(); | // zip(tensors_or_tensorarrays, tensors_or_flows).Select((ta, t_or_flow) => ta).ToArray(); | ||||
return tensors_or_flows; | return tensors_or_flows; | ||||
@@ -591,18 +591,18 @@ namespace Tensorflow | |||||
/// <param name="body"></param> | /// <param name="body"></param> | ||||
/// <param name="loop_vars"></param> | /// <param name="loop_vars"></param> | ||||
/// <param name="i"></param> | /// <param name="i"></param> | ||||
public static Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars, | |||||
public static Tensor while_loop<TItem>(Func<TItem, Tensor> cond, Func<TItem, TItem> body, TItem loop_vars, | |||||
TensorShape shape_invariants = null, | TensorShape shape_invariants = null, | ||||
int parallel_iterations = 10, | int parallel_iterations = 10, | ||||
bool back_prop = true, | bool back_prop = true, | ||||
bool swap_memory = false, | bool swap_memory = false, | ||||
string name = null, | string name = null, | ||||
int? maximum_iterations = null, | |||||
Tensor maximum_iterations = null, | |||||
bool return_same_structure = false) | bool return_same_structure = false) | ||||
{ | { | ||||
tf_with(ops.name_scope(name, "while", loop_vars), scope => | tf_with(ops.name_scope(name, "while", loop_vars), scope => | ||||
{ | { | ||||
if (loop_vars == null || loop_vars.Length == 0) | |||||
if (loop_vars == null) | |||||
throw new ValueError("No loop variables provided"); | throw new ValueError("No loop variables provided"); | ||||
if (cond == null) | if (cond == null) | ||||
throw new ValueError("cond must be callable."); | throw new ValueError("cond must be callable."); | ||||
@@ -611,6 +611,28 @@ namespace Tensorflow | |||||
if (parallel_iterations < 1) | if (parallel_iterations < 1) | ||||
throw new ValueError("parallel_iterations must be a positive integer."); | throw new ValueError("parallel_iterations must be a positive integer."); | ||||
var try_to_pack = loop_vars is Tensor && !return_same_structure; | |||||
var counter = constant_op.constant(0, dtype: maximum_iterations.dtype, name: "iteration_counter"); | |||||
var orig_cond = cond; | |||||
var orig_body = body; | |||||
LoopVar<TItem> loop_vars_1 = null; | |||||
Func<Tensor, TItem, LoopVar<TItem>> body_buildloop = null; | |||||
Func<Tensor, TItem, Tensor> cond_buildloop = null; | |||||
if (try_to_pack) | |||||
{ | |||||
} | |||||
else | |||||
{ | |||||
loop_vars_1 = new LoopVar<TItem>(counter, loop_vars); | |||||
cond_buildloop = (i, lv) => | |||||
math_ops.logical_and(i < maximum_iterations, orig_cond(lv)); | |||||
body_buildloop = (i, lv) => new LoopVar<TItem>(i + 1, orig_body(lv)); | |||||
} | |||||
try_to_pack = false; | |||||
var loop_context = new WhileContext( | var loop_context = new WhileContext( | ||||
maximum_iterations: maximum_iterations, | maximum_iterations: maximum_iterations, | ||||
parallel_iterations: parallel_iterations, | parallel_iterations: parallel_iterations, | ||||
@@ -620,7 +642,7 @@ namespace Tensorflow | |||||
if (loop_context.outer_context == null) | if (loop_context.outer_context == null) | ||||
ops.add_to_collection(tf.GraphKeys.WHILE_CONTEXT, loop_context); | ops.add_to_collection(tf.GraphKeys.WHILE_CONTEXT, loop_context); | ||||
var results = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants, | |||||
var results = loop_context.BuildLoop(cond_buildloop, body_buildloop, loop_vars, shape_invariants, | |||||
return_same_structure); | return_same_structure); | ||||
if (maximum_iterations != null) | if (maximum_iterations != null) | ||||
@@ -28,12 +28,9 @@ namespace Tensorflow | |||||
} | } | ||||
public static (Tensor, Tensor) tensor_array_v3<T>(T size, TF_DataType dtype = TF_DataType.DtInvalid, | public static (Tensor, Tensor) tensor_array_v3<T>(T size, TF_DataType dtype = TF_DataType.DtInvalid, | ||||
TensorShape[] element_shape = null, bool dynamic_size = false, bool clear_after_read = true, | |||||
bool identical_element_shapes = false, string tensor_array_name = "tensor_array_name", string name = null) | |||||
TensorShape element_shape = null, bool dynamic_size = false, bool clear_after_read = true, | |||||
bool identical_element_shapes = false, string tensor_array_name = "", string name = null) | |||||
{ | { | ||||
if (tensor_array_name == null) | |||||
tensor_array_name = string.Empty; | |||||
var _op = _op_def_lib._apply_op_helper("TensorArrayV3", name, new | var _op = _op_def_lib._apply_op_helper("TensorArrayV3", name, new | ||||
{ | { | ||||
size, | size, | ||||
@@ -223,7 +223,6 @@ namespace Tensorflow.Util | |||||
private static void _flatten_recursive<T>(T obj, List<T> list) | private static void _flatten_recursive<T>(T obj, List<T> list) | ||||
{ | { | ||||
switch(obj) | switch(obj) | ||||
{ | { | ||||
case IDictionary dict: | case IDictionary dict: | ||||
@@ -93,14 +93,14 @@ namespace Tensorflow | |||||
return new Session().as_default(); | return new Session().as_default(); | ||||
} | } | ||||
public Session Session(Graph graph, SessionOptions opts = null) | |||||
public Session Session(Graph graph, ConfigProto config = null) | |||||
{ | { | ||||
return new Session(graph, opts: opts).as_default(); | |||||
return new Session(graph, config: config).as_default(); | |||||
} | } | ||||
public Session Session(SessionOptions opts) | |||||
public Session Session(ConfigProto config) | |||||
{ | { | ||||
return new Session(null, opts).as_default(); | |||||
return new Session(null, config).as_default(); | |||||
} | } | ||||
public void __init__() | public void __init__() | ||||
@@ -25,9 +25,8 @@ namespace TensorFlowNET.UnitTest | |||||
{ | { | ||||
lock (Locks.ProcessWide) | lock (Locks.ProcessWide) | ||||
{ | { | ||||
var opts = new SessionOptions(); | |||||
opts.SetConfig(new ConfigProto {InterOpParallelismThreads = 4}); | |||||
session_ = new Session(graph, opts, s); | |||||
var config = new ConfigProto {InterOpParallelismThreads = 4}; | |||||
session_ = new Session(graph, config, s); | |||||
} | } | ||||
} | } | ||||
@@ -18,10 +18,10 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||||
var i = constant_op.constant(0, name: "i"); | var i = constant_op.constant(0, name: "i"); | ||||
var c = new Func<Tensor, Tensor>(x => tf.less(x, 10, name: "c")); | var c = new Func<Tensor, Tensor>(x => tf.less(x, 10, name: "c")); | ||||
var b = new Func<Tensor, Tensor>(x => tf.add(x, 1, name: "c")); | var b = new Func<Tensor, Tensor>(x => tf.add(x, 1, name: "c")); | ||||
var r = control_flow_ops.while_loop(c, b, new[] { i }); | |||||
var r = control_flow_ops.while_loop(c, b, i); | |||||
} | } | ||||
private void _testWhileContextHelper(int? maximum_iterations = null) | |||||
private void _testWhileContextHelper(int maximum_iterations) | |||||
{ | { | ||||
// TODO: implement missing code dependencies | // TODO: implement missing code dependencies | ||||
using (var sess = this.cached_session()) | using (var sess = this.cached_session()) | ||||
@@ -30,7 +30,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||||
var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, 10, name: "c")); | var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, 10, name: "c")); | ||||
var b = new Func<Tensor, Tensor>(x => gen_math_ops.add(x, 1, name: "c")); | var b = new Func<Tensor, Tensor>(x => gen_math_ops.add(x, 1, name: "c")); | ||||
control_flow_ops.while_loop( | control_flow_ops.while_loop( | ||||
c, b, new[] { i }, maximum_iterations: maximum_iterations); | |||||
c, b, i , maximum_iterations: tf.constant(maximum_iterations)); | |||||
foreach (Operation op in sess.graph.get_operations()) | foreach (Operation op in sess.graph.get_operations()) | ||||
{ | { | ||||
var control_flow_context = op._get_control_flow_context(); | var control_flow_context = op._get_control_flow_context(); | ||||
@@ -42,13 +42,6 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||||
} | } | ||||
} | } | ||||
[Ignore("TODO")] | |||||
[TestMethod] | |||||
public void testWhileContext() | |||||
{ | |||||
_testWhileContextHelper(); | |||||
} | |||||
[Ignore("TODO")] | [Ignore("TODO")] | ||||
[TestMethod] | [TestMethod] | ||||
public void testWhileContextWithMaximumIterations() | public void testWhileContextWithMaximumIterations() | ||||