You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore_data.py 2.7 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import mindspore.dataset as ds
  4. import mindspore as ms
  5. from enum import Enum
  6. __all__ = [
  7. 'Batch',
  8. 'Concat',
  9. 'FromGenerator',
  10. 'FromSlices',
  11. 'Map',
  12. 'Repeat',
  13. 'Shuffle',
  14. 'Dataloader',
  15. 'Dataset',
  16. 'IterableDataset',
  17. ]
  18. class Dataset(object):
  19. def __init__(self):
  20. pass
  21. def __getitem__(self, idx):
  22. raise NotImplementedError("'{}' not implement in class "\
  23. "{}".format('__getitem__', self.__class__.__name__))
  24. def __len__(self):
  25. raise NotImplementedError("'{}' not implement in class "\
  26. "{}".format('__len__', self.__class__.__name__))
  27. class IterableDataset(object):
  28. def __init__(self):
  29. pass
  30. def __iter__(self):
  31. raise NotImplementedError("'{}' not implement in class " \
  32. "{}".format('__iter__', self.__class__.__name__))
  33. def Batch(dataset, batch_size, drop_last=False):
  34. '''
  35. Parameters
  36. ----------
  37. dataset
  38. batch_size
  39. drop_last
  40. Returns
  41. -------
  42. '''
  43. return dataset.batch(batch_size=batch_size, drop_remainder=drop_last)
  44. def Concat(datasets):
  45. datasets = list(datasets)
  46. dataset = ds.Dataset.concat(datasets)
  47. return dataset
  48. def FromGenerator(generator, output_types, column_names):
  49. output_types = list(output_types)
  50. column_names = list(column_names)
  51. return ds.GeneratorDataset(source=generator, column_names=column_names, column_types=output_types)
  52. def FromSlices(datas, column_names):
  53. return ds.NumpySlicesDataset(data=datas, column_names=column_names)
  54. def Map(dataset, map_func, input_columns=None):
  55. """ Maps map_func across the elements of this dataset.
  56. Parameters
  57. ----------
  58. dataset : DataFlow
  59. input DataFlow
  60. map_func : function
  61. A function mapping a dataset element to another dataset element.
  62. num_parallel_calls
  63. Returns
  64. -------
  65. """
  66. return dataset.map(operations=map_func, input_columns=input_columns)
  67. def Repeat(dataset, count=None):
  68. return dataset.repeat(count)
  69. def Shuffle(dataset, buffer_size):
  70. return dataset.shuffle(buffer_size)
  71. def Zip(datasets):
  72. '''
  73. Creates a Dataset by zipping together the given datasets.
  74. Parameters
  75. ----------
  76. datasets:
  77. A tuple of datasets to be zipped together.
  78. Returns
  79. -------
  80. '''
  81. datasets = tuple(datasets)
  82. return ds.zip(datasets)
  83. def Dataloader(dataset, batch_size, shuffle=False, drop_last=False, shuffle_buffer_size=10000):
  84. if shuffle:
  85. dataset = Shuffle(dataset, buffer_size=shuffle_buffer_size)
  86. dataset = Batch(dataset, batch_size=batch_size, drop_last=drop_last)
  87. return dataset

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.