using System;
using Tensorflow.Functions;
using static Tensorflow.Binding;
namespace Tensorflow
{
///
/// A `Dataset` that maps a function over elements in its input.
///
public class MapDataset : UnaryDataset
{
public MapDataset(IDatasetV2 input_dataset,
Func map_func,
bool use_inter_op_parallelism = true,
bool preserve_cardinality = false,
bool use_legacy_function = false) : base(input_dataset)
{
using var func = new ConcreteFunction($"autograph_{map_func.Method.Name}");
var input = tf.placeholder(input_dataset.element_spec[0].dtype);
var output = map_func(input);
func.ToGraph(input, output);
structure = func.OutputStructure;
variant_tensor = ops.map_dataset(input_dataset.variant_tensor,
func,
output_types,
output_shapes);
}
}
}