You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common.py 1.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import time
  3. from collections.abc import Sequence
  4. from .builder import PREPROCESSORS, build_preprocessor
  5. @PREPROCESSORS.register_module()
  6. class Compose(object):
  7. """Compose a data pipeline with a sequence of transforms.
  8. Args:
  9. transforms (list[dict | callable]):
  10. Either config dicts of transforms or transform objects.
  11. profiling (bool, optional): If set True, will profile and
  12. print preprocess time for each step.
  13. """
  14. def __init__(self, transforms, field_name=None, profiling=False):
  15. assert isinstance(transforms, Sequence)
  16. self.profiling = profiling
  17. self.transforms = []
  18. self.field_name = field_name
  19. for transform in transforms:
  20. if isinstance(transform, dict):
  21. if self.field_name is None:
  22. transform = build_preprocessor(transform, field_name)
  23. self.transforms.append(transform)
  24. elif callable(transform):
  25. self.transforms.append(transform)
  26. else:
  27. raise TypeError('transform must be callable or a dict, but got'
  28. f' {type(transform)}')
  29. def __call__(self, data):
  30. for t in self.transforms:
  31. if self.profiling:
  32. start = time.time()
  33. data = t(data)
  34. if self.profiling:
  35. print(f'{t} time {time.time()-start}')
  36. if data is None:
  37. return None
  38. return data
  39. def __repr__(self):
  40. format_string = self.__class__.__name__ + '('
  41. for t in self.transforms:
  42. format_string += f'\n {t}'
  43. format_string += '\n)'
  44. return format_string

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展