|
|
@@ -17,6 +17,7 @@ |
|
|
|
using System; |
|
|
|
using System.Collections.Generic; |
|
|
|
using System.Linq; |
|
|
|
using NumSharp; |
|
|
|
using Tensorflow.Framework; |
|
|
|
using Tensorflow.Util; |
|
|
|
using static Tensorflow.Binding; |
|
|
@@ -128,60 +129,57 @@ namespace Tensorflow |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
(int, List<Tensor>, List<TensorArray>) compute(ValueTuple<int, List<Tensor>, List<TensorArray>> tuple) |
|
|
|
BodyItem compute(BodyItem item) |
|
|
|
{ |
|
|
|
|
|
|
|
(int _i, List<Tensor> a_flat_, List<TensorArray> tas) = tuple; |
|
|
|
|
|
|
|
var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(tf.constant(_i))).ToList()); |
|
|
|
var packed_a = output_pack(a_flat_); |
|
|
|
var a_out = fn((Tensor)packed_a, (Tensor)packed_elems); // todo brendan are these casts legal? |
|
|
|
var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(tf.constant(item.I))).ToList()); |
|
|
|
var packed_a = output_pack(item.A_Flat); |
|
|
|
var a_out = fn(packed_a, packed_elems); |
|
|
|
|
|
|
|
var flat_a_out = output_flatten(a_out); |
|
|
|
for (int j = 0; j < tas.Count; j++) |
|
|
|
for (int j = 0; j < item.Accs_ta.Count; j++) |
|
|
|
{ |
|
|
|
tas[j].write(tf.constant(i), flat_a_out[j]); |
|
|
|
item.Accs_ta[j].write(tf.constant(i), flat_a_out[j]); |
|
|
|
} |
|
|
|
|
|
|
|
var next_i = reverse ? _i-- : _i++; |
|
|
|
return (next_i, flat_a_out, tas); |
|
|
|
var next_i = reverse ? item.I - 1 : item.I + 1; |
|
|
|
return new BodyItem(next_i, flat_a_out, item.Accs_ta); |
|
|
|
} |
|
|
|
|
|
|
|
int initial_i; |
|
|
|
Func<(int, List<Tensor>, List<TensorArray>), Tensor> condition; |
|
|
|
Func<BodyItem, Tensor> condition; |
|
|
|
if (reverse) |
|
|
|
{ |
|
|
|
initial_i = n - 1 - i; |
|
|
|
condition = x => tf.constant(x.Item1 >= 0); |
|
|
|
condition = x => tf.constant(x.I >= 0); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
initial_i = i; |
|
|
|
condition = x => tf.constant(x.Item1 < n); |
|
|
|
condition = x => tf.constant(x.I < n); |
|
|
|
} |
|
|
|
|
|
|
|
(_, _, List<TensorArray> r_a) = |
|
|
|
BodyItem bodyItem = |
|
|
|
control_flow_ops.while_loop( |
|
|
|
condition, |
|
|
|
compute, |
|
|
|
(initial_i, a_flat, accs_ta), |
|
|
|
new BodyItem(tf.constant(initial_i), a_flat, accs_ta), |
|
|
|
parallel_iterations: parallel_iterations, |
|
|
|
back_prop: back_prop, |
|
|
|
swap_memory: swap_memory, |
|
|
|
maximum_iterations: tf.constant(n)); |
|
|
|
|
|
|
|
var results_flat = r_a.Select(r => r.stack()).ToList(); |
|
|
|
var results_flat = bodyItem.Accs_ta.Select(r => r.stack()).ToList(); |
|
|
|
|
|
|
|
var n_static = new Dimension(tensor_shape.dimension_value(elems_flat[0].shape[0])); |
|
|
|
var n_static = new Dimension(tensor_shape.dimension_value(elems_flat[0].TensorShape.with_rank_at_least(1).dims[0])); |
|
|
|
|
|
|
|
foreach (var elem in elems_flat.Skip(1)) |
|
|
|
{ |
|
|
|
n_static.merge_with(new Dimension(tensor_shape.dimension_value(elem.shape[0]))); |
|
|
|
n_static.merge_with(new Dimension(tensor_shape.dimension_value(elem.TensorShape.with_rank_at_least(1).dims[0]))); |
|
|
|
} |
|
|
|
|
|
|
|
foreach (Tensor r in results_flat) |
|
|
|
{ |
|
|
|
r.set_shape(new TensorShape(n_static).concatenate(r.shape.Skip(1).ToArray())); |
|
|
|
r.set_shape(new TensorShape(n_static).concatenate(r.TensorShape[new Slice("1:")])); |
|
|
|
} |
|
|
|
|
|
|
|
// todo get working when the above caching_device is fixed |
|
|
@@ -192,6 +190,37 @@ namespace Tensorflow |
|
|
|
return output_pack(results_flat); |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
internal class BodyItem : ICanBeFlattened, IPackable<BodyItem> |
|
|
|
{ |
|
|
|
public Tensor I { get; set; } |
|
|
|
public List<Tensor> A_Flat { get; set; } |
|
|
|
public List<TensorArray> Accs_ta { get; set; } |
|
|
|
|
|
|
|
public BodyItem(Tensor i, List<Tensor> a_flat, List<TensorArray> accs_ta) |
|
|
|
{ |
|
|
|
I = i; |
|
|
|
A_Flat = a_flat; |
|
|
|
Accs_ta = accs_ta; |
|
|
|
} |
|
|
|
|
|
|
|
public object[] Flatten() |
|
|
|
{ |
|
|
|
var elements = new List<object> { I }; |
|
|
|
elements.AddRange(A_Flat); |
|
|
|
elements.AddRange(Accs_ta); |
|
|
|
return elements.ToArray(); |
|
|
|
} |
|
|
|
|
|
|
|
public BodyItem Pack(object[] sequences) |
|
|
|
{ |
|
|
|
I = sequences[0] as Tensor; |
|
|
|
A_Flat = new List<Tensor> { sequences[1] as Tensor }; |
|
|
|
Accs_ta = new List<TensorArray> { sequences[2] as TensorArray }; |
|
|
|
|
|
|
|
return new BodyItem(I, A_Flat, Accs_ta); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|