Browse Source

Add ParallelMapDataset

tags/v0.30
Oceania2018 4 years ago
parent
commit
ab36f2cb5a
8 changed files with 85 additions and 7 deletions
  1. +10
    -0
      src/TensorFlowNET.Core/Data/DatasetOptions.cs
  2. +7
    -1
      src/TensorFlowNET.Core/Data/DatasetV2.cs
  3. +2
    -1
      src/TensorFlowNET.Core/Data/FlatMapDataset.cs
  4. +5
    -0
      src/TensorFlowNET.Core/Data/IDatasetV2.cs
  5. +21
    -0
      src/TensorFlowNET.Core/Data/OptionsDataset.cs
  6. +34
    -0
      src/TensorFlowNET.Core/Data/ParallelMapDataset.cs
  7. +1
    -3
      src/TensorFlowNET.Core/Data/TensorDataset.cs
  8. +5
    -2
      src/TensorFlowNET.Core/Data/ZipDataset.cs

+ 10
- 0
src/TensorFlowNET.Core/Data/DatasetOptions.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class DatasetOptions
{
}
}

+ 7
- 1
src/TensorFlowNET.Core/Data/DatasetV2.cs View File

@@ -60,12 +60,18 @@ namespace Tensorflow
preserve_cardinality: preserve_cardinality,
use_legacy_function: use_legacy_function);

public IDatasetV2 map(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> map_func, int num_parallel_calls = -1)
=> new ParallelMapDataset(this, map_func, num_parallel_calls: num_parallel_calls);

public IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func)
=> new FlatMapDataset(this, map_func);

public IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget)
=> new ModelDataset(this, algorithm, cpu_budget);

public IDatasetV2 with_options(DatasetOptions options)
=> new OptionsDataset(this, options);

public IDatasetV2 apply_options()
{
// (1) Apply threading options
@@ -94,7 +100,7 @@ namespace Tensorflow
}

public override string ToString()
=> $"{GetType().Name} shapes: ({structure[0].shape}, {structure[1].shape}), types: (tf.{structure[0].dtype.as_numpy_name()}, tf.{structure[1].dtype.as_numpy_name()})";
=> $"{GetType().Name} shapes: {string.Join(", ", structure.Select(x => x.shape))}, types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}";

public IEnumerator<(Tensor, Tensor)> GetEnumerator()
{


+ 2
- 1
src/TensorFlowNET.Core/Data/FlatMapDataset.cs View File

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Functions;

@@ -14,7 +15,7 @@ namespace Tensorflow
Func<Tensor, IDatasetV2> map_func) : base(input_dataset)
{
var func = new ConcreteFunction(map_func, input_dataset.element_spec[0].dtype);
structure = func.OutputStructure;
variant_tensor = ops.flat_map_dataset(input_dataset.variant_tensor,
func,
output_types,


+ 5
- 0
src/TensorFlowNET.Core/Data/IDatasetV2.cs View File

@@ -62,10 +62,15 @@ namespace Tensorflow
bool preserve_cardinality = false,
bool use_legacy_function = false);

IDatasetV2 map(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> map_func,
int num_parallel_calls = -1);

IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func);

IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget);

IDatasetV2 with_options(DatasetOptions options);

/// <summary>
/// Apply options, such as optimization configuration, to the dataset.
/// </summary>


+ 21
- 0
src/TensorFlowNET.Core/Data/OptionsDataset.cs View File

@@ -0,0 +1,21 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
/// <summary>
/// An identity `Dataset` that stores options.
/// </summary>
public class OptionsDataset : UnaryUnchangedStructureDataset
{
DatasetOptions options;

public OptionsDataset(IDatasetV2 input_dataset, DatasetOptions options)
: base(input_dataset)
{
this.options = options;
variant_tensor = input_dataset.variant_tensor;
}
}
}

+ 34
- 0
src/TensorFlowNET.Core/Data/ParallelMapDataset.cs View File

@@ -0,0 +1,34 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Functions;
using static Tensorflow.Binding;

namespace Tensorflow
{
//A `Dataset` that maps a function over elements in its input in parallel.
public class ParallelMapDataset : UnaryDataset
{
public ParallelMapDataset(IDatasetV2 input_dataset,
Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> map_func,
int num_parallel_calls = -1,
bool use_inter_op_parallelism = true,
bool preserve_cardinality = false,
bool use_legacy_function = false) : base(input_dataset)
{
var func = new ConcreteFunction(map_func,
input_dataset.element_spec.Select(x => x.dtype).ToArray(),
input_dataset.element_spec.Select(x => x.shape).ToArray());

structure = func.OutputStructure;
var _num_parallel_calls = tf.convert_to_tensor(num_parallel_calls, dtype: tf.int64,
name: "num_parallel_calls");
variant_tensor = ops.parallel_map_dataset_v2(input_dataset.variant_tensor,
_num_parallel_calls,
func,
output_types,
output_shapes);
}
}
}

+ 1
- 3
src/TensorFlowNET.Core/Data/TensorDataset.cs View File

@@ -15,11 +15,9 @@ namespace Tensorflow
public TensorDataset(Tensor feature, Tensor label)
{
_tensors = new[] { feature, label };
var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray();
structure = batched_spec.Select(x => x._unbatch()).ToArray();
structure = _tensors.Select(x => x.ToTensorSpec()).ToArray();

variant_tensor = ops.tensor_dataset(_tensors, output_shapes);

}
public TensorDataset(Tensor element)
{


+ 5
- 2
src/TensorFlowNET.Core/Data/ZipDataset.cs View File

@@ -2,16 +2,19 @@
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Framework.Models;

namespace Tensorflow
{
public class ZipDataset : DatasetV2
{
dataset_ops ops = new dataset_ops();
public ZipDataset(params IDatasetV2[] ds)
{
var input_datasets = ds.Select(x => x.variant_tensor).ToArray();
structure = ds.Select(x => x.structure[0]).ToArray();
var _structure = new List<TensorSpec>();
foreach (var dataset in ds)
_structure.AddRange(dataset.structure);
structure = _structure.ToArray();
variant_tensor = ops.zip_dataset(input_datasets, output_types, output_shapes);
}
}


Loading…
Cancel
Save