diff --git a/src/TensorFlowNET.Core/Data/MapDataset.cs b/src/TensorFlowNET.Core/Data/MapDataset.cs index 231e613e..3ea54233 100644 --- a/src/TensorFlowNET.Core/Data/MapDataset.cs +++ b/src/TensorFlowNET.Core/Data/MapDataset.cs @@ -1,5 +1,6 @@ using System; using Tensorflow.Functions; +using static Tensorflow.Binding; namespace Tensorflow { @@ -14,7 +15,12 @@ namespace Tensorflow bool preserve_cardinality = false, bool use_legacy_function = false) : base(input_dataset) { - var func = new ConcreteFunction(map_func, input_dataset.element_spec[0].dtype); + var func = new ConcreteFunction($"autograph_{map_func.Method.Name}"); + var input = tf.placeholder(input_dataset.element_spec[0].dtype, name: "input"); + var output = map_func(input); + func.ToGraph(input, output); + + structure = func.OutputStructure; variant_tensor = ops.map_dataset(input_dataset.variant_tensor, func, diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index a3067182..e4754860 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -109,6 +109,8 @@ namespace Tensorflow.Functions inputs, outputs, null); + + OutputStructure = outputs.Select(x => x.ToTensorSpec()).ToArray(); } public Tensors Invoke(Tensors inputs)