|
|
@@ -39,11 +39,11 @@ namespace Tensorflow |
|
|
|
bool input_is_sequence = nest.is_sequence(elems); |
|
|
|
|
|
|
|
List<Tensor> input_flatten(Tensor x) => input_is_sequence ? nest.flatten(x) : new List<Tensor> {x}; |
|
|
|
object input_pack(List<Tensor> x) => input_is_sequence ? nest.pack_sequence_as(elems, x) : x[0]; |
|
|
|
Tensor input_pack(List<Tensor> x) => input_is_sequence ? (Tensor)nest.pack_sequence_as(elems, x) : x[0]; |
|
|
|
|
|
|
|
bool output_is_sequence; |
|
|
|
Func<Tensor, List<Tensor>> output_flatten; |
|
|
|
Func<List<Tensor>, object> output_pack; |
|
|
|
Func<List<Tensor>, Tensor> output_pack; |
|
|
|
if (initializer == null) |
|
|
|
{ |
|
|
|
output_is_sequence = input_is_sequence; |
|
|
@@ -54,7 +54,7 @@ namespace Tensorflow |
|
|
|
{ |
|
|
|
output_is_sequence = nest.is_sequence(initializer); |
|
|
|
output_flatten = (x) => output_is_sequence ? nest.flatten(x) : new List<Tensor> {x}; |
|
|
|
output_pack = (x) => output_is_sequence ? nest.pack_sequence_as(initializer, x) : x[0]; |
|
|
|
output_pack = (x) => output_is_sequence ? (Tensor)nest.pack_sequence_as(initializer, x) : x[0]; |
|
|
|
} |
|
|
|
|
|
|
|
var elems_flat = input_flatten(elems); |
|
|
@@ -130,8 +130,11 @@ namespace Tensorflow |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
(int, List<Tensor>, List<TensorArray>) compute(int _i, List<Tensor> a_flat_, List<TensorArray> tas) |
|
|
|
(int, List<Tensor>, List<TensorArray>) compute(ValueTuple<int, List<Tensor>, List<TensorArray>> tuple) |
|
|
|
{ |
|
|
|
|
|
|
|
(int _i, List<Tensor> a_flat_, List<TensorArray> tas) = tuple; |
|
|
|
|
|
|
|
var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(tf.constant(_i))).ToList()); |
|
|
|
var packed_a = output_pack(a_flat_); |
|
|
|
var a_out = fn((Tensor)packed_a, (Tensor)packed_elems); // todo brendan are these casts legal? |
|
|
@@ -147,19 +150,19 @@ namespace Tensorflow |
|
|
|
} |
|
|
|
|
|
|
|
int initial_i; |
|
|
|
Func<int, Tensor> condition; |
|
|
|
Func<(int, List<Tensor>, List<TensorArray>), Tensor> condition; |
|
|
|
if (reverse) |
|
|
|
{ |
|
|
|
initial_i = n - 1 - i; |
|
|
|
condition = x => tf.constant(x >= 0); |
|
|
|
condition = x => tf.constant(x.Item1 >= 0); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
initial_i = i; |
|
|
|
condition = x => tf.convert_to_tensor(x < n); |
|
|
|
condition = x => tf.constant(x.Item1 < n); |
|
|
|
} |
|
|
|
|
|
|
|
List<TensorArray> r_a = |
|
|
|
(_, _, List<TensorArray> r_a) = |
|
|
|
control_flow_ops.while_loop( |
|
|
|
condition, |
|
|
|
compute, |
|
|
@@ -167,7 +170,7 @@ namespace Tensorflow |
|
|
|
parallel_iterations: parallel_iterations, |
|
|
|
back_prop: back_prop, |
|
|
|
swap_memory: swap_memory, |
|
|
|
maximum_iterations: n); |
|
|
|
maximum_iterations: tf.constant(n)); |
|
|
|
|
|
|
|
var results_flat = r_a.Select(r => r.stack()).ToList(); |
|
|
|
|
|
|
|