diff --git a/src/TensorFlowNET.Core/APIs/tf.scan.cs b/src/TensorFlowNET.Core/APIs/tf.scan.cs index 439b0512..5642eaaf 100644 --- a/src/TensorFlowNET.Core/APIs/tf.scan.cs +++ b/src/TensorFlowNET.Core/APIs/tf.scan.cs @@ -23,7 +23,7 @@ namespace Tensorflow public Tensor scan( Func fn, Tensor elems, - IInitializer initializer = null, + Tensor initializer = null, int parallel_iterations = 10, bool back_prop = true, bool swap_memory = false, diff --git a/src/TensorFlowNET.Core/Operations/functional_ops.cs b/src/TensorFlowNET.Core/Operations/functional_ops.cs index 68e56fb9..f0e1aa1c 100644 --- a/src/TensorFlowNET.Core/Operations/functional_ops.cs +++ b/src/TensorFlowNET.Core/Operations/functional_ops.cs @@ -28,7 +28,7 @@ namespace Tensorflow public static Tensor scan( Func fn, Tensor elems, - IInitializer initializer = null, + Tensor initializer = null, int parallel_iterations = 10, bool back_prop = true, bool swap_memory = false, @@ -108,11 +108,9 @@ namespace Tensorflow } else { - throw new NotImplementedException("Initializer not handled yet"); - // todo the below in python, initializer is able to be passed as a List - //List initializer_flat = output_flatten(initializer); - //a_flat = initializer_flat.Select(init => ops.convert_to_tensor(init)).ToList(); - //i = 0; + List initializer_flat = output_flatten(initializer); + a_flat = initializer_flat.Select(init => ops.convert_to_tensor(init)).ToList(); + i = 0; } var accs_ta = a_flat.Select(init => new TensorArray(