From 68df1b70f00626b51f395485b4f6ad1ab1e9913a Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 22 Aug 2020 18:08:23 -0500 Subject: [PATCH] tf.data.Dataset skip() #446 --- src/TensorFlowNET.Core/Data/DatasetV2.cs | 3 +++ src/TensorFlowNET.Core/Data/IDatasetV2.cs | 7 ++++++ src/TensorFlowNET.Core/Data/SkipDataset.cs | 24 +++++++++++++++++++ .../Operations/dataset_ops.cs | 18 ++++++++++++++ .../Dataset/DatasetTest.cs | 15 ++++++++++++ 5 files changed, 67 insertions(+) create mode 100644 src/TensorFlowNET.Core/Data/SkipDataset.cs diff --git a/src/TensorFlowNET.Core/Data/DatasetV2.cs b/src/TensorFlowNET.Core/Data/DatasetV2.cs index bbdaa8b6..cd330f4a 100644 --- a/src/TensorFlowNET.Core/Data/DatasetV2.cs +++ b/src/TensorFlowNET.Core/Data/DatasetV2.cs @@ -41,6 +41,9 @@ namespace Tensorflow public IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true) => new ShuffleDataset(this, buffer_size, seed: seed, reshuffle_each_iteration: reshuffle_each_iteration); + public IDatasetV2 skip(int count) + => new SkipDataset(this, count); + public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs) => new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs); diff --git a/src/TensorFlowNET.Core/Data/IDatasetV2.cs b/src/TensorFlowNET.Core/Data/IDatasetV2.cs index 71e5fa5f..1a0d88cf 100644 --- a/src/TensorFlowNET.Core/Data/IDatasetV2.cs +++ b/src/TensorFlowNET.Core/Data/IDatasetV2.cs @@ -34,6 +34,13 @@ namespace Tensorflow IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true); + /// + /// Creates a `Dataset` that skips `count` elements from this dataset. + /// + /// + /// + IDatasetV2 skip(int count); + IDatasetV2 batch(int batch_size, bool drop_remainder = false); IDatasetV2 prefetch(int buffer_size = -1, int? slack_period = null); diff --git a/src/TensorFlowNET.Core/Data/SkipDataset.cs b/src/TensorFlowNET.Core/Data/SkipDataset.cs new file mode 100644 index 00000000..1bcfd3fa --- /dev/null +++ b/src/TensorFlowNET.Core/Data/SkipDataset.cs @@ -0,0 +1,24 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// A `Dataset` skipping the first `count` elements from its input. + /// + public class SkipDataset : UnaryUnchangedStructureDataset + { + Tensor _count; + + public SkipDataset(IDatasetV2 input_dataset, + int count) : base(input_dataset) + { + _count = tf.convert_to_tensor(count, dtype: dtypes.int64, name: "count"); + variant_tensor = ops.skip_dataset(input_dataset.variant_tensor, + _count, + output_types, output_shapes); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/dataset_ops.cs b/src/TensorFlowNET.Core/Operations/dataset_ops.cs index feeb62e0..276dc462 100644 --- a/src/TensorFlowNET.Core/Operations/dataset_ops.cs +++ b/src/TensorFlowNET.Core/Operations/dataset_ops.cs @@ -106,6 +106,24 @@ namespace Tensorflow throw new NotImplementedException(""); } + public Tensor skip_dataset(Tensor input_dataset, Tensor count, + TF_DataType[] output_types, TensorShape[] output_shapes, + string name = null) + { + if (tf.Context.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "SkipDataset", name, + null, + input_dataset, count, + "output_types", output_types, + "output_shapes", output_shapes); + return results[0]; + } + + throw new NotImplementedException(""); + } + public Tensor dummy_seed_generator(string name = null) { if (tf.Context.executing_eagerly()) diff --git a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs index b8e0bcb3..37430980 100644 --- a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs +++ b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs @@ -84,5 +84,20 @@ namespace TensorFlowNET.UnitTest.Dataset value += 3; } } + + [TestMethod] + public void Skip() + { + long value = 7; + + var dataset = tf.data.Dataset.range(10); + dataset = dataset.skip(7); + + foreach (var item in dataset) + { + Assert.AreEqual(value, (long)item.Item1); + value ++; + } + } } }