diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 9a4af07c..c0b51a8b 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -1,11 +1,12 @@ import os import sys import __main__ -from functools import wraps +from functools import wraps, partial from inspect import ismethod from copy import deepcopy from io import StringIO import time +import signal import numpy as np @@ -29,7 +30,15 @@ def recover_logger(fn): return wrapper -def magic_argv_env_context(fn): +def magic_argv_env_context(fn=None, timeout=600): + """ + 用来在测试时包裹每一个单独的测试函数,使得 ddp 测试正确; + :param timeout: 表示一个测试如果经过多久还没有通过的话就主动将其 kill 掉,默认为 10 分钟,单位为秒; + :return: + """ + # 说明是通过 @magic_argv_env_context(timeout=600) 调用; + if fn is None: + return partial(magic_argv_env_context, timeout=timeout) @wraps(fn) def wrapper(*args, **kwargs): @@ -55,11 +64,17 @@ def magic_argv_env_context(fn): else: sys.argv = [sys.argv[0], f"{os.path.abspath(sys.modules[fn.__module__].__file__)}::{get_class_that_defined_method(fn).__name__}::{fn.__name__}{subtest}"] + used_args + def _handle_timeout(signum, frame): + raise TimeoutError(f"\nYour test fn: {fn.__name__} has timed out.\n") + + signal.signal(signal.SIGALRM, _handle_timeout) + signal.alarm(timeout) res = fn(*args, **kwargs) + signal.alarm(0) sys.argv = deepcopy(command) os.environ = env - return res + return wrapper