@@ -37,7 +37,7 @@ namespace Tensorflow | |||
public Operation group<T>(T[] inputs, string name = null) where T : ITensorOrOperation | |||
=> control_flow_ops.group(inputs, name: name); | |||
public Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars, | |||
/*public Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars, | |||
TensorShape shape_invariants = null, | |||
int parallel_iterations = 10, | |||
bool back_prop = true, | |||
@@ -52,7 +52,7 @@ namespace Tensorflow | |||
swap_memory: swap_memory, | |||
name: name, | |||
maximum_iterations: maximum_iterations, | |||
return_same_structure: return_same_structure); | |||
return_same_structure: return_same_structure);*/ | |||
public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) | |||
=> ops.control_dependencies(control_inputs); | |||
@@ -0,0 +1,25 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Operations | |||
{ | |||
internal class LoopVar<TItem> | |||
{ | |||
public Tensor Counter { get; } | |||
public TItem[] Items { get; } | |||
public TItem Item { get; } | |||
public LoopVar(Tensor counter, TItem[] items) | |||
{ | |||
Counter = counter; | |||
Items = items; | |||
} | |||
public LoopVar(Tensor counter, TItem item) | |||
{ | |||
Counter = counter; | |||
Item = item; | |||
} | |||
} | |||
} |
@@ -0,0 +1,32 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Operations | |||
{ | |||
internal class BodyItemInRnnWhileLoop | |||
{ | |||
/// <summary> | |||
/// int32 scalar Tensor. | |||
/// </summary> | |||
public Tensor time { get; set; } | |||
/// <summary> | |||
/// List of `TensorArray`s that represent the output. | |||
/// </summary> | |||
public TensorArray[] output_ta_t { get; set; } | |||
/// <summary> | |||
/// nested tuple of vector tensors that represent the state. | |||
/// </summary> | |||
public Tensor state { get; set; } | |||
public BodyItemInRnnWhileLoop(Tensor time, TensorArray[] output_ta_t, Tensor state) | |||
{ | |||
this.time = time; | |||
this.output_ta_t = output_ta_t; | |||
this.state = state; | |||
} | |||
public static implicit operator (Tensor, TensorArray[], Tensor)(BodyItemInRnnWhileLoop item) | |||
=> (item.time, item.output_ta_t, item.state); | |||
} | |||
} |
@@ -145,7 +145,7 @@ namespace Tensorflow.Operations | |||
{ | |||
var ta = new TensorArray(dtype: dtype_, | |||
size: time_steps, | |||
element_shape: new[] { element_shape }, | |||
element_shape: element_shape, | |||
tensor_array_name: base_name + name); | |||
return ta; | |||
}; | |||
@@ -178,19 +178,29 @@ namespace Tensorflow.Operations | |||
// Make sure that we run at least 1 step, if necessary, to ensure | |||
// the TensorArrays pick up the dynamic shape. | |||
Tensor loop_bound; | |||
Tensor loop_bound = null; | |||
if (in_graph_mode) | |||
loop_bound = math_ops.minimum( | |||
time_steps, math_ops.maximum(1, max_sequence_length)); | |||
/*Func<Tensor, Tensor> cond = (ctime) => | |||
Func<BodyItemInRnnWhileLoop, Tensor> cond = (item) => | |||
{ | |||
return null; | |||
return time < loop_bound; | |||
}; | |||
control_flow_ops.while_loop( | |||
// Take a time step of the dynamic RNN. | |||
Func<BodyItemInRnnWhileLoop, BodyItemInRnnWhileLoop> _time_step = (item) => | |||
{ | |||
return item; | |||
}; | |||
control_flow_ops.while_loop<BodyItemInRnnWhileLoop>( | |||
cond: cond, | |||
body = );*/ | |||
body: _time_step, | |||
loop_vars: new BodyItemInRnnWhileLoop(time, output_ta.ToArray(), state), | |||
parallel_iterations: parallel_iterations, | |||
maximum_iterations: time_steps, | |||
swap_memory: swap_memory); | |||
throw new NotImplementedException(""); | |||
} | |||
@@ -39,7 +39,7 @@ namespace Tensorflow.Operations | |||
public TensorArray(TF_DataType dtype, Tensor size = default, bool? clear_after_read = null, bool? dynamic_size = null, | |||
string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | |||
bool infer_shape = true, TensorShape[] element_shape = null, | |||
bool infer_shape = true, TensorShape element_shape = null, | |||
bool colocate_with_first_write_call = true, string name = null) | |||
{ | |||
_implementation = new _GraphTensorArray(dtype, | |||
@@ -44,7 +44,7 @@ namespace Tensorflow.Operations | |||
public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null, | |||
bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | |||
bool infer_shape = true, TensorShape[] element_shape = null, | |||
bool infer_shape = true, TensorShape element_shape = null, | |||
bool colocate_with_first_write_call = true, string name = null) | |||
{ | |||
clear_after_read = clear_after_read ?? true; | |||
@@ -68,7 +68,7 @@ namespace Tensorflow.Operations | |||
else | |||
{ | |||
_infer_shape = true; | |||
_element_shape = new List<TensorShape> { }; | |||
_element_shape = new List<TensorShape> { element_shape }; | |||
} | |||
tf_with(ops.name_scope(name, "TensorArray", new { handle, size, flow }), scope => | |||
@@ -135,7 +135,7 @@ namespace Tensorflow.Operations | |||
var ta = new TensorArray(_dtype, | |||
infer_shape:_infer_shape, | |||
element_shape: _element_shape.ToArray(), | |||
element_shape: _element_shape[0], | |||
dynamic_size: _dynamic_size, | |||
handle: _handle, | |||
flow: flow_out, | |||
@@ -485,7 +485,7 @@ namespace Tensorflow | |||
}); | |||
} | |||
public static Tensor[] _convert_flows_to_tensorarrays<T>(T[] tensors_or_tensorarrays, Tensor[] tensors_or_flows) | |||
public static Tensor[] _convert_flows_to_tensorarrays<T>(T tensors_or_tensorarrays, Tensor[] tensors_or_flows) | |||
{ | |||
// zip(tensors_or_tensorarrays, tensors_or_flows).Select((ta, t_or_flow) => ta).ToArray(); | |||
return tensors_or_flows; | |||
@@ -591,18 +591,18 @@ namespace Tensorflow | |||
/// <param name="body"></param> | |||
/// <param name="loop_vars"></param> | |||
/// <param name="i"></param> | |||
public static Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars, | |||
public static Tensor while_loop<TItem>(Func<TItem, Tensor> cond, Func<TItem, TItem> body, TItem loop_vars, | |||
TensorShape shape_invariants = null, | |||
int parallel_iterations = 10, | |||
bool back_prop = true, | |||
bool swap_memory = false, | |||
string name = null, | |||
int? maximum_iterations = null, | |||
Tensor maximum_iterations = null, | |||
bool return_same_structure = false) | |||
{ | |||
tf_with(ops.name_scope(name, "while", loop_vars), scope => | |||
{ | |||
if (loop_vars == null || loop_vars.Length == 0) | |||
if (loop_vars == null) | |||
throw new ValueError("No loop variables provided"); | |||
if (cond == null) | |||
throw new ValueError("cond must be callable."); | |||
@@ -611,6 +611,28 @@ namespace Tensorflow | |||
if (parallel_iterations < 1) | |||
throw new ValueError("parallel_iterations must be a positive integer."); | |||
var try_to_pack = loop_vars is Tensor && !return_same_structure; | |||
var counter = constant_op.constant(0, dtype: maximum_iterations.dtype, name: "iteration_counter"); | |||
var orig_cond = cond; | |||
var orig_body = body; | |||
LoopVar<TItem> loop_vars_1 = null; | |||
Func<Tensor, TItem, LoopVar<TItem>> body_buildloop = null; | |||
Func<Tensor, TItem, Tensor> cond_buildloop = null; | |||
if (try_to_pack) | |||
{ | |||
} | |||
else | |||
{ | |||
loop_vars_1 = new LoopVar<TItem>(counter, loop_vars); | |||
cond_buildloop = (i, lv) => | |||
math_ops.logical_and(i < maximum_iterations, orig_cond(lv)); | |||
body_buildloop = (i, lv) => new LoopVar<TItem>(i + 1, orig_body(lv)); | |||
} | |||
try_to_pack = false; | |||
var loop_context = new WhileContext( | |||
maximum_iterations: maximum_iterations, | |||
parallel_iterations: parallel_iterations, | |||
@@ -620,7 +642,7 @@ namespace Tensorflow | |||
if (loop_context.outer_context == null) | |||
ops.add_to_collection(tf.GraphKeys.WHILE_CONTEXT, loop_context); | |||
var results = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants, | |||
var results = loop_context.BuildLoop(cond_buildloop, body_buildloop, loop_vars, shape_invariants, | |||
return_same_structure); | |||
if (maximum_iterations != null) | |||
@@ -28,12 +28,9 @@ namespace Tensorflow | |||
} | |||
public static (Tensor, Tensor) tensor_array_v3<T>(T size, TF_DataType dtype = TF_DataType.DtInvalid, | |||
TensorShape[] element_shape = null, bool dynamic_size = false, bool clear_after_read = true, | |||
bool identical_element_shapes = false, string tensor_array_name = "tensor_array_name", string name = null) | |||
TensorShape element_shape = null, bool dynamic_size = false, bool clear_after_read = true, | |||
bool identical_element_shapes = false, string tensor_array_name = "", string name = null) | |||
{ | |||
if (tensor_array_name == null) | |||
tensor_array_name = string.Empty; | |||
var _op = _op_def_lib._apply_op_helper("TensorArrayV3", name, new | |||
{ | |||
size, | |||
@@ -223,7 +223,6 @@ namespace Tensorflow.Util | |||
private static void _flatten_recursive<T>(T obj, List<T> list) | |||
{ | |||
switch(obj) | |||
{ | |||
case IDictionary dict: | |||
@@ -93,14 +93,14 @@ namespace Tensorflow | |||
return new Session().as_default(); | |||
} | |||
public Session Session(Graph graph, SessionOptions opts = null) | |||
public Session Session(Graph graph, ConfigProto config = null) | |||
{ | |||
return new Session(graph, opts: opts).as_default(); | |||
return new Session(graph, config: config).as_default(); | |||
} | |||
public Session Session(SessionOptions opts) | |||
public Session Session(ConfigProto config) | |||
{ | |||
return new Session(null, opts).as_default(); | |||
return new Session(null, config).as_default(); | |||
} | |||
public void __init__() | |||
@@ -25,9 +25,8 @@ namespace TensorFlowNET.UnitTest | |||
{ | |||
lock (Locks.ProcessWide) | |||
{ | |||
var opts = new SessionOptions(); | |||
opts.SetConfig(new ConfigProto {InterOpParallelismThreads = 4}); | |||
session_ = new Session(graph, opts, s); | |||
var config = new ConfigProto {InterOpParallelismThreads = 4}; | |||
session_ = new Session(graph, config, s); | |||
} | |||
} | |||
@@ -18,10 +18,10 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||
var i = constant_op.constant(0, name: "i"); | |||
var c = new Func<Tensor, Tensor>(x => tf.less(x, 10, name: "c")); | |||
var b = new Func<Tensor, Tensor>(x => tf.add(x, 1, name: "c")); | |||
var r = control_flow_ops.while_loop(c, b, new[] { i }); | |||
var r = control_flow_ops.while_loop(c, b, i); | |||
} | |||
private void _testWhileContextHelper(int? maximum_iterations = null) | |||
private void _testWhileContextHelper(int maximum_iterations) | |||
{ | |||
// TODO: implement missing code dependencies | |||
using (var sess = this.cached_session()) | |||
@@ -30,7 +30,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||
var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, 10, name: "c")); | |||
var b = new Func<Tensor, Tensor>(x => gen_math_ops.add(x, 1, name: "c")); | |||
control_flow_ops.while_loop( | |||
c, b, new[] { i }, maximum_iterations: maximum_iterations); | |||
c, b, i , maximum_iterations: tf.constant(maximum_iterations)); | |||
foreach (Operation op in sess.graph.get_operations()) | |||
{ | |||
var control_flow_context = op._get_control_flow_context(); | |||
@@ -42,13 +42,6 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||
} | |||
} | |||
[Ignore("TODO")] | |||
[TestMethod] | |||
public void testWhileContext() | |||
{ | |||
_testWhileContextHelper(); | |||
} | |||
[Ignore("TODO")] | |||
[TestMethod] | |||
public void testWhileContextWithMaximumIterations() | |||