|
|
@@ -1,11 +1,12 @@ |
|
|
|
import os |
|
|
|
import sys |
|
|
|
import __main__ |
|
|
|
from functools import wraps |
|
|
|
from functools import wraps, partial |
|
|
|
from inspect import ismethod |
|
|
|
from copy import deepcopy |
|
|
|
from io import StringIO |
|
|
|
import time |
|
|
|
import signal |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
@@ -29,7 +30,15 @@ def recover_logger(fn): |
|
|
|
return wrapper |
|
|
|
|
|
|
|
|
|
|
|
def magic_argv_env_context(fn): |
|
|
|
def magic_argv_env_context(fn=None, timeout=600): |
|
|
|
""" |
|
|
|
用来在测试时包裹每一个单独的测试函数,使得 ddp 测试正确; |
|
|
|
:param timeout: 表示一个测试如果经过多久还没有通过的话就主动将其 kill 掉,默认为 10 分钟,单位为秒; |
|
|
|
:return: |
|
|
|
""" |
|
|
|
# 说明是通过 @magic_argv_env_context(timeout=600) 调用; |
|
|
|
if fn is None: |
|
|
|
return partial(magic_argv_env_context, timeout=timeout) |
|
|
|
|
|
|
|
@wraps(fn) |
|
|
|
def wrapper(*args, **kwargs): |
|
|
@@ -55,11 +64,17 @@ def magic_argv_env_context(fn): |
|
|
|
else: |
|
|
|
sys.argv = [sys.argv[0], f"{os.path.abspath(sys.modules[fn.__module__].__file__)}::{get_class_that_defined_method(fn).__name__}::{fn.__name__}{subtest}"] + used_args |
|
|
|
|
|
|
|
def _handle_timeout(signum, frame): |
|
|
|
raise TimeoutError(f"\nYour test fn: {fn.__name__} has timed out.\n") |
|
|
|
|
|
|
|
signal.signal(signal.SIGALRM, _handle_timeout) |
|
|
|
signal.alarm(timeout) |
|
|
|
res = fn(*args, **kwargs) |
|
|
|
signal.alarm(0) |
|
|
|
sys.argv = deepcopy(command) |
|
|
|
os.environ = env |
|
|
|
|
|
|
|
return res |
|
|
|
|
|
|
|
return wrapper |
|
|
|
|
|
|
|
|
|
|
|