Browse Source

Adjust types to get while_loop working

tags/v0.20
Brendan Mulcahy Haiping Chen 5 years ago
parent
commit
593ce2b6c3
1 changed files with 12 additions and 9 deletions
  1. +12
    -9
      src/TensorFlowNET.Core/Operations/functional_ops.cs

+ 12
- 9
src/TensorFlowNET.Core/Operations/functional_ops.cs View File

@@ -39,11 +39,11 @@ namespace Tensorflow
bool input_is_sequence = nest.is_sequence(elems); bool input_is_sequence = nest.is_sequence(elems);


List<Tensor> input_flatten(Tensor x) => input_is_sequence ? nest.flatten(x) : new List<Tensor> {x}; List<Tensor> input_flatten(Tensor x) => input_is_sequence ? nest.flatten(x) : new List<Tensor> {x};
object input_pack(List<Tensor> x) => input_is_sequence ? nest.pack_sequence_as(elems, x) : x[0];
Tensor input_pack(List<Tensor> x) => input_is_sequence ? (Tensor)nest.pack_sequence_as(elems, x) : x[0];


bool output_is_sequence; bool output_is_sequence;
Func<Tensor, List<Tensor>> output_flatten; Func<Tensor, List<Tensor>> output_flatten;
Func<List<Tensor>, object> output_pack;
Func<List<Tensor>, Tensor> output_pack;
if (initializer == null) if (initializer == null)
{ {
output_is_sequence = input_is_sequence; output_is_sequence = input_is_sequence;
@@ -54,7 +54,7 @@ namespace Tensorflow
{ {
output_is_sequence = nest.is_sequence(initializer); output_is_sequence = nest.is_sequence(initializer);
output_flatten = (x) => output_is_sequence ? nest.flatten(x) : new List<Tensor> {x}; output_flatten = (x) => output_is_sequence ? nest.flatten(x) : new List<Tensor> {x};
output_pack = (x) => output_is_sequence ? nest.pack_sequence_as(initializer, x) : x[0];
output_pack = (x) => output_is_sequence ? (Tensor)nest.pack_sequence_as(initializer, x) : x[0];
} }


var elems_flat = input_flatten(elems); var elems_flat = input_flatten(elems);
@@ -130,8 +130,11 @@ namespace Tensorflow
} }
} }


(int, List<Tensor>, List<TensorArray>) compute(int _i, List<Tensor> a_flat_, List<TensorArray> tas)
(int, List<Tensor>, List<TensorArray>) compute(ValueTuple<int, List<Tensor>, List<TensorArray>> tuple)
{ {

(int _i, List<Tensor> a_flat_, List<TensorArray> tas) = tuple;

var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(tf.constant(_i))).ToList()); var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(tf.constant(_i))).ToList());
var packed_a = output_pack(a_flat_); var packed_a = output_pack(a_flat_);
var a_out = fn((Tensor)packed_a, (Tensor)packed_elems); // todo brendan are these casts legal? var a_out = fn((Tensor)packed_a, (Tensor)packed_elems); // todo brendan are these casts legal?
@@ -147,19 +150,19 @@ namespace Tensorflow
} }


int initial_i; int initial_i;
Func<int, Tensor> condition;
Func<(int, List<Tensor>, List<TensorArray>), Tensor> condition;
if (reverse) if (reverse)
{ {
initial_i = n - 1 - i; initial_i = n - 1 - i;
condition = x => tf.constant(x >= 0);
condition = x => tf.constant(x.Item1 >= 0);
} }
else else
{ {
initial_i = i; initial_i = i;
condition = x => tf.convert_to_tensor(x < n);
condition = x => tf.constant(x.Item1 < n);
} }


List<TensorArray> r_a =
(_, _, List<TensorArray> r_a) =
control_flow_ops.while_loop( control_flow_ops.while_loop(
condition, condition,
compute, compute,
@@ -167,7 +170,7 @@ namespace Tensorflow
parallel_iterations: parallel_iterations, parallel_iterations: parallel_iterations,
back_prop: back_prop, back_prop: back_prop,
swap_memory: swap_memory, swap_memory: swap_memory,
maximum_iterations: n);
maximum_iterations: tf.constant(n));


var results_flat = r_a.Select(r => r.stack()).ToList(); var results_flat = r_a.Select(r => r.stack()).ToList();




Loading…
Cancel
Save