Browse Source

tf.data.Dataset shard() #446

tags/v0.20
Oceania2018 5 years ago
parent
commit
a174a84d3d
5 changed files with 84 additions and 0 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/Data/DatasetV2.cs
  2. +8
    -0
      src/TensorFlowNET.Core/Data/IDatasetV2.cs
  3. +31
    -0
      src/TensorFlowNET.Core/Data/ShardDataset.cs
  4. +19
    -0
      src/TensorFlowNET.Core/Operations/dataset_ops.cs
  5. +23
    -0
      test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs

+ 3
- 0
src/TensorFlowNET.Core/Data/DatasetV2.cs View File

@@ -35,6 +35,9 @@ namespace Tensorflow
public IDatasetV2 repeat(int count = -1)
=> new RepeatDataset(this, count: count);

public IDatasetV2 shard(int num_shards, int index)
=> new ShardDataset(this, num_shards, index);

public IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true)
=> new ShuffleDataset(this, buffer_size, seed: seed, reshuffle_each_iteration: reshuffle_each_iteration);



+ 8
- 0
src/TensorFlowNET.Core/Data/IDatasetV2.cs View File

@@ -24,6 +24,14 @@ namespace Tensorflow
/// <returns></returns>
IDatasetV2 repeat(int count = -1);

/// <summary>
/// Creates a `Dataset` that includes only 1/`num_shards` of this dataset.
/// </summary>
/// <param name="num_shards">The number of shards operating in parallel</param>
/// <param name="index">The worker index</param>
/// <returns></returns>
IDatasetV2 shard(int num_shards, int index);

IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true);

IDatasetV2 batch(int batch_size, bool drop_remainder = false);


+ 31
- 0
src/TensorFlowNET.Core/Data/ShardDataset.cs View File

@@ -0,0 +1,31 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow
{
/// <summary>
/// A `Dataset` for sharding its input.
/// </summary>
public class ShardDataset : UnaryUnchangedStructureDataset
{
Tensor _num_shards;
Tensor _index;

public ShardDataset(IDatasetV2 input_dataset,
int num_shards,
int index) : base(input_dataset)
{
_num_shards = tf.convert_to_tensor(num_shards, dtype: TF_DataType.TF_INT64, name: "num_shards");
_index = tf.convert_to_tensor(index, dtype: TF_DataType.TF_INT64, name: "index");

variant_tensor = ops.shard_dataset
(input_dataset.variant_tensor,
num_shards: _num_shards,
index: _index,
input_dataset.output_types,
input_dataset.output_shapes);
}
}
}

+ 19
- 0
src/TensorFlowNET.Core/Operations/dataset_ops.cs View File

@@ -65,6 +65,25 @@ namespace Tensorflow
throw new NotImplementedException("");
}

public Tensor shard_dataset(Tensor input_dataset, Tensor num_shards, Tensor index,
TF_DataType[] output_types, TensorShape[] output_shapes,
bool require_non_empty = false, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"ShardDataset", name,
null,
input_dataset, num_shards, index,
"require_non_empty", require_non_empty,
"output_types", output_types,
"output_shapes", output_shapes);
return results[0];
}

throw new NotImplementedException("");
}

public Tensor shuffle_dataset_v3(Tensor input_dataset, Tensor buffer_size,
Tensor seed, Tensor seed2, Tensor seed_generator,
TF_DataType[] output_types, TensorShape[] output_shapes,


+ 23
- 0
test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs View File

@@ -61,5 +61,28 @@ namespace TensorFlowNET.UnitTest.Dataset
}
Assert.AreEqual(5, n);
}

[TestMethod]
public void Shard()
{
long value = 0;

var dataset1 = tf.data.Dataset.range(10);
var dataset2 = dataset1.shard(num_shards: 3, index: 0);

foreach (var item in dataset2)
{
Assert.AreEqual(value, (long)item.Item1);
value += 3;
}

value = 1;
var dataset3 = dataset1.shard(num_shards: 3, index: 1);
foreach (var item in dataset3)
{
Assert.AreEqual(value, (long)item.Item1);
value += 3;
}
}
}
}

Loading…
Cancel
Save