@@ -41,6 +41,9 @@ namespace Tensorflow | |||||
public IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true) | public IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true) | ||||
=> new ShuffleDataset(this, buffer_size, seed: seed, reshuffle_each_iteration: reshuffle_each_iteration); | => new ShuffleDataset(this, buffer_size, seed: seed, reshuffle_each_iteration: reshuffle_each_iteration); | ||||
public IDatasetV2 skip(int count) | |||||
=> new SkipDataset(this, count); | |||||
public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs) | public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs) | ||||
=> new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs); | => new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs); | ||||
@@ -34,6 +34,13 @@ namespace Tensorflow | |||||
IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true); | IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true); | ||||
/// <summary> | |||||
/// Creates a `Dataset` that skips `count` elements from this dataset. | |||||
/// </summary> | |||||
/// <param name="count"></param> | |||||
/// <returns></returns> | |||||
IDatasetV2 skip(int count); | |||||
IDatasetV2 batch(int batch_size, bool drop_remainder = false); | IDatasetV2 batch(int batch_size, bool drop_remainder = false); | ||||
IDatasetV2 prefetch(int buffer_size = -1, int? slack_period = null); | IDatasetV2 prefetch(int buffer_size = -1, int? slack_period = null); | ||||
@@ -0,0 +1,24 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow | |||||
{ | |||||
/// <summary> | |||||
/// A `Dataset` skipping the first `count` elements from its input. | |||||
/// </summary> | |||||
public class SkipDataset : UnaryUnchangedStructureDataset | |||||
{ | |||||
Tensor _count; | |||||
public SkipDataset(IDatasetV2 input_dataset, | |||||
int count) : base(input_dataset) | |||||
{ | |||||
_count = tf.convert_to_tensor(count, dtype: dtypes.int64, name: "count"); | |||||
variant_tensor = ops.skip_dataset(input_dataset.variant_tensor, | |||||
_count, | |||||
output_types, output_shapes); | |||||
} | |||||
} | |||||
} |
@@ -106,6 +106,24 @@ namespace Tensorflow | |||||
throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
} | } | ||||
public Tensor skip_dataset(Tensor input_dataset, Tensor count, | |||||
TF_DataType[] output_types, TensorShape[] output_shapes, | |||||
string name = null) | |||||
{ | |||||
if (tf.Context.executing_eagerly()) | |||||
{ | |||||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
"SkipDataset", name, | |||||
null, | |||||
input_dataset, count, | |||||
"output_types", output_types, | |||||
"output_shapes", output_shapes); | |||||
return results[0]; | |||||
} | |||||
throw new NotImplementedException(""); | |||||
} | |||||
public Tensor dummy_seed_generator(string name = null) | public Tensor dummy_seed_generator(string name = null) | ||||
{ | { | ||||
if (tf.Context.executing_eagerly()) | if (tf.Context.executing_eagerly()) | ||||
@@ -84,5 +84,20 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
value += 3; | value += 3; | ||||
} | } | ||||
} | } | ||||
[TestMethod] | |||||
public void Skip() | |||||
{ | |||||
long value = 7; | |||||
var dataset = tf.data.Dataset.range(10); | |||||
dataset = dataset.skip(7); | |||||
foreach (var item in dataset) | |||||
{ | |||||
Assert.AreEqual(value, (long)item.Item1); | |||||
value ++; | |||||
} | |||||
} | |||||
} | } | ||||
} | } |