diff --git a/src/TensorFlowNET.Core/Operations/functional_ops.cs b/src/TensorFlowNET.Core/Operations/functional_ops.cs index f0e1aa1c..674334bd 100644 --- a/src/TensorFlowNET.Core/Operations/functional_ops.cs +++ b/src/TensorFlowNET.Core/Operations/functional_ops.cs @@ -17,6 +17,7 @@ using System; using System.Collections.Generic; using System.Linq; +using NumSharp; using Tensorflow.Framework; using Tensorflow.Util; using static Tensorflow.Binding; @@ -128,60 +129,57 @@ namespace Tensorflow } } - (int, List, List) compute(ValueTuple, List> tuple) + BodyItem compute(BodyItem item) { - - (int _i, List a_flat_, List tas) = tuple; - - var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(tf.constant(_i))).ToList()); - var packed_a = output_pack(a_flat_); - var a_out = fn((Tensor)packed_a, (Tensor)packed_elems); // todo brendan are these casts legal? + var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(tf.constant(item.I))).ToList()); + var packed_a = output_pack(item.A_Flat); + var a_out = fn(packed_a, packed_elems); var flat_a_out = output_flatten(a_out); - for (int j = 0; j < tas.Count; j++) + for (int j = 0; j < item.Accs_ta.Count; j++) { - tas[j].write(tf.constant(i), flat_a_out[j]); + item.Accs_ta[j].write(tf.constant(i), flat_a_out[j]); } - var next_i = reverse ? _i-- : _i++; - return (next_i, flat_a_out, tas); + var next_i = reverse ? item.I - 1 : item.I + 1; + return new BodyItem(next_i, flat_a_out, item.Accs_ta); } int initial_i; - Func<(int, List, List), Tensor> condition; + Func condition; if (reverse) { initial_i = n - 1 - i; - condition = x => tf.constant(x.Item1 >= 0); + condition = x => tf.constant(x.I >= 0); } else { initial_i = i; - condition = x => tf.constant(x.Item1 < n); + condition = x => tf.constant(x.I < n); } - (_, _, List r_a) = + BodyItem bodyItem = control_flow_ops.while_loop( condition, compute, - (initial_i, a_flat, accs_ta), + new BodyItem(tf.constant(initial_i), a_flat, accs_ta), parallel_iterations: parallel_iterations, back_prop: back_prop, swap_memory: swap_memory, maximum_iterations: tf.constant(n)); - var results_flat = r_a.Select(r => r.stack()).ToList(); + var results_flat = bodyItem.Accs_ta.Select(r => r.stack()).ToList(); - var n_static = new Dimension(tensor_shape.dimension_value(elems_flat[0].shape[0])); + var n_static = new Dimension(tensor_shape.dimension_value(elems_flat[0].TensorShape.with_rank_at_least(1).dims[0])); foreach (var elem in elems_flat.Skip(1)) { - n_static.merge_with(new Dimension(tensor_shape.dimension_value(elem.shape[0]))); + n_static.merge_with(new Dimension(tensor_shape.dimension_value(elem.TensorShape.with_rank_at_least(1).dims[0]))); } foreach (Tensor r in results_flat) { - r.set_shape(new TensorShape(n_static).concatenate(r.shape.Skip(1).ToArray())); + r.set_shape(new TensorShape(n_static).concatenate(r.TensorShape[new Slice("1:")])); } // todo get working when the above caching_device is fixed @@ -192,6 +190,37 @@ namespace Tensorflow return output_pack(results_flat); }); } + + internal class BodyItem : ICanBeFlattened, IPackable + { + public Tensor I { get; set; } + public List A_Flat { get; set; } + public List Accs_ta { get; set; } + + public BodyItem(Tensor i, List a_flat, List accs_ta) + { + I = i; + A_Flat = a_flat; + Accs_ta = accs_ta; + } + + public object[] Flatten() + { + var elements = new List { I }; + elements.AddRange(A_Flat); + elements.AddRange(Accs_ta); + return elements.ToArray(); + } + + public BodyItem Pack(object[] sequences) + { + I = sequences[0] as Tensor; + A_Flat = new List { sequences[1] as Tensor }; + Accs_ta = new List { sequences[2] as TensorArray }; + + return new BodyItem(I, A_Flat, Accs_ta); + } + } } }