You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

profiler.py 1.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import importlib
  3. import sys
  4. from functools import wraps
  5. from typing import Any, Callable, Dict, Tuple, Type
  6. def reraise(tp, value, tb):
  7. try:
  8. if value is None:
  9. value = tp()
  10. if value.__traceback__ is not tb:
  11. raise value.with_traceback(tb)
  12. raise value
  13. finally:
  14. value = None
  15. tb = None
  16. class Profiler:
  17. def __init__(self) -> None:
  18. import cProfile
  19. self.pr = cProfile.Profile()
  20. def __enter__(self):
  21. self.pr.enable()
  22. def __exit__(self, tp, exc, tb):
  23. self.pr.disable()
  24. if tp is not None:
  25. reraise(tp, exc, tb)
  26. import pstats
  27. ps = pstats.Stats(self.pr, stream=sys.stderr).sort_stats('tottime')
  28. ps.print_stats(20)
  29. def wrapper(tp: Type[Profiler]) -> Callable[[], Callable[..., Any]]:
  30. def _inner(func: Callable[..., Any]) -> Callable[..., Any]:
  31. @wraps(func)
  32. def executor(*args: Tuple[Any, ...], **kwargs: Dict[str, Any]) -> Any:
  33. with tp():
  34. return func(*args, **kwargs)
  35. return executor
  36. return _inner
  37. PIPELINE_BASE_MODULE = 'modelscope.pipelines.base'
  38. PIPELINE_BASE_CLASS = 'Pipeline'
  39. def enable():
  40. base = importlib.import_module(PIPELINE_BASE_MODULE)
  41. Pipeline = getattr(base, PIPELINE_BASE_CLASS)
  42. Pipeline.__call__ = wrapper(Profiler)(Pipeline.__call__)