You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

DatasetV2.cs 5.3 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. using System;
  2. using System.Collections;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using Tensorflow.Data;
  6. using Tensorflow.Framework.Models;
  7. using static Tensorflow.Binding;
  8. namespace Tensorflow
  9. {
  10. /// <summary>
  11. /// Abstract class representing a dataset with no inputs.
  12. /// </summary>
  13. public class DatasetV2 : IDatasetV2
  14. {
  15. protected dataset_ops ops = new dataset_ops();
  16. public string[] class_names { get; set; }
  17. public Tensor variant_tensor { get; set; }
  18. public TensorSpec[] structure { get; set; }
  19. public TensorShape[] output_shapes => structure.Select(x => x.shape).ToArray();
  20. public TF_DataType[] output_types => structure.Select(x => x.dtype).ToArray();
  21. public TensorSpec[] element_spec => structure;
  22. public IDatasetV2 cache(string filename = "")
  23. => new CacheDataset(this, filename: filename);
  24. public IDatasetV2 concatenate(IDatasetV2 dataset)
  25. => new ConcatenateDataset(this, dataset);
  26. public IDatasetV2 take(int count = -1)
  27. => new TakeDataset(this, count: count);
  28. public IDatasetV2 batch(int batch_size, bool drop_remainder = false)
  29. => new BatchDataset(this, batch_size, drop_remainder: drop_remainder);
  30. public IDatasetV2 prefetch(int buffer_size = -1, int? slack_period = null)
  31. => new PrefetchDataset(this, buffer_size: buffer_size, slack_period: slack_period);
  32. public IDatasetV2 repeat(int count = -1)
  33. => new RepeatDataset(this, count: count);
  34. public IDatasetV2 shard(int num_shards, int index)
  35. => new ShardDataset(this, num_shards, index);
  36. public IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true)
  37. => new ShuffleDataset(this, buffer_size, seed: seed, reshuffle_each_iteration: reshuffle_each_iteration);
  38. public IDatasetV2 skip(int count)
  39. => new SkipDataset(this, count);
  40. public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs)
  41. => new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs);
  42. public IDatasetV2 map(Func<Tensors, Tensors> map_func,
  43. bool use_inter_op_parallelism = true,
  44. bool preserve_cardinality = true,
  45. bool use_legacy_function = false)
  46. => new MapDataset(this,
  47. map_func,
  48. use_inter_op_parallelism: use_inter_op_parallelism,
  49. preserve_cardinality: preserve_cardinality,
  50. use_legacy_function: use_legacy_function);
  51. public IDatasetV2 map(Func<Tensors, Tensors> map_func, int num_parallel_calls)
  52. => new ParallelMapDataset(this, map_func, num_parallel_calls: num_parallel_calls);
  53. public OwnedIterator make_one_shot_iterator()
  54. {
  55. if (tf.Context.executing_eagerly())
  56. {
  57. // with ops.colocate_with(self._variant_tensor)
  58. return new OwnedIterator(this);
  59. }
  60. throw new NotImplementedException("");
  61. }
  62. public IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func)
  63. => new FlatMapDataset(this, map_func);
  64. public IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget)
  65. => new ModelDataset(this, algorithm, cpu_budget);
  66. public IDatasetV2 with_options(DatasetOptions options)
  67. => new OptionsDataset(this, options);
  68. public IDatasetV2 apply_options()
  69. {
  70. // (1) Apply threading options
  71. var graph_rewrites = new[]
  72. {
  73. "map_and_batch_fusion",
  74. "noop_elimination",
  75. "shuffle_and_repeat_fusion"
  76. };
  77. var graph_rewrite_configs = new string[0];
  78. // (2) Apply graph rewrite options
  79. var dataset = optimize(graph_rewrites, graph_rewrite_configs);
  80. // (3) Apply autotune options
  81. var autotune = true;
  82. long cpu_budget = 0;
  83. if (autotune)
  84. dataset = dataset.model(AutotuneAlgorithm.HILL_CLIMB, cpu_budget);
  85. // (4) Apply stats aggregator options
  86. return dataset;
  87. }
  88. public Tensor dataset_cardinality(string name = null)
  89. => tf.Context.ExecuteOp("DatasetCardinality", name, new ExecuteOpArgs(variant_tensor));
  90. public override string ToString()
  91. => $"{GetType().Name} shapes: {string.Join(", ", structure.Select(x => x.shape))}, types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}";
  92. public IEnumerator<(Tensor, Tensor)> GetEnumerator()
  93. {
  94. using var ownedIterator = new OwnedIterator(this);
  95. Tensor[] results = null;
  96. while (true)
  97. {
  98. try
  99. {
  100. results = ownedIterator.next();
  101. }
  102. catch (StopIteration)
  103. {
  104. break;
  105. }
  106. yield return (results[0], results.Length == 1 ? null : results[1]);
  107. }
  108. }
  109. IEnumerator IEnumerable.GetEnumerator()
  110. {
  111. return this.GetEnumerator();
  112. }
  113. }
  114. }