You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

paddle_data.py 2.4 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import numpy as np
  4. import paddle
  5. from paddle.io import Dataset as dataset
  6. from paddle.io import IterableDataset as iterabledataset
  7. from paddle.io import DataLoader
  8. __all__ = [
  9. 'Batch',
  10. 'Concat',
  11. 'FromGenerator',
  12. 'FromSlices',
  13. 'Map',
  14. 'Repeat',
  15. 'Shuffle',
  16. 'Dataloader',
  17. 'Dataset',
  18. 'IterableDataset',
  19. ]
  20. class Dataset(dataset):
  21. def __init__(self):
  22. pass
  23. def __getitem__(self, idx):
  24. raise NotImplementedError("'{}' not implement in class "\
  25. "{}".format('__getitem__', self.__class__.__name__))
  26. def __len__(self):
  27. raise NotImplementedError("'{}' not implement in class "\
  28. "{}".format('__len__', self.__class__.__name__))
  29. class IterableDataset(iterabledataset):
  30. def __init__(self):
  31. pass
  32. def __iter__(self):
  33. raise NotImplementedError("'{}' not implement in class "\
  34. "{}".format('__iter__', self.__class__.__name__))
  35. def __getitem__(self, idx):
  36. raise RuntimeError("'{}' should not be called for IterableDataset" \
  37. "{}".format('__getitem__', self.__class__.__name__))
  38. def __len__(self):
  39. raise RuntimeError("'{}' should not be called for IterableDataset" \
  40. "{}".format('__len__', self.__class__.__name__))
  41. def FromGenerator(generator, output_types=None, column_names=None):
  42. return generator
  43. def FromSlices(datas, column_names=None):
  44. datas = list(datas)
  45. return paddle.io.TensorDataset(datas)
  46. def Concat(datasets):
  47. return paddle.io.ChainDataset(list(datasets))
  48. def Zip(datasets):
  49. return paddle.io.ComposeDataset(list(datasets))
  50. def Dataloader(dataset, batch_size=None, shuffle=False, drop_last=False, shuffle_buffer_size=0):
  51. return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, return_list=True)
  52. def Batch(dataset, batch_size, drop_last=False):
  53. raise NotImplementedError('This function not implement in paddle backend.')
  54. def Shuffle(dataset, buffer_size, seed=None):
  55. raise NotImplementedError('This function not implement in paddle backend.')
  56. def Repeat(dataset, count=None):
  57. raise NotImplementedError('This function not implement in paddle backend.')
  58. def Map(dataset, map_func, input_columns=None):
  59. raise NotImplementedError('This function not implement in paddle backend.')

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.