|
|
@@ -39,92 +39,89 @@ namespace Tensorflow |
|
|
|
bool input_is_sequence = nest.is_sequence(elems); |
|
|
|
|
|
|
|
List<Tensor> input_flatten(Tensor x) => input_is_sequence ? nest.flatten(x) : new List<Tensor> {x}; |
|
|
|
object input_pack(List<Tensor> x) => input_is_sequence ? nest.pack_sequence_as(elems, x) : x; |
|
|
|
object input_pack(List<Tensor> x) => input_is_sequence ? nest.pack_sequence_as(elems, x) : x[0]; |
|
|
|
|
|
|
|
bool output_is_sequence; |
|
|
|
Func<Tensor, List<Tensor>> output_flatten; |
|
|
|
Func<List<Tensor>, object> output_pack; |
|
|
|
if (initializer == null) |
|
|
|
{ |
|
|
|
output_is_sequence = input_is_sequence; |
|
|
|
output_flatten = input_flatten; |
|
|
|
//output_pack = input_pack |
|
|
|
output_pack = input_pack; |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
output_is_sequence = nest.is_sequence(initializer); |
|
|
|
output_flatten = (x) => output_is_sequence ? nest.flatten(x) : new List<Tensor> {x}; |
|
|
|
} |
|
|
|
|
|
|
|
object output_pack(List<Tensor> x) |
|
|
|
{ |
|
|
|
return output_is_sequence ? nest.pack_sequence_as(initializer, x) : x[0]; |
|
|
|
output_pack = (x) => output_is_sequence ? nest.pack_sequence_as(initializer, x) : x[0]; |
|
|
|
} |
|
|
|
|
|
|
|
var elems_flat = input_flatten(elems); |
|
|
|
|
|
|
|
bool in_graph_mode = true; // todo not context.executing_eagerly() |
|
|
|
bool in_graph_mode = true; // todo !context.executing_eagerly() |
|
|
|
|
|
|
|
//with ops.name_scope(name, "scan", elems_flat): |
|
|
|
return tf_with(ops.name_scope(name, "scan", new { elems_flat }), scope => |
|
|
|
{ |
|
|
|
//if (in_graph_mode) |
|
|
|
//{ |
|
|
|
// // Any get_variable calls in fn will cache the first call locally |
|
|
|
// // and not issue repeated network I/O requests for each iteration. |
|
|
|
// var varscope = variable_scope.get_variable_scope(); |
|
|
|
// bool varscope_caching_device_was_none = false; |
|
|
|
// if (varscope.caching_device = null) |
|
|
|
// { |
|
|
|
// // varscope.set_caching_device(lambda op: op.device) |
|
|
|
// // varscope_caching_device_was_none = True |
|
|
|
// } |
|
|
|
//} |
|
|
|
|
|
|
|
elems_flat = elems_flat.Select(elem => ops.convert_to_tensor(elem, name: "elem")).ToList(); |
|
|
|
|
|
|
|
// # Convert elems to tensor array. n may be known statically. |
|
|
|
var n = tensor_shape.dimension_value(elems_flat[0].shape[0]); |
|
|
|
//if (n == null) |
|
|
|
//{ |
|
|
|
// n = array_ops.shape(elems_flat[0])[0]; |
|
|
|
//} |
|
|
|
|
|
|
|
// # TensorArrays are always flat |
|
|
|
var elems_ta = elems_flat.Select(elem => new TensorArray( |
|
|
|
elem.dtype, |
|
|
|
size: tf.constant(n), |
|
|
|
dynamic_size: false, |
|
|
|
element_shape: elem.shape[0], //1: |
|
|
|
infer_shape: true)).ToList(); |
|
|
|
|
|
|
|
for (int index = 0; index < elems_ta.Count; index++) |
|
|
|
{ |
|
|
|
elems_ta[index].unstack(elems_flat[index]); |
|
|
|
} |
|
|
|
// todo tf.net doesn't expose .caching_device |
|
|
|
//if (in_graph_mode) |
|
|
|
//{ |
|
|
|
// // Any get_variable calls in fn will cache the first call locally |
|
|
|
// // and not issue repeated network I/O requests for each iteration. |
|
|
|
// var varscope = variable_scope.get_variable_scope(); |
|
|
|
// bool varscope_caching_device_was_none = false; |
|
|
|
// if (varscope.caching_device = null) |
|
|
|
// { |
|
|
|
// // varscope.set_caching_device(lambda op: op.device) |
|
|
|
// // varscope_caching_device_was_none = True |
|
|
|
// } |
|
|
|
//} |
|
|
|
|
|
|
|
elems_flat = elems_flat.Select(elem => ops.convert_to_tensor(elem, name: "elem")).ToList(); |
|
|
|
|
|
|
|
var n = tensor_shape.dimension_value(elems_flat[0].shape[0]); |
|
|
|
|
|
|
|
// todo python had the below but dimension_value returns int which can't be null |
|
|
|
//if (n == null) |
|
|
|
//{ |
|
|
|
// n = array_ops.shape(elems_flat[0])[0]; |
|
|
|
//} |
|
|
|
|
|
|
|
var elems_ta = elems_flat.Select(elem => new TensorArray( |
|
|
|
elem.dtype, |
|
|
|
size: tf.constant(n), |
|
|
|
dynamic_size: false, |
|
|
|
element_shape: elem.shape.Skip(1).ToArray(), |
|
|
|
infer_shape: true)).ToList(); |
|
|
|
|
|
|
|
for (int index = 0; index < elems_ta.Count; index++) |
|
|
|
{ |
|
|
|
elems_ta[index].unstack(elems_flat[index]); |
|
|
|
} |
|
|
|
|
|
|
|
List<Tensor> a_flat; |
|
|
|
int i; |
|
|
|
if (initializer == null) |
|
|
|
{ |
|
|
|
// a_flat = [elem.read(n - 1 if reverse else 0) for elem in elems_ta] |
|
|
|
a_flat = elems_ta.Select(elem => elem.read(tf.constant(reverse ? n - 1 : 0))).ToList(); |
|
|
|
i = 1; |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
List<Tensor> initializer_flat = output_flatten(initializer as Tensor); |
|
|
|
a_flat = initializer_flat.Select(init => ops.convert_to_tensor(init)).ToList(); |
|
|
|
i = 0; |
|
|
|
} |
|
|
|
List<Tensor> a_flat; |
|
|
|
int i; |
|
|
|
if (initializer == null) |
|
|
|
{ |
|
|
|
a_flat = elems_ta.Select(elem => elem.read(tf.constant(reverse ? n - 1 : 0))).ToList(); |
|
|
|
i = 1; |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
throw new NotImplementedException("Initializer not handled yet"); |
|
|
|
// todo the below in python, initializer is able to be passed as a List<Tensor> |
|
|
|
//List<Tensor> initializer_flat = output_flatten(initializer); |
|
|
|
//a_flat = initializer_flat.Select(init => ops.convert_to_tensor(init)).ToList(); |
|
|
|
//i = 0; |
|
|
|
} |
|
|
|
|
|
|
|
var accs_ta = a_flat.Select(init => new TensorArray( |
|
|
|
dtype: init.dtype, |
|
|
|
size: tf.constant(n), |
|
|
|
element_shape: infer_shape ? init.shape : null, |
|
|
|
dynamic_size: false, |
|
|
|
infer_shape: infer_shape)).ToList(); |
|
|
|
var accs_ta = a_flat.Select(init => new TensorArray( |
|
|
|
dtype: init.dtype, |
|
|
|
size: tf.constant(n), |
|
|
|
element_shape: infer_shape ? init.shape : null, |
|
|
|
dynamic_size: false, |
|
|
|
infer_shape: infer_shape)).ToList(); |
|
|
|
|
|
|
|
// if initializer is None: |
|
|
|
if (initializer == null) |
|
|
|
{ |
|
|
|
for (int index = 0; index < accs_ta.Count; index++) |
|
|
@@ -135,14 +132,14 @@ namespace Tensorflow |
|
|
|
|
|
|
|
(int, List<Tensor>, List<TensorArray>) compute(int _i, List<Tensor> a_flat_, List<TensorArray> tas) |
|
|
|
{ |
|
|
|
var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(tf.convert_to_tensor(_i))).ToList()); |
|
|
|
var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(tf.constant(_i))).ToList()); |
|
|
|
var packed_a = output_pack(a_flat_); |
|
|
|
var a_out = fn((Tensor)packed_a, (Tensor)packed_elems); // todo brendan are these casts legal? |
|
|
|
|
|
|
|
var flat_a_out = output_flatten(a_out); |
|
|
|
for (int j = 0; j < tas.Count; j++) |
|
|
|
{ |
|
|
|
tas[j].write(tf.convert_to_tensor(i), flat_a_out[j]); // todo brendan convert to tensor |
|
|
|
tas[j].write(tf.constant(i), flat_a_out[j]); |
|
|
|
} |
|
|
|
|
|
|
|
var next_i = reverse ? _i-- : _i++; |
|
|
@@ -154,13 +151,11 @@ namespace Tensorflow |
|
|
|
if (reverse) |
|
|
|
{ |
|
|
|
initial_i = n - 1 - i; |
|
|
|
// condition = lambda i, _1, _2: i >= 0 |
|
|
|
condition = x => tf.convert_to_tensor(x >= 0); |
|
|
|
condition = x => tf.constant(x >= 0); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
initial_i = i; |
|
|
|
// condition = lambda i, _1, _2: i < n |
|
|
|
condition = x => tf.convert_to_tensor(x < n); |
|
|
|
} |
|
|
|
|
|
|
@@ -178,18 +173,20 @@ namespace Tensorflow |
|
|
|
|
|
|
|
var n_static = new Dimension(tensor_shape.dimension_value(elems_flat[0].shape[0])); |
|
|
|
|
|
|
|
foreach (var elem in elems_flat) // for elem in elems_flat[1:]: |
|
|
|
foreach (var elem in elems_flat.Skip(1)) |
|
|
|
{ |
|
|
|
n_static.merge_with(new Dimension(tensor_shape.dimension_value(elem.shape[0]))); |
|
|
|
} |
|
|
|
|
|
|
|
foreach (Tensor r in results_flat) |
|
|
|
{ |
|
|
|
r.set_shape(new TensorShape(n_static).concatenate(r.shape[0])); //r.shape[1:] |
|
|
|
r.set_shape(new TensorShape(n_static).concatenate(r.shape.Skip(1).ToArray())); |
|
|
|
} |
|
|
|
|
|
|
|
// if in_graph_mode and varscope_caching_device_was_none: |
|
|
|
// varscope.set_caching_device(None) |
|
|
|
// todo get working when the above caching_device is fixed |
|
|
|
//if (in_graph_mode && varscope_caching_device_was_none) { |
|
|
|
// varscope.set_caching_device(None); |
|
|
|
//} |
|
|
|
|
|
|
|
return output_pack(results_flat); |
|
|
|
}); |
|
|
|