Browse Source

Fix up the code in tf.scan

tags/v0.20
Brendan Mulcahy Haiping Chen 5 years ago
parent
commit
83a1ba6421
1 changed files with 70 additions and 73 deletions
  1. +70
    -73
      src/TensorFlowNET.Core/Operations/functional_ops.cs

+ 70
- 73
src/TensorFlowNET.Core/Operations/functional_ops.cs View File

@@ -39,92 +39,89 @@ namespace Tensorflow
bool input_is_sequence = nest.is_sequence(elems); bool input_is_sequence = nest.is_sequence(elems);


List<Tensor> input_flatten(Tensor x) => input_is_sequence ? nest.flatten(x) : new List<Tensor> {x}; List<Tensor> input_flatten(Tensor x) => input_is_sequence ? nest.flatten(x) : new List<Tensor> {x};
object input_pack(List<Tensor> x) => input_is_sequence ? nest.pack_sequence_as(elems, x) : x;
object input_pack(List<Tensor> x) => input_is_sequence ? nest.pack_sequence_as(elems, x) : x[0];


bool output_is_sequence; bool output_is_sequence;
Func<Tensor, List<Tensor>> output_flatten; Func<Tensor, List<Tensor>> output_flatten;
Func<List<Tensor>, object> output_pack;
if (initializer == null) if (initializer == null)
{ {
output_is_sequence = input_is_sequence; output_is_sequence = input_is_sequence;
output_flatten = input_flatten; output_flatten = input_flatten;
//output_pack = input_pack
output_pack = input_pack;
} }
else else
{ {
output_is_sequence = nest.is_sequence(initializer); output_is_sequence = nest.is_sequence(initializer);
output_flatten = (x) => output_is_sequence ? nest.flatten(x) : new List<Tensor> {x}; output_flatten = (x) => output_is_sequence ? nest.flatten(x) : new List<Tensor> {x};
}

object output_pack(List<Tensor> x)
{
return output_is_sequence ? nest.pack_sequence_as(initializer, x) : x[0];
output_pack = (x) => output_is_sequence ? nest.pack_sequence_as(initializer, x) : x[0];
} }


var elems_flat = input_flatten(elems); var elems_flat = input_flatten(elems);


bool in_graph_mode = true; // todo not context.executing_eagerly()
bool in_graph_mode = true; // todo !context.executing_eagerly()


