diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 73b712c9..a22f81d8 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -263,7 +263,7 @@ class Trainer(TrainerEventTrigger): def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, - catch_KeyboardInterrupt=True): + catch_KeyboardInterrupt=None): """ 注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None,并且使用 ModelCheckpoint 去保存断点重训的文件; @@ -273,15 +273,17 @@ class Trainer(TrainerEventTrigger): :param resume_from: 从哪个路径下恢复 trainer 的状态 :param resume_training: 是否按照 checkpoint 中训练状态恢复。如果为 False,则只恢复 model 和 optimizers 的状态。 :param catch_KeyboardInterrupt: 是否捕获KeyboardInterrupt, 如果捕获的话,不会抛出一场,trainer.run()之后的代码会继续运 - 行。 + 行。默认如果非 distributed 的 driver 会 catch ,distributed 不会 catch (无法 catch ) :return: """ - - if self.driver.is_distributed(): - if catch_KeyboardInterrupt: - logger.warning("Parameter `catch_KeyboardInterrupt` can only be False when you are using multi-device " - "driver. And we are gonna to set it to False.") - catch_KeyboardInterrupt = False + if catch_KeyboardInterrupt is None: + catch_KeyboardInterrupt = not self.driver.is_distributed() + else: + if self.driver.is_distributed(): + if catch_KeyboardInterrupt: + logger.warning("Parameter `catch_KeyboardInterrupt` can only be False when you are using multi-device " + "driver. And we are gonna to set it to False.") + catch_KeyboardInterrupt = False self._set_num_eval_batch_per_dl(num_eval_batch_per_dl)