From b1ba42e8315e0dfbd0cb3440e011570ae3e8a801 Mon Sep 17 00:00:00 2001 From: yh Date: Sun, 15 May 2022 22:33:30 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E8=8B=A5=E5=B9=B2=E5=BE=AE?= =?UTF-8?q?=E5=B0=8F=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/__init__.py | 4 +- .../callbacks/load_best_model_callback.py | 6 +- fastNLP/core/callbacks/progress_callback.py | 1 - fastNLP/core/controllers/trainer.py | 6 +- fastNLP/core/drivers/paddle_driver/fleet.py | 1 + fastNLP/core/drivers/torch_driver/ddp.py | 1 + fastNLP/core/log/logger.py | 16 ++ fastNLP/core/metrics/accuracy.py | 5 +- fastNLP/core/utils/tqdm_progress.py | 2 +- .../test_checkpoint_callback_torch.py | 145 +++++++++--------- .../test_load_best_model_callback_torch.py | 32 ++-- .../callbacks/test_more_evaluate_callback.py | 12 +- .../callbacks/test_progress_callback_torch.py | 4 +- tests/helpers/utils.py | 11 ++ 14 files changed, 146 insertions(+), 100 deletions(-) diff --git a/fastNLP/__init__.py b/fastNLP/__init__.py index 915e382d..9885a175 100644 --- a/fastNLP/__init__.py +++ b/fastNLP/__init__.py @@ -1,3 +1,5 @@ from fastNLP.envs import * -from fastNLP.core import * \ No newline at end of file +from fastNLP.core import * + +__version__ = '0.8.0beta' diff --git a/fastNLP/core/callbacks/load_best_model_callback.py b/fastNLP/core/callbacks/load_best_model_callback.py index 1399722a..48bea6e3 100644 --- a/fastNLP/core/callbacks/load_best_model_callback.py +++ b/fastNLP/core/callbacks/load_best_model_callback.py @@ -11,6 +11,7 @@ import shutil from fastNLP.envs.env import FASTNLP_LAUNCH_TIME, FASTNLP_GLOBAL_RANK, FASTNLP_BACKEND_LAUNCH from fastNLP.core.log import logger from fastNLP.envs import all_rank_call_context +from fastNLP.core.utils.exceptions import EarlyStopException class LoadBestModelCallback(HasMonitorCallback): @@ -61,7 +62,7 @@ class LoadBestModelCallback(HasMonitorCallback): save_folder = os.path.join(save_folder, os.environ.get(FASTNLP_LAUNCH_TIME)) self.real_save_folder = os.path.join(save_folder, 'best_so_far') if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: - os.makedirs(self.real_save_folder) + os.makedirs(self.real_save_folder, exist_ok=True) else: # 创建出一个 stringio self.real_save_folder = None self.buffer = BytesIO() @@ -114,7 +115,8 @@ class LoadBestModelCallback(HasMonitorCallback): trainer.driver.barrier() def on_exception(self, trainer, exception): - self.encounter_exception = True + if not isinstance(exception, EarlyStopException): + self.encounter_exception = True def _delete_folder(self): if self.real_save_folder: diff --git a/fastNLP/core/callbacks/progress_callback.py b/fastNLP/core/callbacks/progress_callback.py index 9fab4dbd..a091a35c 100644 --- a/fastNLP/core/callbacks/progress_callback.py +++ b/fastNLP/core/callbacks/progress_callback.py @@ -9,7 +9,6 @@ __all__ = [ 'TqdmCallback' ] -from ...envs.imports import _module_available, _compare_version from .has_monitor_callback import HasMonitorCallback from fastNLP.core.utils import f_rich_progress, f_tqdm_progress diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 01be134d..1ff00287 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -295,9 +295,9 @@ class Trainer(TrainerEventTrigger): log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; 注意该参数仅当使用分布式的 ``driver`` 时才有效,例如 ``TorchDDPDriver``; - * *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto', 'tqdm'] 或者 RichCallback, RawTextCallback等对象, - 默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 RichCallback,否则使用 RawTextCallback 对象。如果 - 需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 等对象。 + * *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto', 'tqdm'] 或者 :class:`~.fastNLP.RichCallback`, :class:`~fastNLP.RawTextCallback`等对象, + 默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 :class:`~fastNLP.RichCallback`,否则使用 :class:`~fastNLP.RawTextCallback` 对象。如果 + 需要定制 progress bar 的参数,例如打印频率等,可以传入 :class:`~fastNLP.RichCallback`, :class:`~fastNLP.RawTextCallback` 等对象。 * *train_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Trainer`` 中。与 input_mapping 互斥。 * *train_output_mapping* -- 与 output_mapping 一致,但是只用于 ``Trainer`` 中。与 output_mapping 互斥。 * *evaluate_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Evaluator`` 中。与 input_mapping 互斥。 diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index 86a9c3f0..72dbd07d 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -296,6 +296,7 @@ class PaddleFleetDriver(PaddleDriver): 使用 FleetLauncher 拉起子进程 """ if self.local_rank == 0: + logger._set_distributed() # 是 rank0 的话,则拉起其它子进程 launcher = FleetLauncher(self.parallel_device, self.output_from_new_proc) launcher.launch() diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index 5f3f3108..9dbea342 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -425,6 +425,7 @@ class TorchDDPDriver(TorchDriver): os.environ[FASTNLP_DISTRIBUTED_CHECK] = f"{len(self.parallel_device)}" os.environ[FASTNLP_GLOBAL_RANK] = "0" + logger._set_distributed() interactive_ddp_procs = [] diff --git a/fastNLP/core/log/logger.py b/fastNLP/core/log/logger.py index 86d52041..251ef3b9 100644 --- a/fastNLP/core/log/logger.py +++ b/fastNLP/core/log/logger.py @@ -208,6 +208,20 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): for handler in self.handlers: handler.setLevel(level) + def _set_distributed(self): + """ + 在 fastNLP 拉起进程的时候,调用一下这个方法,使得能够输出 rank 信息 + + :return: + """ + for handler in self.handlers: + if isinstance(handler, logging.FileHandler): + formatter = logging.Formatter(fmt='Rank: %(rank)s - %(asctime)s - %(module)s - [%(levelname)s] - %(message)s', + datefmt='%Y/%m/%d %H:%M:%S') + else: + formatter = logging.Formatter('Rank: %(rank)s - %(message)s') + handler.setFormatter(formatter) + def _get_level(level): if not isinstance(level, int): @@ -350,6 +364,8 @@ def _init_logger(path=None, stdout='rich', level='INFO'): if path is not None: _add_file_handler(logger, path, level) + logger.setLevel(level) + return logger diff --git a/fastNLP/core/metrics/accuracy.py b/fastNLP/core/metrics/accuracy.py index 59990f95..9fa2152b 100644 --- a/fastNLP/core/metrics/accuracy.py +++ b/fastNLP/core/metrics/accuracy.py @@ -31,9 +31,10 @@ class Accuracy(Metric): r""" get_metric 函数将根据 update 函数累计的评价指标统计量来计算最终的评价结果. - :return dict evaluate_result: {"acc": float} + :return dict evaluate_result: {"acc": float, 'total': float, 'correct': float} """ - evaluate_result = {'acc': round(self.correct.get_scalar() / (self.total.get_scalar() + 1e-12), 6)} + evaluate_result = {'acc': round(self.correct.get_scalar() / (self.total.get_scalar() + 1e-12), 6), + 'total': self.total.item(), 'correct': self.correct.item()} return evaluate_result def update(self, pred, target, seq_len=None): diff --git a/fastNLP/core/utils/tqdm_progress.py b/fastNLP/core/utils/tqdm_progress.py index 9fcfac94..897c4b7d 100644 --- a/fastNLP/core/utils/tqdm_progress.py +++ b/fastNLP/core/utils/tqdm_progress.py @@ -83,7 +83,7 @@ class TqdmProgress(metaclass=Singleton): :return: """ assert _module_available('tqdm') and _compare_version('tqdm', operator.ge, '4.57'), \ - f"To use {self.__class__.__name__}, tqdm>=4.57 is needed." + f"To use tqdm, tqdm>=4.57 is needed." from .rich_progress import f_rich_progress assert not f_rich_progress.not_empty(), "Cannot use tqdm before rich finish loop." diff --git a/tests/core/callbacks/test_checkpoint_callback_torch.py b/tests/core/callbacks/test_checkpoint_callback_torch.py index eff7b420..0a99db6a 100644 --- a/tests/core/callbacks/test_checkpoint_callback_torch.py +++ b/tests/core/callbacks/test_checkpoint_callback_torch.py @@ -28,10 +28,10 @@ if _NEED_IMPORT_TORCH: class ArgMaxDatasetConfig: num_labels: int = 10 feature_dimension: int = 10 - data_num: int = 100 + data_num: int = 50 seed: int = 0 - batch_size: int = 4 + batch_size: int = 2 shuffle: bool = True @@ -204,99 +204,100 @@ def test_model_checkpoint_callback_1( @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) +@pytest.mark.parametrize("only_state_dict", [True, False]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) @magic_argv_env_context(timeout=100) def test_model_checkpoint_callback_2( model_and_optimizers: TrainerParameters, driver, - device + device, + only_state_dict ): - for only_state_dict in [True, False]: - try: - path = Path.cwd().joinpath("test_model_checkpoint") - path.mkdir(exist_ok=True, parents=True) - - from fastNLP.core.callbacks.callback_event import Event + try: + path = Path.cwd().joinpath("test_model_checkpoint") + path.mkdir(exist_ok=True, parents=True) - @Trainer.on(Event.on_train_epoch_end()) - def raise_exception(trainer): - if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4: - raise NotImplementedError + from fastNLP.core.callbacks.callback_event import Event - callbacks = [ - CheckpointCallback(folder=path, every_n_epochs=None, every_n_batches=None, last=False, - on_exceptions=NotImplementedError, topk=None, monitor=None, only_state_dict=only_state_dict, - save_object='model'), - ] + @Trainer.on(Event.on_train_epoch_end()) + def raise_exception(trainer): + if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4: + raise NotImplementedError - with pytest.raises(NotImplementedError): - trainer = Trainer( - model=model_and_optimizers.model, - driver=driver, - device=device, - optimizers=model_and_optimizers.optimizers, - train_dataloader=model_and_optimizers.train_dataloader, - evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, - input_mapping=model_and_optimizers.input_mapping, - output_mapping=model_and_optimizers.output_mapping, - metrics=model_and_optimizers.metrics, - - n_epochs=10, - callbacks=callbacks, - output_from_new_proc="all" - ) + callbacks = [ + CheckpointCallback(folder=path, every_n_epochs=None, every_n_batches=None, last=False, + on_exceptions=NotImplementedError, topk=None, monitor=None, only_state_dict=only_state_dict, + save_object='model'), + ] - trainer.run() + with pytest.raises(NotImplementedError): + trainer = Trainer( + model=model_and_optimizers.model, + driver=driver, + device=device, + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + + n_epochs=10, + callbacks=callbacks, + output_from_new_proc="all" + ) - if dist.is_initialized(): - dist.destroy_process_group() - if FASTNLP_DISTRIBUTED_CHECK in os.environ: - os.environ.pop(FASTNLP_DISTRIBUTED_CHECK) + trainer.run() - # 检查生成保存模型文件的数量是不是正确的; - all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} + if dist.is_initialized(): + dist.destroy_process_group() + if FASTNLP_DISTRIBUTED_CHECK in os.environ: + os.environ.pop(FASTNLP_DISTRIBUTED_CHECK) - if not isinstance(device, list): - assert "model-epoch_4-batch_100-exception_NotImplementedError" in all_saved_model_paths - exception_model_path = all_saved_model_paths["model-epoch_4-batch_100-exception_NotImplementedError"] - # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; - else: - assert "model-epoch_4-batch_52-exception_NotImplementedError" in all_saved_model_paths - exception_model_path = all_saved_model_paths["model-epoch_4-batch_52-exception_NotImplementedError"] + # 检查生成保存模型文件的数量是不是正确的; + all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} - assert len(all_saved_model_paths) == 1 - all_state_dicts = [exception_model_path] + if not isinstance(device, list): + assert "model-epoch_4-batch_100-exception_NotImplementedError" in all_saved_model_paths + exception_model_path = all_saved_model_paths["model-epoch_4-batch_100-exception_NotImplementedError"] + # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; + else: + assert "model-epoch_4-batch_52-exception_NotImplementedError" in all_saved_model_paths + exception_model_path = all_saved_model_paths["model-epoch_4-batch_52-exception_NotImplementedError"] - for folder in all_state_dicts: - trainer = Trainer( - model=model_and_optimizers.model, - driver="torch", - device=4, - optimizers=model_and_optimizers.optimizers, - train_dataloader=model_and_optimizers.train_dataloader, - evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, - input_mapping=model_and_optimizers.input_mapping, - output_mapping=model_and_optimizers.output_mapping, - metrics=model_and_optimizers.metrics, + assert len(all_saved_model_paths) == 1 + all_state_dicts = [exception_model_path] - n_epochs=2, - output_from_new_proc="all" - ) + for folder in all_state_dicts: + trainer = Trainer( + model=model_and_optimizers.model, + driver="torch", + device=4, + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + + n_epochs=2, + output_from_new_proc="all" + ) - trainer.load_model(folder, only_state_dict=only_state_dict) - trainer.run() - trainer.driver.barrier() + trainer.load_model(folder, only_state_dict=only_state_dict) + trainer.run() + trainer.driver.barrier() - finally: - rank_zero_rm(path) - # pass + finally: + rank_zero_rm(path) + # pass if dist.is_initialized(): dist.destroy_process_group() @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) @magic_argv_env_context(timeout=100) def test_trainer_checkpoint_callback_1( model_and_optimizers: TrainerParameters, diff --git a/tests/core/callbacks/test_load_best_model_callback_torch.py b/tests/core/callbacks/test_load_best_model_callback_torch.py index 9ce5c99d..04efb95c 100644 --- a/tests/core/callbacks/test_load_best_model_callback_torch.py +++ b/tests/core/callbacks/test_load_best_model_callback_torch.py @@ -20,13 +20,14 @@ from fastNLP.core.drivers.torch_driver import TorchSingleDriver from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchArgMaxDataset from tests.helpers.utils import magic_argv_env_context +from fastNLP import logger @dataclass class ArgMaxDatasetConfig: num_labels: int = 10 feature_dimension: int = 10 - data_num: int = 100 + data_num: int = 20 seed: int = 0 batch_size: int = 4 @@ -71,18 +72,31 @@ def model_and_optimizers(request): return trainer_params +from fastNLP import Metric +class CountMetrc(Metric): + def __init__(self): + super().__init__() + self.register_element('count', 0, aggregate_method='sum') + + def update(self, pred): + self.count += len(pred) + + def get_metric(self) -> dict: + return {'cnt': self.count.item()} + + @pytest.mark.torch @pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1), ("torch", "cpu")]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) @magic_argv_env_context def test_load_best_model_callback( model_and_optimizers: TrainerParameters, driver, - device + device, ): for save_folder in ['save_models', None]: for only_state_dict in [True, False]: - callbacks = [LoadBestModelCallback(monitor='acc')] - + callbacks = [LoadBestModelCallback(monitor='acc', only_state_dict=only_state_dict, + save_folder=save_folder)] trainer = Trainer( model=model_and_optimizers.model, driver=driver, @@ -92,16 +106,16 @@ def test_load_best_model_callback( evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, input_mapping=model_and_optimizers.input_mapping, output_mapping=lambda output: output if ('loss' in output) else {'pred':output['preds'], 'target': output['target']}, - metrics=model_and_optimizers.metrics, - n_epochs=3, + metrics={'acc': Accuracy()}, + n_epochs=2, callbacks=callbacks, output_from_new_proc="all" ) trainer.run(num_eval_sanity_batch=0) - driver = TorchSingleDriver(model_and_optimizers.model, device=torch.device('cuda')) - evaluator = Evaluator(model_and_optimizers.model, driver=driver, device=device, + _driver = TorchSingleDriver(model_and_optimizers.model, device=torch.device('cuda')) + evaluator = Evaluator(model_and_optimizers.model, driver=_driver, device=device, dataloaders={'dl1': model_and_optimizers.evaluate_dataloaders}, metrics={'acc': Accuracy(aggregate_when_get_metric=False)}, output_mapping=lambda output: output if ('loss' in output) else {'pred':output['preds'], 'target': output['target']}, @@ -113,5 +127,3 @@ def test_load_best_model_callback( shutil.rmtree(save_folder, ignore_errors=True) if dist.is_initialized(): dist.destroy_process_group() - - diff --git a/tests/core/callbacks/test_more_evaluate_callback.py b/tests/core/callbacks/test_more_evaluate_callback.py index e49d9f88..1ed755d1 100644 --- a/tests/core/callbacks/test_more_evaluate_callback.py +++ b/tests/core/callbacks/test_more_evaluate_callback.py @@ -31,7 +31,7 @@ if _NEED_IMPORT_TORCH: class ArgMaxDatasetConfig: num_labels: int = 10 feature_dimension: int = 10 - data_num: int = 100 + data_num: int = 20 seed: int = 0 batch_size: int = 4 @@ -92,7 +92,7 @@ def model_and_optimizers(request): @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) @magic_argv_env_context def test_model_more_evaluate_callback_1( model_and_optimizers: TrainerParameters, @@ -121,7 +121,7 @@ def test_model_more_evaluate_callback_1( folder=path, topk=1, topk_monitor='acc', only_state_dict=only_state_dict, save_object='model') ] - n_epochs = 5 + n_epochs = 3 trainer = Trainer( model=model_and_optimizers.model, driver=driver, @@ -175,7 +175,7 @@ def test_model_more_evaluate_callback_1( @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) @magic_argv_env_context def test_trainer_checkpoint_callback_1( model_and_optimizers: TrainerParameters, @@ -204,7 +204,7 @@ def test_trainer_checkpoint_callback_1( folder=path, topk=1, topk_monitor='acc', only_state_dict=only_state_dict, save_object='trainer') ] - n_epochs = 5 + n_epochs = 2 trainer = Trainer( model=model_and_optimizers.model, driver=driver, @@ -241,7 +241,7 @@ def test_trainer_checkpoint_callback_1( input_mapping=model_and_optimizers.input_mapping, output_mapping=model_and_optimizers.output_mapping, metrics=model_and_optimizers.metrics, - n_epochs=7, + n_epochs=5, output_from_new_proc="all", evaluate_fn='train_step' ) diff --git a/tests/core/callbacks/test_progress_callback_torch.py b/tests/core/callbacks/test_progress_callback_torch.py index d2f2f59b..a3af18b0 100644 --- a/tests/core/callbacks/test_progress_callback_torch.py +++ b/tests/core/callbacks/test_progress_callback_torch.py @@ -21,7 +21,7 @@ from tests.helpers.datasets.torch_data import TorchArgMaxDataset class ArgMaxDatasetConfig: num_labels: int = 10 feature_dimension: int = 10 - data_num: int = 100 + data_num: int = 20 seed: int = 0 batch_size: int = 4 @@ -87,7 +87,7 @@ def test_run( model_and_optimizers: TrainerParameters, device): if device != 'cpu' and not torch.cuda.is_available(): pytest.skip(f"No cuda for device:{device}") - n_epochs = 5 + n_epochs = 2 for progress_bar in ['rich', 'auto', None, 'raw', 'tqdm']: trainer = Trainer( model=model_and_optimizers.model, diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 463f144d..8734426f 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -70,12 +70,23 @@ def magic_argv_env_context(fn=None, timeout=300): def _handle_timeout(signum, frame): raise TimeoutError(f"\nYour test fn: {fn.__name__} has timed out.\n") + # 恢复 logger + handlers = [handler for handler in logger.handlers] + formatters = [handler.formatter for handler in handlers] + level = logger.level + signal.signal(signal.SIGALRM, _handle_timeout) signal.alarm(timeout) res = fn(*args, **kwargs) signal.alarm(0) sys.argv = deepcopy(command) os.environ = env + + for formatter, handler in zip(formatters, handlers): + handler.setFormatter(formatter) + logger.handlers = handlers + logger.setLevel(level) + return res return wrapper