diff --git a/src/TensorFlowNET.Core/Data/ConcatenateDataset.cs b/src/TensorFlowNET.Core/Data/ConcatenateDataset.cs new file mode 100644 index 00000000..9d4abd6b --- /dev/null +++ b/src/TensorFlowNET.Core/Data/ConcatenateDataset.cs @@ -0,0 +1,35 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Framework; +using Tensorflow.Framework.Models; +using static Tensorflow.Binding; + +namespace Tensorflow.Data +{ + /// + /// A `Dataset` that concatenates its input with given dataset. + /// + public class ConcatenateDataset : DatasetV2 + { + IDatasetV2 _input_dataset; + IDatasetV2 _dataset_to_concatenate; + public ConcatenateDataset(IDatasetV2 input_dataset, IDatasetV2 dataset_to_concatenate) + { + _input_dataset = input_dataset; + _dataset_to_concatenate = dataset_to_concatenate; + var _structure = new List(); + foreach(var (i, spec) in enumerate(dataset_to_concatenate.element_spec)) + { + var shape = _input_dataset.output_shapes[i].most_specific_compatible_shape(spec.shape); + _structure.Add(new TensorSpec(shape, dtype: spec.dtype)); + } + structure = _structure.ToArray(); + + variant_tensor = ops.concatenate_dataset(input_dataset.variant_tensor, + dataset_to_concatenate.variant_tensor, + output_types, output_shapes); + } + } +} diff --git a/src/TensorFlowNET.Core/Data/DatasetV2.cs b/src/TensorFlowNET.Core/Data/DatasetV2.cs index 850211d1..2abe9970 100644 --- a/src/TensorFlowNET.Core/Data/DatasetV2.cs +++ b/src/TensorFlowNET.Core/Data/DatasetV2.cs @@ -2,6 +2,7 @@ using System.Collections; using System.Collections.Generic; using System.Linq; +using Tensorflow.Data; using Tensorflow.Framework.Models; using static Tensorflow.Binding; @@ -26,6 +27,9 @@ namespace Tensorflow public IDatasetV2 cache(string filename = "") => new CacheDataset(this, filename: filename); + public IDatasetV2 concatenate(IDatasetV2 dataset) + => new ConcatenateDataset(this, dataset); + public IDatasetV2 take(int count = -1) => new TakeDataset(this, count: count); diff --git a/src/TensorFlowNET.Core/Data/IDatasetV2.cs b/src/TensorFlowNET.Core/Data/IDatasetV2.cs index 5240f550..4d9b00d2 100644 --- a/src/TensorFlowNET.Core/Data/IDatasetV2.cs +++ b/src/TensorFlowNET.Core/Data/IDatasetV2.cs @@ -23,6 +23,13 @@ namespace Tensorflow /// IDatasetV2 cache(string filename = ""); + /// + /// Creates a `Dataset` by concatenating the given dataset with this dataset. + /// + /// + /// + IDatasetV2 concatenate(IDatasetV2 dataset); + /// /// /// diff --git a/src/TensorFlowNET.Core/Data/OwnedIterator.cs b/src/TensorFlowNET.Core/Data/OwnedIterator.cs index 6327d174..84ca0a9b 100644 --- a/src/TensorFlowNET.Core/Data/OwnedIterator.cs +++ b/src/TensorFlowNET.Core/Data/OwnedIterator.cs @@ -1,5 +1,7 @@ using System; +using System.Linq; using Tensorflow.Framework.Models; +using static Tensorflow.Binding; namespace Tensorflow { @@ -36,7 +38,10 @@ namespace Tensorflow { try { - return ops.iterator_get_next(_iterator_resource, _dataset.output_types, _dataset.output_shapes); + var results = ops.iterator_get_next(_iterator_resource, _dataset.output_types, _dataset.output_shapes); + foreach(var (i, tensor) in enumerate(results)) + tensor.set_shape(_element_spec[i].shape); + return results; } catch (OutOfRangeError ex) {