Browse Source

Convert bodyItem to a class

tags/v0.20
Brendan Mulcahy Haiping Chen 5 years ago
parent
commit
3dcbf01ab2
1 changed files with 49 additions and 20 deletions
  1. +49
    -20
      src/TensorFlowNET.Core/Operations/functional_ops.cs

+ 49
- 20
src/TensorFlowNET.Core/Operations/functional_ops.cs View File

@@ -17,6 +17,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using NumSharp;
using Tensorflow.Framework; using Tensorflow.Framework;
using Tensorflow.Util; using Tensorflow.Util;
using static Tensorflow.Binding; using static Tensorflow.Binding;
@@ -128,60 +129,57 @@ namespace Tensorflow
} }
} }


(int, List<Tensor>, List<TensorArray>) compute(ValueTuple<int, List<Tensor>, List<TensorArray>> tuple)
BodyItem compute(BodyItem item)
{ {

(int _i, List<Tensor> a_flat_, List<TensorArray> tas) = tuple;

var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(tf.constant(_i))).ToList());
var packed_a = output_pack(a_flat_);
var a_out = fn((Tensor)packed_a, (Tensor)packed_elems); // todo brendan are these casts legal?
var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(tf.constant(item.I))).ToList());
var packed_a = output_pack(item.A_Flat);
var a_out = fn(packed_a, packed_elems);


var flat_a_out = output_flatten(a_out); var flat_a_out = output_flatten(a_out);
for (int j = 0; j < tas.Count; j++)
for (int j = 0; j < item.Accs_ta.Count; j++)
{ {
tas[j].write(tf.constant(i), flat_a_out[j]);
item.Accs_ta[j].write(tf.constant(i), flat_a_out[j]);
} }


var next_i = reverse ? _i-- : _i++;
return (next_i, flat_a_out, tas);
var next_i = reverse ? item.I - 1 : item.I + 1;
return new BodyItem(next_i, flat_a_out, item.Accs_ta);
} }


int initial_i; int initial_i;
Func<(int, List<Tensor>, List<TensorArray>), Tensor> condition;
Func<BodyItem, Tensor> condition;
if (reverse) if (reverse)
{ {
initial_i = n - 1 - i; initial_i = n - 1 - i;
condition = x => tf.constant(x.Item1 >= 0);
condition = x => tf.constant(x.I >= 0);
} }
else else
{ {
initial_i = i; initial_i = i;
condition = x => tf.constant(x.Item1 < n);
condition = x => tf.constant(x.I < n);
} }


(_, _, List<TensorArray> r_a) =
BodyItem bodyItem =
control_flow_ops.while_loop( control_flow_ops.while_loop(
condition, condition,
compute, compute,
(initial_i, a_flat, accs_ta),
new BodyItem(tf.constant(initial_i), a_flat, accs_ta),
parallel_iterations: parallel_iterations, parallel_iterations: parallel_iterations,
back_prop: back_prop, back_prop: back_prop,
swap_memory: swap_memory, swap_memory: swap_memory,
maximum_iterations: tf.constant(n)); maximum_iterations: tf.constant(n));


var results_flat = r_a.Select(r => r.stack()).ToList();
var results_flat = bodyItem.Accs_ta.Select(r => r.stack()).ToList();


var n_static = new Dimension(tensor_shape.dimension_value(elems_flat[0].shape[0]));
var n_static = new Dimension(tensor_shape.dimension_value(elems_flat[0].TensorShape.with_rank_at_least(1).dims[0]));
foreach (var elem in elems_flat.Skip(1)) foreach (var elem in elems_flat.Skip(1))
{ {
n_static.merge_with(new Dimension(tensor_shape.dimension_value(elem.shape[0])));
n_static.merge_with(new Dimension(tensor_shape.dimension_value(elem.TensorShape.with_rank_at_least(1).dims[0])));
} }


foreach (Tensor r in results_flat) foreach (Tensor r in results_flat)
{ {
r.set_shape(new TensorShape(n_static).concatenate(r.shape.Skip(1).ToArray()));
r.set_shape(new TensorShape(n_static).concatenate(r.TensorShape[new Slice("1:")]));
} }


// todo get working when the above caching_device is fixed // todo get working when the above caching_device is fixed
@@ -192,6 +190,37 @@ namespace Tensorflow
return output_pack(results_flat); return output_pack(results_flat);
}); });
} }

internal class BodyItem : ICanBeFlattened, IPackable<BodyItem>
{
public Tensor I { get; set; }
public List<Tensor> A_Flat { get; set; }
public List<TensorArray> Accs_ta { get; set; }

public BodyItem(Tensor i, List<Tensor> a_flat, List<TensorArray> accs_ta)
{
I = i;
A_Flat = a_flat;
Accs_ta = accs_ta;
}

public object[] Flatten()
{
var elements = new List<object> { I };
elements.AddRange(A_Flat);
elements.AddRange(Accs_ta);
return elements.ToArray();
}

public BodyItem Pack(object[] sequences)
{
I = sequences[0] as Tensor;
A_Flat = new List<Tensor> { sequences[1] as Tensor };
Accs_ta = new List<TensorArray> { sequences[2] as TensorArray };
return new BodyItem(I, A_Flat, Accs_ta);
}
}
} }
} }



Loading…
Cancel
Save