diff --git a/src/TensorFlowNET.Core/Data/DatasetOptions.cs b/src/TensorFlowNET.Core/Data/DatasetOptions.cs new file mode 100644 index 00000000..4b1b0b56 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/DatasetOptions.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class DatasetOptions + { + } +} diff --git a/src/TensorFlowNET.Core/Data/DatasetV2.cs b/src/TensorFlowNET.Core/Data/DatasetV2.cs index 8d7512eb..e5e6eb75 100644 --- a/src/TensorFlowNET.Core/Data/DatasetV2.cs +++ b/src/TensorFlowNET.Core/Data/DatasetV2.cs @@ -60,12 +60,18 @@ namespace Tensorflow preserve_cardinality: preserve_cardinality, use_legacy_function: use_legacy_function); + public IDatasetV2 map(Func map_func, int num_parallel_calls = -1) + => new ParallelMapDataset(this, map_func, num_parallel_calls: num_parallel_calls); + public IDatasetV2 flat_map(Func map_func) => new FlatMapDataset(this, map_func); public IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget) => new ModelDataset(this, algorithm, cpu_budget); + public IDatasetV2 with_options(DatasetOptions options) + => new OptionsDataset(this, options); + public IDatasetV2 apply_options() { // (1) Apply threading options @@ -94,7 +100,7 @@ namespace Tensorflow } public override string ToString() - => $"{GetType().Name} shapes: ({structure[0].shape}, {structure[1].shape}), types: (tf.{structure[0].dtype.as_numpy_name()}, tf.{structure[1].dtype.as_numpy_name()})"; + => $"{GetType().Name} shapes: {string.Join(", ", structure.Select(x => x.shape))}, types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}"; public IEnumerator<(Tensor, Tensor)> GetEnumerator() { diff --git a/src/TensorFlowNET.Core/Data/FlatMapDataset.cs b/src/TensorFlowNET.Core/Data/FlatMapDataset.cs index 129c4f76..202a67f9 100644 --- a/src/TensorFlowNET.Core/Data/FlatMapDataset.cs +++ b/src/TensorFlowNET.Core/Data/FlatMapDataset.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; using Tensorflow.Functions; @@ -14,7 +15,7 @@ namespace Tensorflow Func map_func) : base(input_dataset) { var func = new ConcreteFunction(map_func, input_dataset.element_spec[0].dtype); - + structure = func.OutputStructure; variant_tensor = ops.flat_map_dataset(input_dataset.variant_tensor, func, output_types, diff --git a/src/TensorFlowNET.Core/Data/IDatasetV2.cs b/src/TensorFlowNET.Core/Data/IDatasetV2.cs index 6b04eecd..2a96dca2 100644 --- a/src/TensorFlowNET.Core/Data/IDatasetV2.cs +++ b/src/TensorFlowNET.Core/Data/IDatasetV2.cs @@ -62,10 +62,15 @@ namespace Tensorflow bool preserve_cardinality = false, bool use_legacy_function = false); + IDatasetV2 map(Func map_func, + int num_parallel_calls = -1); + IDatasetV2 flat_map(Func map_func); IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget); + IDatasetV2 with_options(DatasetOptions options); + /// /// Apply options, such as optimization configuration, to the dataset. /// diff --git a/src/TensorFlowNET.Core/Data/OptionsDataset.cs b/src/TensorFlowNET.Core/Data/OptionsDataset.cs new file mode 100644 index 00000000..cce22e65 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/OptionsDataset.cs @@ -0,0 +1,21 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + /// + /// An identity `Dataset` that stores options. + /// + public class OptionsDataset : UnaryUnchangedStructureDataset + { + DatasetOptions options; + + public OptionsDataset(IDatasetV2 input_dataset, DatasetOptions options) + : base(input_dataset) + { + this.options = options; + variant_tensor = input_dataset.variant_tensor; + } + } +} diff --git a/src/TensorFlowNET.Core/Data/ParallelMapDataset.cs b/src/TensorFlowNET.Core/Data/ParallelMapDataset.cs new file mode 100644 index 00000000..c84eb328 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/ParallelMapDataset.cs @@ -0,0 +1,34 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Functions; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + //A `Dataset` that maps a function over elements in its input in parallel. + public class ParallelMapDataset : UnaryDataset + { + public ParallelMapDataset(IDatasetV2 input_dataset, + Func map_func, + int num_parallel_calls = -1, + bool use_inter_op_parallelism = true, + bool preserve_cardinality = false, + bool use_legacy_function = false) : base(input_dataset) + { + var func = new ConcreteFunction(map_func, + input_dataset.element_spec.Select(x => x.dtype).ToArray(), + input_dataset.element_spec.Select(x => x.shape).ToArray()); + + structure = func.OutputStructure; + var _num_parallel_calls = tf.convert_to_tensor(num_parallel_calls, dtype: tf.int64, + name: "num_parallel_calls"); + variant_tensor = ops.parallel_map_dataset_v2(input_dataset.variant_tensor, + _num_parallel_calls, + func, + output_types, + output_shapes); + } + } +} diff --git a/src/TensorFlowNET.Core/Data/TensorDataset.cs b/src/TensorFlowNET.Core/Data/TensorDataset.cs index 0a001d61..a3584886 100644 --- a/src/TensorFlowNET.Core/Data/TensorDataset.cs +++ b/src/TensorFlowNET.Core/Data/TensorDataset.cs @@ -15,11 +15,9 @@ namespace Tensorflow public TensorDataset(Tensor feature, Tensor label) { _tensors = new[] { feature, label }; - var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray(); - structure = batched_spec.Select(x => x._unbatch()).ToArray(); + structure = _tensors.Select(x => x.ToTensorSpec()).ToArray(); variant_tensor = ops.tensor_dataset(_tensors, output_shapes); - } public TensorDataset(Tensor element) { diff --git a/src/TensorFlowNET.Core/Data/ZipDataset.cs b/src/TensorFlowNET.Core/Data/ZipDataset.cs index b5d7e189..e7fea1cd 100644 --- a/src/TensorFlowNET.Core/Data/ZipDataset.cs +++ b/src/TensorFlowNET.Core/Data/ZipDataset.cs @@ -2,16 +2,19 @@ using System.Collections.Generic; using System.Linq; using System.Text; +using Tensorflow.Framework.Models; namespace Tensorflow { public class ZipDataset : DatasetV2 { - dataset_ops ops = new dataset_ops(); public ZipDataset(params IDatasetV2[] ds) { var input_datasets = ds.Select(x => x.variant_tensor).ToArray(); - structure = ds.Select(x => x.structure[0]).ToArray(); + var _structure = new List(); + foreach (var dataset in ds) + _structure.AddRange(dataset.structure); + structure = _structure.ToArray(); variant_tensor = ops.zip_dataset(input_datasets, output_types, output_shapes); } }