From a174a84d3d8c3a4c1a3b65c5e582a38f6ab93843 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 22 Aug 2020 17:50:21 -0500 Subject: [PATCH] tf.data.Dataset shard() #446 --- src/TensorFlowNET.Core/Data/DatasetV2.cs | 3 ++ src/TensorFlowNET.Core/Data/IDatasetV2.cs | 8 +++++ src/TensorFlowNET.Core/Data/ShardDataset.cs | 31 +++++++++++++++++++ .../Operations/dataset_ops.cs | 19 ++++++++++++ .../Dataset/DatasetTest.cs | 23 ++++++++++++++ 5 files changed, 84 insertions(+) create mode 100644 src/TensorFlowNET.Core/Data/ShardDataset.cs diff --git a/src/TensorFlowNET.Core/Data/DatasetV2.cs b/src/TensorFlowNET.Core/Data/DatasetV2.cs index d111ac2c..bbdaa8b6 100644 --- a/src/TensorFlowNET.Core/Data/DatasetV2.cs +++ b/src/TensorFlowNET.Core/Data/DatasetV2.cs @@ -35,6 +35,9 @@ namespace Tensorflow public IDatasetV2 repeat(int count = -1) => new RepeatDataset(this, count: count); + public IDatasetV2 shard(int num_shards, int index) + => new ShardDataset(this, num_shards, index); + public IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true) => new ShuffleDataset(this, buffer_size, seed: seed, reshuffle_each_iteration: reshuffle_each_iteration); diff --git a/src/TensorFlowNET.Core/Data/IDatasetV2.cs b/src/TensorFlowNET.Core/Data/IDatasetV2.cs index c5c32013..71e5fa5f 100644 --- a/src/TensorFlowNET.Core/Data/IDatasetV2.cs +++ b/src/TensorFlowNET.Core/Data/IDatasetV2.cs @@ -24,6 +24,14 @@ namespace Tensorflow /// IDatasetV2 repeat(int count = -1); + /// + /// Creates a `Dataset` that includes only 1/`num_shards` of this dataset. + /// + /// The number of shards operating in parallel + /// The worker index + /// + IDatasetV2 shard(int num_shards, int index); + IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true); IDatasetV2 batch(int batch_size, bool drop_remainder = false); diff --git a/src/TensorFlowNET.Core/Data/ShardDataset.cs b/src/TensorFlowNET.Core/Data/ShardDataset.cs new file mode 100644 index 00000000..4bb8ebfc --- /dev/null +++ b/src/TensorFlowNET.Core/Data/ShardDataset.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// A `Dataset` for sharding its input. + /// + public class ShardDataset : UnaryUnchangedStructureDataset + { + Tensor _num_shards; + Tensor _index; + + public ShardDataset(IDatasetV2 input_dataset, + int num_shards, + int index) : base(input_dataset) + { + _num_shards = tf.convert_to_tensor(num_shards, dtype: TF_DataType.TF_INT64, name: "num_shards"); + _index = tf.convert_to_tensor(index, dtype: TF_DataType.TF_INT64, name: "index"); + + variant_tensor = ops.shard_dataset + (input_dataset.variant_tensor, + num_shards: _num_shards, + index: _index, + input_dataset.output_types, + input_dataset.output_shapes); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/dataset_ops.cs b/src/TensorFlowNET.Core/Operations/dataset_ops.cs index fbbb12b6..feeb62e0 100644 --- a/src/TensorFlowNET.Core/Operations/dataset_ops.cs +++ b/src/TensorFlowNET.Core/Operations/dataset_ops.cs @@ -65,6 +65,25 @@ namespace Tensorflow throw new NotImplementedException(""); } + public Tensor shard_dataset(Tensor input_dataset, Tensor num_shards, Tensor index, + TF_DataType[] output_types, TensorShape[] output_shapes, + bool require_non_empty = false, string name = null) + { + if (tf.Context.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "ShardDataset", name, + null, + input_dataset, num_shards, index, + "require_non_empty", require_non_empty, + "output_types", output_types, + "output_shapes", output_shapes); + return results[0]; + } + + throw new NotImplementedException(""); + } + public Tensor shuffle_dataset_v3(Tensor input_dataset, Tensor buffer_size, Tensor seed, Tensor seed2, Tensor seed_generator, TF_DataType[] output_types, TensorShape[] output_shapes, diff --git a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs index 47d53824..b8e0bcb3 100644 --- a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs +++ b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs @@ -61,5 +61,28 @@ namespace TensorFlowNET.UnitTest.Dataset } Assert.AreEqual(5, n); } + + [TestMethod] + public void Shard() + { + long value = 0; + + var dataset1 = tf.data.Dataset.range(10); + var dataset2 = dataset1.shard(num_shards: 3, index: 0); + + foreach (var item in dataset2) + { + Assert.AreEqual(value, (long)item.Item1); + value += 3; + } + + value = 1; + var dataset3 = dataset1.shard(num_shards: 3, index: 1); + foreach (var item in dataset3) + { + Assert.AreEqual(value, (long)item.Item1); + value += 3; + } + } } }