using System; using Tensorflow.Functions; using static Tensorflow.Binding; namespace Tensorflow { /// /// A `Dataset` that maps a function over elements in its input. /// public class MapDataset : UnaryDataset { public MapDataset(IDatasetV2 input_dataset, Func map_func, bool use_inter_op_parallelism = true, bool preserve_cardinality = false, bool use_legacy_function = false) : base(input_dataset) { using var func = new ConcreteFunction($"autograph_{map_func.Method.Name}"); var input = tf.placeholder(input_dataset.element_spec[0].dtype); var output = map_func(input); func.ToGraph(input, output); structure = func.OutputStructure; variant_tensor = ops.map_dataset(input_dataset.variant_tensor, func, output_types, output_shapes); } } }