Browse Source

Convert bodyItem to a class

tags/v0.20
Brendan Mulcahy Haiping Chen 5 years ago
parent
commit
3dcbf01ab2
1 changed files with 49 additions and 20 deletions
  1. +49
    -20
      src/TensorFlowNET.Core/Operations/functional_ops.cs

+ 49
- 20
src/TensorFlowNET.Core/Operations/functional_ops.cs View File

@@ -17,6 +17,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using NumSharp;
using Tensorflow.Framework;
using Tensorflow.Util;
using static Tensorflow.Binding;
@@ -128,60 +129,57 @@ namespace Tensorflow
}
}

(int, List<Tensor>, List<TensorArray>) compute(ValueTuple<int, List<Tensor>, List<TensorArray>> tuple)
BodyItem compute(BodyItem item)
{

(int _i, List<Tensor> a_flat_, List<TensorArray> tas) = tuple;

var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(tf.constant(_i))).ToList());
var packed_a = output_pack(a_flat_);
var a_out = fn((Tensor)packed_a, (Tensor)packed_elems); // todo brendan are these casts legal?
var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(tf.constant(item.I))).ToList());
var packed_a = output_pack(item.A_Flat);
var a_out = fn(packed_a, packed_elems);

var flat_a_out = output_flatten(a_out);
for (int j = 0; j < tas.Count; j++)
for (int j = 0; j < item.Accs_ta.Count; j++)
{
tas[j].write(tf.constant(i), flat_a_out[j]);
item.Accs_ta[j].write(tf.constant(i), flat_a_out[j]);
}

var next_i = reverse ? _i-- : _i++;
return (next_i, flat_a_out, tas);
var next_i = reverse ? item.I - 1 : item.I + 1;
return new BodyItem(next_i, flat_a_out, item.Accs_ta);
}

int initial_i;
Func<(int, List<Tensor>, List<TensorArray>), Tensor> condition;
Func<BodyItem, Tensor> condition;
if (reverse)
{
initial_i = n - 1 - i;
condition = x => tf.constant(x.Item1 >= 0);
condition = x => tf.constant(x.I >= 0);
}
else
{
initial_i = i;
condition = x => tf.constant(x.Item1 < n);
condition = x => tf.constant(x.I < n);
}

(_, _, List<TensorArray> r_a) =
BodyItem bodyItem =
control_flow_ops.while_loop(
condition,
compute,
(initial_i, a_flat, accs_ta),
new BodyItem(tf.constant(initial_i), a_flat, accs_ta),
parallel_iterations: parallel_iterations,
back_prop: back_prop,
swap_memory: swap_memory,
maximum_iterations: tf.constant(n));

var results_flat = r_a.Select(r => r.stack()).ToList();
var results_flat = bodyItem.Accs_ta.Select(r => r.stack()).ToList();

var n_static = new Dimension(tensor_shape.dimension_value(elems_flat[0].shape[0]));
var n_static = new Dimension(tensor_shape.dimension_value(elems_flat[0].TensorShape.with_rank_at_least(1).dims[0]));
foreach (var elem in elems_flat.Skip(1))
{
n_static.merge_with(new Dimension(tensor_shape.dimension_value(elem.shape[0])));
n_static.merge_with(new Dimension(tensor_shape.dimension_value(elem.TensorShape.with_rank_at_least(1).dims[0])));
}

foreach (Tensor r in results_flat)
{
r.set_shape(new TensorShape(n_static).concatenate(r.shape.Skip(1).ToArray()));
r.set_shape(new TensorShape(n_static).concatenate(r.TensorShape[new Slice("1:")]));
}

// todo get working when the above caching_device is fixed
@@ -192,6 +190,37 @@ namespace Tensorflow
return output_pack(results_flat);
});
}

internal class BodyItem : ICanBeFlattened, IPackable<BodyItem>
{
public Tensor I { get; set; }
public List<Tensor> A_Flat { get; set; }
public List<TensorArray> Accs_ta { get; set; }

public BodyItem(Tensor i, List<Tensor> a_flat, List<TensorArray> accs_ta)
{
I = i;
A_Flat = a_flat;
Accs_ta = accs_ta;
}

public object[] Flatten()
{
var elements = new List<object> { I };
elements.AddRange(A_Flat);
elements.AddRange(Accs_ta);
return elements.ToArray();
}

public BodyItem Pack(object[] sequences)
{
I = sequences[0] as Tensor;
A_Flat = new List<Tensor> { sequences[1] as Tensor };
Accs_ta = new List<TensorArray> { sequences[2] as TensorArray };
return new BodyItem(I, A_Flat, Accs_ta);
}
}
}
}


Loading…
Cancel
Save