|
|
@@ -263,7 +263,7 @@ class Trainer(TrainerEventTrigger): |
|
|
|
|
|
|
|
def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, |
|
|
|
num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, |
|
|
|
catch_KeyboardInterrupt=True): |
|
|
|
catch_KeyboardInterrupt=None): |
|
|
|
""" |
|
|
|
注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None,并且使用 ModelCheckpoint |
|
|
|
去保存断点重训的文件; |
|
|
@@ -273,15 +273,17 @@ class Trainer(TrainerEventTrigger): |
|
|
|
:param resume_from: 从哪个路径下恢复 trainer 的状态 |
|
|
|
:param resume_training: 是否按照 checkpoint 中训练状态恢复。如果为 False,则只恢复 model 和 optimizers 的状态。 |
|
|
|
:param catch_KeyboardInterrupt: 是否捕获KeyboardInterrupt, 如果捕获的话,不会抛出一场,trainer.run()之后的代码会继续运 |
|
|
|
行。 |
|
|
|
行。默认如果非 distributed 的 driver 会 catch ,distributed 不会 catch (无法 catch ) |
|
|
|
:return: |
|
|
|
""" |
|
|
|
|
|
|
|
if self.driver.is_distributed(): |
|
|
|
if catch_KeyboardInterrupt: |
|
|
|
logger.warning("Parameter `catch_KeyboardInterrupt` can only be False when you are using multi-device " |
|
|
|
"driver. And we are gonna to set it to False.") |
|
|
|
catch_KeyboardInterrupt = False |
|
|
|
if catch_KeyboardInterrupt is None: |
|
|
|
catch_KeyboardInterrupt = not self.driver.is_distributed() |
|
|
|
else: |
|
|
|
if self.driver.is_distributed(): |
|
|
|
if catch_KeyboardInterrupt: |
|
|
|
logger.warning("Parameter `catch_KeyboardInterrupt` can only be False when you are using multi-device " |
|
|
|
"driver. And we are gonna to set it to False.") |
|
|
|
catch_KeyboardInterrupt = False |
|
|
|
|
|
|
|
self._set_num_eval_batch_per_dl(num_eval_batch_per_dl) |
|
|
|
|
|
|
|