You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore_data.py 8.9 kB

4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import mindspore.dataset as ds
  4. import mindspore as ms
  5. from enum import Enum
  6. __all__ = [
  7. 'Apply',
  8. 'Batch',
  9. 'Concat',
  10. 'CsvDataset',
  11. 'Filter',
  12. 'Flat_map',
  13. 'FromGenerator',
  14. 'FromSlices',
  15. 'Map',
  16. 'Prefetch',
  17. 'Repeat',
  18. 'Shuffle',
  19. 'Skip',
  20. 'Take',
  21. 'TextFlieDataset',
  22. 'TFRecordDataset',
  23. 'Dataloader',
  24. ]
  25. class Shuffle(str, Enum):
  26. GLOBAL: str = "global"
  27. FILES: str = "file"
  28. def Apply(dataset, transformation_func):
  29. return dataset.apply(transformation_func)
  30. def Batch(
  31. dataset, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None, inut_columns=None,
  32. output_columns=None, column_order=None, pad_info=None
  33. ):
  34. '''
  35. Combine batch_size number of consecutive rows into batches.
  36. Parameters
  37. ----------
  38. dataset
  39. batch_size
  40. drop_remainder
  41. num_parallel_workers
  42. per_batch_map
  43. inut_columns
  44. output_columns
  45. column_order
  46. pad_info
  47. Returns
  48. -------
  49. '''
  50. return dataset.batch(
  51. batch_size=batch_size, drop_remainder=drop_remainder, num_parallel_workers=num_parallel_workers,
  52. per_batch_map=per_batch_map, input_columns=inut_columns, output_columns=output_columns,
  53. column_order=column_order, pad_info=pad_info
  54. )
  55. def Concat(dataset_1, dataset_2):
  56. return dataset_1.concat(dataset_2)
  57. def CsvDataset(
  58. file_pattern, batch_size=1, column_names=None, column_defaults=None, label_name=None, select_columns=None,
  59. field_delim=',', use_quote_delim=True, na_value='', header=True, num_epochs=None, shuffle=Shuffle.GLOBAL,
  60. shuffle_buffer_size=10000, shuffle_seed=None, prefetch_buffer_size=None, num_parallel_reads=None, sloppy=False,
  61. num_rows_for_inference=100, compression_type=None, ignore_errors=False, numples_samples=None, num_shards=None,
  62. shard_id=None, cache=None
  63. ):
  64. """
  65. A source dataset that reads and parses comma-separated values (CSV) datasets.
  66. Examples:
  67. >>> import mindspore.dataset as dataset
  68. >>>
  69. >>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files
  70. >>> dataset = dataset.CSVDataset(dataset_files=dataset_files, column_names=['col1', 'col2', 'col3', 'col4'])
  71. """
  72. return ds.CSVDataset(
  73. dataset_files=file_pattern, field_delim=field_delim, column_defaults=column_defaults, column_names=column_names,
  74. num_samples=numples_samples, num_parallel_workers=num_parallel_reads, shuffle=shuffle, num_shards=num_shards,
  75. shard_id=shard_id, cache=cache
  76. )
  77. def Filter(dataset, predicate):
  78. return dataset.filter(predicate)
  79. def Flat_map(dataset, map_func):
  80. return dataset.flat_map(map_func)
  81. def FromGenerator(
  82. generator, output_types, output_shapes=None, args=None, column_names=None, column_types=None, schema=None,
  83. num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None,
  84. python_multiprocessing=True
  85. ):
  86. return ds.GeneratorDataset(
  87. source=generator, column_names=column_names, column_types=column_types, schema=schema, num_samples=num_samples,
  88. num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler, num_shards=num_shards,
  89. shard_id=shard_id, python_multiprocessing=python_multiprocessing
  90. )
  91. def FromSlices(
  92. tensor, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None,
  93. shard_id=None
  94. ):
  95. return ds.NumpySlicesDataset(
  96. data=tensor, column_names=column_names, num_samples=num_samples, num_parallel_workers=num_parallel_workers,
  97. shuffle=shuffle, sampler=sampler, num_shards=num_shards, shard_id=shard_id
  98. )
  99. def Map(
  100. dataset, map_func, num_parallel_calls=None, input_columns=None, output_columns=None, column_order=None,
  101. num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None
  102. ):
  103. """ Maps map_func across the elements of this dataset.
  104. Parameters
  105. ----------
  106. dataset : DataFlow
  107. input DataFlow
  108. map_func : function
  109. A function mapping a dataset element to another dataset element.
  110. num_parallel_calls
  111. Returns
  112. -------
  113. """
  114. return dataset.map(
  115. operations=map_func, input_columns=input_columns, output_columns=output_columns, column_order=column_order,
  116. num_parallel_workers=num_parallel_workers, python_multiprocessing=python_multiprocessing, cache=cache,
  117. callbacks=callbacks
  118. )
  119. def Prefetch(dataset, buffer_size):
  120. batch_size = dataset.get_batch_size()
  121. prefetch_size = batch_size * buffer_size
  122. return dataset.config.set_prefetch_size(prefetch_size)
  123. def Repeat(dataset, count=None):
  124. return dataset.repeat(count)
  125. def Shuffle(dataset, buffer_size, seed=None, reshuffle_each_iteration=None):
  126. #dataset.config.set_seed(seed)
  127. return dataset.shuffle(buffer_size)
  128. def Skip(dataset, count):
  129. '''
  130. Creates a Dataset that skips count elements from this dataset.
  131. Parameters
  132. ----------
  133. dataset:
  134. A dataset
  135. count:
  136. A tf.int64 scalar tf.Tensor, representing the number of elements of this dataset that should be skipped to form the new dataset.
  137. Returns
  138. -------
  139. '''
  140. return dataset.skip(count)
  141. def Take(dataset, count):
  142. '''
  143. Creates a Dataset with at most count elements from this dataset.
  144. Parameters
  145. ----------
  146. dataset:
  147. A dataset
  148. count:
  149. A tf.int64 scalar tf.Tensor, representing the number of elements of this dataset that should be taken to form the new dataset.
  150. If count is -1, or if count is greater than the size of this dataset, the new dataset will contain all elements of this dataset.
  151. Returns
  152. -------
  153. '''
  154. return dataset.take(count)
  155. def TextFlieDataset(
  156. filenames, compression_type=None, buffer_size=None, num_parallel_reads=None, num_samples=None, shuffle=None,
  157. num_shards=None, shard_id=None, cache=None
  158. ):
  159. """
  160. A source dataset that reads and parses datasets stored on disk in text format.
  161. The generated dataset has one column ['text'].
  162. Examples:
  163. >>> import mindspore.dataset as dataset
  164. >>>
  165. >>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files
  166. >>> dataset = dataset.TextFileDataset(dataset_files=dataset_files)
  167. """
  168. if shuffle is None:
  169. shuffle = Shuffle.GLOBAL
  170. return ds.TextFileDataset(
  171. dataset_files=filenames, num_samples=num_samples, num_parallel_workers=num_parallel_reads, shuffle=shuffle,
  172. num_shards=num_shards, shard_id=shard_id, cache=cache
  173. )
  174. def TFRecordDataset(
  175. filenames, compression_type=None, buffer_size=None, num_parallel_reads=None, schema=None, columns_list=None,
  176. num_samples=None, shuffle=None, num_shards=None, shard_id=None, shard_equal_rows=False, cache=None
  177. ):
  178. """
  179. A source dataset that reads and parses datasets stored on disk in TFData format.
  180. Examples:
  181. >>> import mindspore.dataset as dataset
  182. >>> import mindspore.common.dtype as mstype
  183. >>>
  184. >>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple tf data files
  185. >>>
  186. >>> # 1) Get all rows from dataset_files with no explicit schema
  187. >>> # The meta-data in the first row will be used as a schema.
  188. >>> tfdataset = dataset.TFRecordDataset(dataset_files=dataset_files)
  189. >>>
  190. >>> # 2) Get all rows from dataset_files with user-defined schema
  191. >>> schema = dataset.Schema()
  192. >>> schema.add_column('col_1d', de_type=mindspore.int64, shape=[2])
  193. >>> tfdataset = dataset.TFRecordDataset(dataset_files=dataset_files, schema=schema)
  194. >>>
  195. >>> # 3) Get all rows from dataset_files with schema file "./schema.json"
  196. >>> tfdataset = dataset.TFRecordDataset(dataset_files=dataset_files, schema="./schema.json")
  197. """
  198. if shuffle is None:
  199. shuffle = Shuffle.GLOBAL
  200. return ds.TFRecordDataset(
  201. dataset_files=filenames, schema=schema, columns_list=columns_list, num_samples=num_samples,
  202. num_parallel_workers=num_parallel_reads, shuffle=shuffle, num_shards=num_shards, shard_id=shard_id,
  203. shard_equal_rows=shard_equal_rows, cache=cache
  204. )
  205. def Zip(datasets):
  206. '''
  207. Creates a Dataset by zipping together the given datasets.
  208. Parameters
  209. ----------
  210. datasets:
  211. A tuple of datasets to be zipped together.
  212. Returns
  213. -------
  214. '''
  215. return ds.zip(datasets)
  216. def Dataloader(dataset, batch_size, shuffle=False, drop_last=False, prefetch=0, shuffle_buffer_size=0):
  217. if shuffle:
  218. dataset = Shuffle(dataset, buffer_size=shuffle_buffer_size)
  219. dataset = Batch(dataset, batch_size=batch_size, drop_remainder=drop_last)
  220. dataset = Prefetch(dataset, buffer_size=prefetch)
  221. return dataset

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.