diff --git a/src/TensorFlowNET.Core/Data/ConcatenateDataset.cs b/src/TensorFlowNET.Core/Data/ConcatenateDataset.cs
new file mode 100644
index 00000000..9d4abd6b
--- /dev/null
+++ b/src/TensorFlowNET.Core/Data/ConcatenateDataset.cs
@@ -0,0 +1,35 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using Tensorflow.Framework;
+using Tensorflow.Framework.Models;
+using static Tensorflow.Binding;
+
+namespace Tensorflow.Data
+{
+ ///
+ /// A `Dataset` that concatenates its input with given dataset.
+ ///
+ public class ConcatenateDataset : DatasetV2
+ {
+ IDatasetV2 _input_dataset;
+ IDatasetV2 _dataset_to_concatenate;
+ public ConcatenateDataset(IDatasetV2 input_dataset, IDatasetV2 dataset_to_concatenate)
+ {
+ _input_dataset = input_dataset;
+ _dataset_to_concatenate = dataset_to_concatenate;
+ var _structure = new List();
+ foreach(var (i, spec) in enumerate(dataset_to_concatenate.element_spec))
+ {
+ var shape = _input_dataset.output_shapes[i].most_specific_compatible_shape(spec.shape);
+ _structure.Add(new TensorSpec(shape, dtype: spec.dtype));
+ }
+ structure = _structure.ToArray();
+
+ variant_tensor = ops.concatenate_dataset(input_dataset.variant_tensor,
+ dataset_to_concatenate.variant_tensor,
+ output_types, output_shapes);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Data/DatasetV2.cs b/src/TensorFlowNET.Core/Data/DatasetV2.cs
index 850211d1..2abe9970 100644
--- a/src/TensorFlowNET.Core/Data/DatasetV2.cs
+++ b/src/TensorFlowNET.Core/Data/DatasetV2.cs
@@ -2,6 +2,7 @@
using System.Collections;
using System.Collections.Generic;
using System.Linq;
+using Tensorflow.Data;
using Tensorflow.Framework.Models;
using static Tensorflow.Binding;
@@ -26,6 +27,9 @@ namespace Tensorflow
public IDatasetV2 cache(string filename = "")
=> new CacheDataset(this, filename: filename);
+ public IDatasetV2 concatenate(IDatasetV2 dataset)
+ => new ConcatenateDataset(this, dataset);
+
public IDatasetV2 take(int count = -1)
=> new TakeDataset(this, count: count);
diff --git a/src/TensorFlowNET.Core/Data/IDatasetV2.cs b/src/TensorFlowNET.Core/Data/IDatasetV2.cs
index 5240f550..4d9b00d2 100644
--- a/src/TensorFlowNET.Core/Data/IDatasetV2.cs
+++ b/src/TensorFlowNET.Core/Data/IDatasetV2.cs
@@ -23,6 +23,13 @@ namespace Tensorflow
///
IDatasetV2 cache(string filename = "");
+ ///
+ /// Creates a `Dataset` by concatenating the given dataset with this dataset.
+ ///
+ ///
+ ///
+ IDatasetV2 concatenate(IDatasetV2 dataset);
+
///
///
///
diff --git a/src/TensorFlowNET.Core/Data/OwnedIterator.cs b/src/TensorFlowNET.Core/Data/OwnedIterator.cs
index 6327d174..84ca0a9b 100644
--- a/src/TensorFlowNET.Core/Data/OwnedIterator.cs
+++ b/src/TensorFlowNET.Core/Data/OwnedIterator.cs
@@ -1,5 +1,7 @@
using System;
+using System.Linq;
using Tensorflow.Framework.Models;
+using static Tensorflow.Binding;
namespace Tensorflow
{
@@ -36,7 +38,10 @@ namespace Tensorflow
{
try
{
- return ops.iterator_get_next(_iterator_resource, _dataset.output_types, _dataset.output_shapes);
+ var results = ops.iterator_get_next(_iterator_resource, _dataset.output_types, _dataset.output_shapes);
+ foreach(var (i, tensor) in enumerate(results))
+ tensor.set_shape(_element_spec[i].shape);
+ return results;
}
catch (OutOfRangeError ex)
{