diff --git a/src/TensorFlowNET.Core/Data/DatasetV2.cs b/src/TensorFlowNET.Core/Data/DatasetV2.cs
index d111ac2c..bbdaa8b6 100644
--- a/src/TensorFlowNET.Core/Data/DatasetV2.cs
+++ b/src/TensorFlowNET.Core/Data/DatasetV2.cs
@@ -35,6 +35,9 @@ namespace Tensorflow
public IDatasetV2 repeat(int count = -1)
=> new RepeatDataset(this, count: count);
+ public IDatasetV2 shard(int num_shards, int index)
+ => new ShardDataset(this, num_shards, index);
+
public IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true)
=> new ShuffleDataset(this, buffer_size, seed: seed, reshuffle_each_iteration: reshuffle_each_iteration);
diff --git a/src/TensorFlowNET.Core/Data/IDatasetV2.cs b/src/TensorFlowNET.Core/Data/IDatasetV2.cs
index c5c32013..71e5fa5f 100644
--- a/src/TensorFlowNET.Core/Data/IDatasetV2.cs
+++ b/src/TensorFlowNET.Core/Data/IDatasetV2.cs
@@ -24,6 +24,14 @@ namespace Tensorflow
///
IDatasetV2 repeat(int count = -1);
+ ///
+ /// Creates a `Dataset` that includes only 1/`num_shards` of this dataset.
+ ///
+ /// The number of shards operating in parallel
+ /// The worker index
+ ///
+ IDatasetV2 shard(int num_shards, int index);
+
IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true);
IDatasetV2 batch(int batch_size, bool drop_remainder = false);
diff --git a/src/TensorFlowNET.Core/Data/ShardDataset.cs b/src/TensorFlowNET.Core/Data/ShardDataset.cs
new file mode 100644
index 00000000..4bb8ebfc
--- /dev/null
+++ b/src/TensorFlowNET.Core/Data/ShardDataset.cs
@@ -0,0 +1,31 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using static Tensorflow.Binding;
+
+namespace Tensorflow
+{
+ ///
+ /// A `Dataset` for sharding its input.
+ ///
+ public class ShardDataset : UnaryUnchangedStructureDataset
+ {
+ Tensor _num_shards;
+ Tensor _index;
+
+ public ShardDataset(IDatasetV2 input_dataset,
+ int num_shards,
+ int index) : base(input_dataset)
+ {
+ _num_shards = tf.convert_to_tensor(num_shards, dtype: TF_DataType.TF_INT64, name: "num_shards");
+ _index = tf.convert_to_tensor(index, dtype: TF_DataType.TF_INT64, name: "index");
+
+ variant_tensor = ops.shard_dataset
+ (input_dataset.variant_tensor,
+ num_shards: _num_shards,
+ index: _index,
+ input_dataset.output_types,
+ input_dataset.output_shapes);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/dataset_ops.cs b/src/TensorFlowNET.Core/Operations/dataset_ops.cs
index fbbb12b6..feeb62e0 100644
--- a/src/TensorFlowNET.Core/Operations/dataset_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/dataset_ops.cs
@@ -65,6 +65,25 @@ namespace Tensorflow
throw new NotImplementedException("");
}
+ public Tensor shard_dataset(Tensor input_dataset, Tensor num_shards, Tensor index,
+ TF_DataType[] output_types, TensorShape[] output_shapes,
+ bool require_non_empty = false, string name = null)
+ {
+ if (tf.Context.executing_eagerly())
+ {
+ var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
+ "ShardDataset", name,
+ null,
+ input_dataset, num_shards, index,
+ "require_non_empty", require_non_empty,
+ "output_types", output_types,
+ "output_shapes", output_shapes);
+ return results[0];
+ }
+
+ throw new NotImplementedException("");
+ }
+
public Tensor shuffle_dataset_v3(Tensor input_dataset, Tensor buffer_size,
Tensor seed, Tensor seed2, Tensor seed_generator,
TF_DataType[] output_types, TensorShape[] output_shapes,
diff --git a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs
index 47d53824..b8e0bcb3 100644
--- a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs
+++ b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs
@@ -61,5 +61,28 @@ namespace TensorFlowNET.UnitTest.Dataset
}
Assert.AreEqual(5, n);
}
+
+ [TestMethod]
+ public void Shard()
+ {
+ long value = 0;
+
+ var dataset1 = tf.data.Dataset.range(10);
+ var dataset2 = dataset1.shard(num_shards: 3, index: 0);
+
+ foreach (var item in dataset2)
+ {
+ Assert.AreEqual(value, (long)item.Item1);
+ value += 3;
+ }
+
+ value = 1;
+ var dataset3 = dataset1.shard(num_shards: 3, index: 1);
+ foreach (var item in dataset3)
+ {
+ Assert.AreEqual(value, (long)item.Item1);
+ value += 3;
+ }
+ }
}
}