//with ops.name_scope(name, "scan", elems_flat):
return tf_with(ops.name_scope(name, "scan", new { elems_flat }), scope => return tf_with(ops.name_scope(name, "scan", new { elems_flat }), scope =>
{ {
//if (in_graph_mode)
//{
// // Any get_variable calls in fn will cache the first call locally
// // and not issue repeated network I/O requests for each iteration.
// var varscope = variable_scope.get_variable_scope();
// bool varscope_caching_device_was_none = false;
// if (varscope.caching_device = null)
// {
// // varscope.set_caching_device(lambda op: op.device)
// // varscope_caching_device_was_none = True
// }
//}

elems_flat = elems_flat.Select(elem => ops.convert_to_tensor(elem, name: "elem")).ToList();

// # Convert elems to tensor array. n may be known statically.
var n = tensor_shape.dimension_value(elems_flat[0].shape[0]);
//if (n == null)
//{
// n = array_ops.shape(elems_flat[0])[0];
//}

// # TensorArrays are always flat
var elems_ta = elems_flat.Select(elem => new TensorArray(
elem.dtype,
size: tf.constant(n),
dynamic_size: false,
element_shape: elem.shape[0], //1:
infer_shape: true)).ToList();

for (int index = 0; index < elems_ta.Count; index++)
{
elems_ta[index].unstack(elems_flat[index]);
}
// todo tf.net doesn't expose .caching_device
//if (in_graph_mode)
//{
// // Any get_variable calls in fn will cache the first call locally
// // and not issue repeated network I/O requests for each iteration.
// var varscope = variable_scope.get_variable_scope();
// bool varscope_caching_device_was_none = false;
// if (varscope.caching_device = null)
// {
// // varscope.set_caching_device(lambda op: op.device)
// // varscope_caching_device_was_none = True
// }
//}

elems_flat = elems_flat.Select(elem => ops.convert_to_tensor(elem, name: "elem")).ToList();

var n = tensor_shape.dimension_value(elems_flat[0].shape[0]);

// todo python had the below but dimension_value returns int which can't be null
//if (n == null)
//{
// n = array_ops.shape(elems_flat[0])[0];
//}

var elems_ta = elems_flat.Select(elem => new TensorArray(
elem.dtype,
size: tf.constant(n),
dynamic_size: false,
element_shape: elem.shape.Skip(1).ToArray(),
infer_shape: true)).ToList();

for (int index = 0; index < elems_ta.Count; index++)
{
elems_ta[index].unstack(elems_flat[index]);
}


List<Tensor> a_flat;
int i;
if (initializer == null)
{
// a_flat = [elem.read(n - 1 if reverse else 0) for elem in elems_ta]
a_flat = elems_ta.Select(elem => elem.read(tf.constant(reverse ? n - 1 : 0))).ToList();
i = 1;
}
else
{
List<Tensor> initializer_flat = output_flatten(initializer as Tensor);
a_flat = initializer_flat.Select(init => ops.convert_to_tensor(init)).ToList();
i = 0;
}
List<Tensor> a_flat;
int i;
if (initializer == null)
{
a_flat = elems_ta.Select(elem => elem.read(tf.constant(reverse ? n - 1 : 0))).ToList();
i = 1;
}
else
{
throw new NotImplementedException("Initializer not handled yet");
// todo the below in python, initializer is able to be passed as a List<Tensor>
//List<Tensor> initializer_flat = output_flatten(initializer);
//a_flat = initializer_flat.Select(init => ops.convert_to_tensor(init)).ToList();
//i = 0;
}


var accs_ta = a_flat.Select(init => new TensorArray(
dtype: init.dtype,
size: tf.constant(n),
element_shape: infer_shape ? init.shape : null,
dynamic_size: false,
infer_shape: infer_shape)).ToList();
var accs_ta = a_flat.Select(init => new TensorArray(
dtype: init.dtype,
size: tf.constant(n),
element_shape: infer_shape ? init.shape : null,
dynamic_size: false,
infer_shape: infer_shape)).ToList();


// if initializer is None:
if (initializer == null) if (initializer == null)
{ {
for (int index = 0; index < accs_ta.Count; index++) for (int index = 0; index < accs_ta.Count; index++)
@@ -135,14 +132,14 @@ namespace Tensorflow


(int, List<Tensor>, List<TensorArray>) compute(int _i, List<Tensor> a_flat_, List<TensorArray> tas) (int, List<Tensor>, List<TensorArray>) compute(int _i, List<Tensor> a_flat_, List<TensorArray> tas)
{ {
var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(tf.convert_to_tensor(_i))).ToList());
var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(tf.constant(_i))).ToList());
var packed_a = output_pack(a_flat_); var packed_a = output_pack(a_flat_);
var a_out = fn((Tensor)packed_a, (Tensor)packed_elems); // todo brendan are these casts legal? var a_out = fn((Tensor)packed_a, (Tensor)packed_elems); // todo brendan are these casts legal?


var flat_a_out = output_flatten(a_out); var flat_a_out = output_flatten(a_out);
for (int j = 0; j < tas.Count; j++) for (int j = 0; j < tas.Count; j++)
{ {
tas[j].write(tf.convert_to_tensor(i), flat_a_out[j]); // todo brendan convert to tensor
tas[j].write(tf.constant(i), flat_a_out[j]);
} }


var next_i = reverse ? _i-- : _i++; var next_i = reverse ? _i-- : _i++;
@@ -154,13 +151,11 @@ namespace Tensorflow
if (reverse) if (reverse)
{ {
initial_i = n - 1 - i; initial_i = n - 1 - i;
// condition = lambda i, _1, _2: i >= 0
condition = x => tf.convert_to_tensor(x >= 0);
condition = x => tf.constant(x >= 0);
} }
else else
{ {
initial_i = i; initial_i = i;
// condition = lambda i, _1, _2: i < n
condition = x => tf.convert_to_tensor(x < n); condition = x => tf.convert_to_tensor(x < n);
} }


@@ -178,18 +173,20 @@ namespace Tensorflow


var n_static = new Dimension(tensor_shape.dimension_value(elems_flat[0].shape[0])); var n_static = new Dimension(tensor_shape.dimension_value(elems_flat[0].shape[0]));
foreach (var elem in elems_flat) // for elem in elems_flat[1:]:
foreach (var elem in elems_flat.Skip(1))
{ {
n_static.merge_with(new Dimension(tensor_shape.dimension_value(elem.shape[0]))); n_static.merge_with(new Dimension(tensor_shape.dimension_value(elem.shape[0])));
} }


foreach (Tensor r in results_flat) foreach (Tensor r in results_flat)
{ {
r.set_shape(new TensorShape(n_static).concatenate(r.shape[0])); //r.shape[1:]
r.set_shape(new TensorShape(n_static).concatenate(r.shape.Skip(1).ToArray()));
} }


// if in_graph_mode and varscope_caching_device_was_none:
// varscope.set_caching_device(None)
// todo get working when the above caching_device is fixed
//if (in_graph_mode && varscope_caching_device_was_none) {
// varscope.set_caching_device(None);
//}


return output_pack(results_flat); return output_pack(results_flat);
}); });


Loading…
Cancel
Save