Browse Source

tf.data.Dataset skip() #446

tags/v0.20
Oceania2018 5 years ago
parent
commit
68df1b70f0
5 changed files with 67 additions and 0 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/Data/DatasetV2.cs
  2. +7
    -0
      src/TensorFlowNET.Core/Data/IDatasetV2.cs
  3. +24
    -0
      src/TensorFlowNET.Core/Data/SkipDataset.cs
  4. +18
    -0
      src/TensorFlowNET.Core/Operations/dataset_ops.cs
  5. +15
    -0
      test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs

+ 3
- 0
src/TensorFlowNET.Core/Data/DatasetV2.cs View File

@@ -41,6 +41,9 @@ namespace Tensorflow
public IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true) public IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true)
=> new ShuffleDataset(this, buffer_size, seed: seed, reshuffle_each_iteration: reshuffle_each_iteration); => new ShuffleDataset(this, buffer_size, seed: seed, reshuffle_each_iteration: reshuffle_each_iteration);


public IDatasetV2 skip(int count)
=> new SkipDataset(this, count);

public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs) public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs)
=> new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs); => new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs);




+ 7
- 0
src/TensorFlowNET.Core/Data/IDatasetV2.cs View File

@@ -34,6 +34,13 @@ namespace Tensorflow


IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true); IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true);


/// <summary>
/// Creates a `Dataset` that skips `count` elements from this dataset.
/// </summary>
/// <param name="count"></param>
/// <returns></returns>
IDatasetV2 skip(int count);

IDatasetV2 batch(int batch_size, bool drop_remainder = false); IDatasetV2 batch(int batch_size, bool drop_remainder = false);


IDatasetV2 prefetch(int buffer_size = -1, int? slack_period = null); IDatasetV2 prefetch(int buffer_size = -1, int? slack_period = null);


+ 24
- 0
src/TensorFlowNET.Core/Data/SkipDataset.cs View File

@@ -0,0 +1,24 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow
{
/// <summary>
/// A `Dataset` skipping the first `count` elements from its input.
/// </summary>
public class SkipDataset : UnaryUnchangedStructureDataset
{
Tensor _count;

public SkipDataset(IDatasetV2 input_dataset,
int count) : base(input_dataset)
{
_count = tf.convert_to_tensor(count, dtype: dtypes.int64, name: "count");
variant_tensor = ops.skip_dataset(input_dataset.variant_tensor,
_count,
output_types, output_shapes);
}
}
}

+ 18
- 0
src/TensorFlowNET.Core/Operations/dataset_ops.cs View File

@@ -106,6 +106,24 @@ namespace Tensorflow
throw new NotImplementedException(""); throw new NotImplementedException("");
} }


public Tensor skip_dataset(Tensor input_dataset, Tensor count,
TF_DataType[] output_types, TensorShape[] output_shapes,
string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"SkipDataset", name,
null,
input_dataset, count,
"output_types", output_types,
"output_shapes", output_shapes);
return results[0];
}

throw new NotImplementedException("");
}

public Tensor dummy_seed_generator(string name = null) public Tensor dummy_seed_generator(string name = null)
{ {
if (tf.Context.executing_eagerly()) if (tf.Context.executing_eagerly())


+ 15
- 0
test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs View File

@@ -84,5 +84,20 @@ namespace TensorFlowNET.UnitTest.Dataset
value += 3; value += 3;
} }
} }

[TestMethod]
public void Skip()
{
long value = 7;

var dataset = tf.data.Dataset.range(10);
dataset = dataset.skip(7);

foreach (var item in dataset)
{
Assert.AreEqual(value, (long)item.Item1);
value ++;
}
}
} }
} }

Loading…
Cancel
Save