From 83a1ba642175aa81a82cec28ab0b0423dbad6aa7 Mon Sep 17 00:00:00 2001 From: Brendan Mulcahy Date: Fri, 29 Nov 2019 16:21:47 -0500 Subject: [PATCH] Fix up the code in tf.scan --- .../Operations/functional_ops.cs | 143 +++++++++--------- 1 file changed, 70 insertions(+), 73 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/functional_ops.cs b/src/TensorFlowNET.Core/Operations/functional_ops.cs index 85292851..f392d766 100644 --- a/src/TensorFlowNET.Core/Operations/functional_ops.cs +++ b/src/TensorFlowNET.Core/Operations/functional_ops.cs @@ -39,92 +39,89 @@ namespace Tensorflow bool input_is_sequence = nest.is_sequence(elems); List input_flatten(Tensor x) => input_is_sequence ? nest.flatten(x) : new List {x}; - object input_pack(List x) => input_is_sequence ? nest.pack_sequence_as(elems, x) : x; + object input_pack(List x) => input_is_sequence ? nest.pack_sequence_as(elems, x) : x[0]; bool output_is_sequence; Func> output_flatten; + Func, object> output_pack; if (initializer == null) { output_is_sequence = input_is_sequence; output_flatten = input_flatten; - //output_pack = input_pack + output_pack = input_pack; } else { output_is_sequence = nest.is_sequence(initializer); output_flatten = (x) => output_is_sequence ? nest.flatten(x) : new List {x}; - } - - object output_pack(List x) - { - return output_is_sequence ? nest.pack_sequence_as(initializer, x) : x[0]; + output_pack = (x) => output_is_sequence ? nest.pack_sequence_as(initializer, x) : x[0]; } var elems_flat = input_flatten(elems); - bool in_graph_mode = true; // todo not context.executing_eagerly() + bool in_graph_mode = true; // todo !context.executing_eagerly() - //with ops.name_scope(name, "scan", elems_flat): return tf_with(ops.name_scope(name, "scan", new { elems_flat }), scope => { - //if (in_graph_mode) - //{ - // // Any get_variable calls in fn will cache the first call locally - // // and not issue repeated network I/O requests for each iteration. - // var varscope = variable_scope.get_variable_scope(); - // bool varscope_caching_device_was_none = false; - // if (varscope.caching_device = null) - // { - // // varscope.set_caching_device(lambda op: op.device) - // // varscope_caching_device_was_none = True - // } - //} - - elems_flat = elems_flat.Select(elem => ops.convert_to_tensor(elem, name: "elem")).ToList(); - - // # Convert elems to tensor array. n may be known statically. - var n = tensor_shape.dimension_value(elems_flat[0].shape[0]); - //if (n == null) - //{ - // n = array_ops.shape(elems_flat[0])[0]; - //} - - // # TensorArrays are always flat - var elems_ta = elems_flat.Select(elem => new TensorArray( - elem.dtype, - size: tf.constant(n), - dynamic_size: false, - element_shape: elem.shape[0], //1: - infer_shape: true)).ToList(); - - for (int index = 0; index < elems_ta.Count; index++) - { - elems_ta[index].unstack(elems_flat[index]); - } + // todo tf.net doesn't expose .caching_device + //if (in_graph_mode) + //{ + // // Any get_variable calls in fn will cache the first call locally + // // and not issue repeated network I/O requests for each iteration. + // var varscope = variable_scope.get_variable_scope(); + // bool varscope_caching_device_was_none = false; + // if (varscope.caching_device = null) + // { + // // varscope.set_caching_device(lambda op: op.device) + // // varscope_caching_device_was_none = True + // } + //} + + elems_flat = elems_flat.Select(elem => ops.convert_to_tensor(elem, name: "elem")).ToList(); + + var n = tensor_shape.dimension_value(elems_flat[0].shape[0]); + + // todo python had the below but dimension_value returns int which can't be null + //if (n == null) + //{ + // n = array_ops.shape(elems_flat[0])[0]; + //} + + var elems_ta = elems_flat.Select(elem => new TensorArray( + elem.dtype, + size: tf.constant(n), + dynamic_size: false, + element_shape: elem.shape.Skip(1).ToArray(), + infer_shape: true)).ToList(); + + for (int index = 0; index < elems_ta.Count; index++) + { + elems_ta[index].unstack(elems_flat[index]); + } - List a_flat; - int i; - if (initializer == null) - { - // a_flat = [elem.read(n - 1 if reverse else 0) for elem in elems_ta] - a_flat = elems_ta.Select(elem => elem.read(tf.constant(reverse ? n - 1 : 0))).ToList(); - i = 1; - } - else - { - List initializer_flat = output_flatten(initializer as Tensor); - a_flat = initializer_flat.Select(init => ops.convert_to_tensor(init)).ToList(); - i = 0; - } + List a_flat; + int i; + if (initializer == null) + { + a_flat = elems_ta.Select(elem => elem.read(tf.constant(reverse ? n - 1 : 0))).ToList(); + i = 1; + } + else + { + throw new NotImplementedException("Initializer not handled yet"); + // todo the below in python, initializer is able to be passed as a List + //List initializer_flat = output_flatten(initializer); + //a_flat = initializer_flat.Select(init => ops.convert_to_tensor(init)).ToList(); + //i = 0; + } - var accs_ta = a_flat.Select(init => new TensorArray( - dtype: init.dtype, - size: tf.constant(n), - element_shape: infer_shape ? init.shape : null, - dynamic_size: false, - infer_shape: infer_shape)).ToList(); + var accs_ta = a_flat.Select(init => new TensorArray( + dtype: init.dtype, + size: tf.constant(n), + element_shape: infer_shape ? init.shape : null, + dynamic_size: false, + infer_shape: infer_shape)).ToList(); - // if initializer is None: if (initializer == null) { for (int index = 0; index < accs_ta.Count; index++) @@ -135,14 +132,14 @@ namespace Tensorflow (int, List, List) compute(int _i, List a_flat_, List tas) { - var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(tf.convert_to_tensor(_i))).ToList()); + var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(tf.constant(_i))).ToList()); var packed_a = output_pack(a_flat_); var a_out = fn((Tensor)packed_a, (Tensor)packed_elems); // todo brendan are these casts legal? var flat_a_out = output_flatten(a_out); for (int j = 0; j < tas.Count; j++) { - tas[j].write(tf.convert_to_tensor(i), flat_a_out[j]); // todo brendan convert to tensor + tas[j].write(tf.constant(i), flat_a_out[j]); } var next_i = reverse ? _i-- : _i++; @@ -154,13 +151,11 @@ namespace Tensorflow if (reverse) { initial_i = n - 1 - i; - // condition = lambda i, _1, _2: i >= 0 - condition = x => tf.convert_to_tensor(x >= 0); + condition = x => tf.constant(x >= 0); } else { initial_i = i; - // condition = lambda i, _1, _2: i < n condition = x => tf.convert_to_tensor(x < n); } @@ -178,18 +173,20 @@ namespace Tensorflow var n_static = new Dimension(tensor_shape.dimension_value(elems_flat[0].shape[0])); - foreach (var elem in elems_flat) // for elem in elems_flat[1:]: + foreach (var elem in elems_flat.Skip(1)) { n_static.merge_with(new Dimension(tensor_shape.dimension_value(elem.shape[0]))); } foreach (Tensor r in results_flat) { - r.set_shape(new TensorShape(n_static).concatenate(r.shape[0])); //r.shape[1:] + r.set_shape(new TensorShape(n_static).concatenate(r.shape.Skip(1).ToArray())); } - // if in_graph_mode and varscope_caching_device_was_none: - // varscope.set_caching_device(None) + // todo get working when the above caching_device is fixed + //if (in_graph_mode && varscope_caching_device_was_none) { + // varscope.set_caching_device(None); + //} return output_pack(results_flat); });