Browse Source

修改Trainer的catch_KeyboardInterrupt行为,防止一直warning

tags/v1.0.0alpha
yh_cc 2 years ago
parent
commit
5b54a0cd73
1 changed files with 10 additions and 8 deletions
  1. +10
    -8
      fastNLP/core/controllers/trainer.py

+ 10
- 8
fastNLP/core/controllers/trainer.py View File

@@ -263,7 +263,7 @@ class Trainer(TrainerEventTrigger):

def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1,
num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True,
catch_KeyboardInterrupt=True):
catch_KeyboardInterrupt=None):
"""
注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None,并且使用 ModelCheckpoint
去保存断点重训的文件;
@@ -273,15 +273,17 @@ class Trainer(TrainerEventTrigger):
:param resume_from: 从哪个路径下恢复 trainer 的状态
:param resume_training: 是否按照 checkpoint 中训练状态恢复。如果为 False,则只恢复 model 和 optimizers 的状态。
:param catch_KeyboardInterrupt: 是否捕获KeyboardInterrupt, 如果捕获的话,不会抛出一场,trainer.run()之后的代码会继续运
行。
行。默认如果非 distributed 的 driver 会 catch ,distributed 不会 catch (无法 catch )
:return:
"""

if self.driver.is_distributed():
if catch_KeyboardInterrupt:
logger.warning("Parameter `catch_KeyboardInterrupt` can only be False when you are using multi-device "
"driver. And we are gonna to set it to False.")
catch_KeyboardInterrupt = False
if catch_KeyboardInterrupt is None:
catch_KeyboardInterrupt = not self.driver.is_distributed()
else:
if self.driver.is_distributed():
if catch_KeyboardInterrupt:
logger.warning("Parameter `catch_KeyboardInterrupt` can only be False when you are using multi-device "
"driver. And we are gonna to set it to False.")
catch_KeyboardInterrupt = False

self._set_num_eval_batch_per_dl(num_eval_batch_per_dl)



Loading…
Cancel
Save