using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Data;
using Tensorflow.Framework.Models;
using static Tensorflow.Binding;
namespace Tensorflow
{
///
/// Abstract class representing a dataset with no inputs.
///
public class DatasetV2 : IDatasetV2
{
protected dataset_ops ops = new dataset_ops();
public string[] class_names { get; set; }
public Tensor variant_tensor { get; set; }
public TensorSpec[] structure { get; set; }
public int FirstInputTensorCount { get; set; } = 1;
public Shape[] output_shapes => structure.Select(x => x.shape).ToArray();
public TF_DataType[] output_types => structure.Select(x => x.dtype).ToArray();
public TensorSpec[] element_spec => structure;
public int length => cardinality().numpy();
public IDatasetV2 cache(string filename = "")
=> new CacheDataset(this, filename: filename);
public IDatasetV2 concatenate(IDatasetV2 dataset)
=> new ConcatenateDataset(this, dataset);
public IDatasetV2 take(int count = -1)
=> new TakeDataset(this, count: count);
public IDatasetV2 batch(int batch_size, bool drop_remainder = false)
=> new BatchDataset(this, batch_size, drop_remainder: drop_remainder);
public IDatasetV2 prefetch(int buffer_size = -1, int? slack_period = null)
=> new PrefetchDataset(this, buffer_size: buffer_size, slack_period: slack_period);
public IDatasetV2 repeat(int count = -1)
=> new RepeatDataset(this, count: count);
public IDatasetV2 shard(int num_shards, int index)
=> new ShardDataset(this, num_shards, index);
public IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true)
=> new ShuffleDataset(this, buffer_size, seed: seed, reshuffle_each_iteration: reshuffle_each_iteration);
public IDatasetV2 skip(int count)
=> new SkipDataset(this, count);
public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs)
=> new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs);
public IDatasetV2 map(Func map_func,
bool use_inter_op_parallelism = true,
bool preserve_cardinality = true,
bool use_legacy_function = false)
=> new MapDataset(this,
map_func,
use_inter_op_parallelism: use_inter_op_parallelism,
preserve_cardinality: preserve_cardinality,
use_legacy_function: use_legacy_function);
public IDatasetV2 map(Func map_func, int num_parallel_calls)
=> new ParallelMapDataset(this, map_func,
num_parallel_calls: num_parallel_calls,
preserve_cardinality: true);
public IDatasetV2 filter(Func predicate_func)
=> new FilterDataset(this, predicate_func);
public IDatasetV2 filter(Func predicate_func)
=> new FilterDataset(this, predicate_func);
public OwnedIterator make_one_shot_iterator()
{
if (tf.Context.executing_eagerly())
{
// with ops.colocate_with(self._variant_tensor)
return new OwnedIterator(this);
}
throw new NotImplementedException("");
}
public IDatasetV2 flat_map(Func map_func)
=> new FlatMapDataset(this, map_func);
public IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget, long ram_budget)
=> new ModelDataset(this, algorithm, cpu_budget, ram_budget);
public IDatasetV2 with_options(DatasetOptions options)
=> new OptionsDataset(this, options);
public IDatasetV2 apply_options()
{
IDatasetV2 dataset = this;
// (1) Apply threading options
// (2) Apply autotune options
var autotune = true;
long cpu_budget = 0;
long ram_budget = 0;
if (autotune)
dataset = dataset.model(AutotuneAlgorithm.HILL_CLIMB, cpu_budget, ram_budget);
// (3) Apply graph rewrite options
var graph_rewrites = new[]
{
"map_and_batch_fusion",
"map_parallelization",
"noop_elimination",
"shuffle_and_repeat_fusion"
};
var graph_rewrite_configs = new string[]
{
"autotune_buffer_sizes:autotune:true",
"batch_parallelization:autotune:true",
"disable_prefetch_legacy_autotune:autotune:true",
"enable_gradient_descent:autotune:true",
"map_parallelization:autotune:true"
};
dataset = new OptimizeDataset(dataset, new string[0], new string[0], graph_rewrites, graph_rewrite_configs);
// (4) Apply stats aggregator options
dataset.FirstInputTensorCount = this.FirstInputTensorCount;
return dataset;
}
public Tensor cardinality(string name = null)
=> tf.Context.ExecuteOp("DatasetCardinality", name, new ExecuteOpArgs(variant_tensor));
public override string ToString()
=> $"{GetType().Name} shapes: {string.Join(", ", structure.Select(x => x.shape))}, " +
$"types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}, " +
$"len: {length}";
public IEnumerator<(Tensors, Tensors)> GetEnumerator()
{
using var ownedIterator = new OwnedIterator(this);
Tensor[] results = null;
while (true)
{
try
{
results = ownedIterator.next();
}
catch (StopIteration)
{
break;
}
yield return (new Tensors(results.Take(FirstInputTensorCount).ToArray()), results.Length == FirstInputTensorCount ?
null : new Tensors(results.Skip(FirstInputTensorCount).ToArray()));
}
}
IEnumerator IEnumerable.GetEnumerator()
{
return this.GetEnumerator();
}
}
}