@@ -390,4 +390,23 @@ class HasMonitorCallback(Callback): | |||||
if (self.larger_better and monitor_value1 > monitor_value2) or \ | if (self.larger_better and monitor_value1 > monitor_value2) or \ | ||||
(not self.larger_better and monitor_value1 < monitor_value2): | (not self.larger_better and monitor_value1 < monitor_value2): | ||||
better = True | better = True | ||||
return better | |||||
return better | |||||
@property | |||||
def monitor_name(self): | |||||
""" | |||||
返回 monitor 的名字,如果 monitor 是个 callable 的函数,则返回该函数的名称。 | |||||
:return: | |||||
""" | |||||
if callable(self.monitor): | |||||
try: | |||||
monitor_name = self.monitor.__qualname__ | |||||
except: | |||||
monitor_name = self.monitor.__name__ | |||||
elif self.monitor is None: | |||||
return None | |||||
else: | |||||
# 这里是能是monitor,而不能是real_monitor,因为用户再次运行的时候real_monitor被初始化为monitor了 | |||||
monitor_name = str(self.monitor) | |||||
return monitor_name |
@@ -19,11 +19,11 @@ from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir | |||||
class CheckpointCallback(HasMonitorCallback): | class CheckpointCallback(HasMonitorCallback): | ||||
def __init__( | def __init__( | ||||
self, | self, | ||||
monitor, | |||||
monitor:Optional[Union[str, Callable]]=None, | |||||
save_folder: Optional[Union[str, Path]] = None, | save_folder: Optional[Union[str, Path]] = None, | ||||
save_every_n_epochs: Optional[int] = None, | save_every_n_epochs: Optional[int] = None, | ||||
save_every_n_batches: Optional[int] = None, | save_every_n_batches: Optional[int] = None, | ||||
save_last: bool = True, | |||||
save_last: bool = False, | |||||
save_topk: Optional[int] = None, | save_topk: Optional[int] = None, | ||||
save_on_exception: Optional[Union[BaseException, Sequence[BaseException]]] = None, | save_on_exception: Optional[Union[BaseException, Sequence[BaseException]]] = None, | ||||
larger_better: bool = True, | larger_better: bool = True, | ||||
@@ -31,12 +31,32 @@ class CheckpointCallback(HasMonitorCallback): | |||||
model_save_fn: Optional[Callable] = None, | model_save_fn: Optional[Callable] = None, | ||||
**kwargs, | **kwargs, | ||||
): | ): | ||||
""" | |||||
请使用 ModelCheckpointCallback 与 TrainerCheckpointCallback 。 | |||||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | |||||
:param save_every_n_batches: 多少个 batch 保存一次。 | |||||
:param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 | |||||
:param save_topk: 保存 monitor 结果 topK 个。 | |||||
:param save_on_exception: 在出异常信息时,是否保存。传入需要捕获的异常的类。 | |||||
:param larger_better: monitor 的值是否时越大越好。 | |||||
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 | |||||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||||
:param kwargs: | |||||
""" | |||||
super().__init__(monitor=monitor, larger_better=larger_better, | super().__init__(monitor=monitor, larger_better=larger_better, | ||||
must_have_monitor=save_topk is not None) | must_have_monitor=save_topk is not None) | ||||
if save_folder is None: | if save_folder is None: | ||||
logger.warning( | logger.warning( | ||||
"Parameter `path` is None, and we will use the current work directory to find and load your model.") | "Parameter `path` is None, and we will use the current work directory to find and load your model.") | ||||
save_folder = Path.cwd() | save_folder = Path.cwd() | ||||
save_folder = Path(save_folder) | |||||
if not save_folder.exists(): | if not save_folder.exists(): | ||||
raise NotADirectoryError(f"Path '{save_folder.absolute()}' is not existed!") | raise NotADirectoryError(f"Path '{save_folder.absolute()}' is not existed!") | ||||
elif save_folder.is_file(): | elif save_folder.is_file(): | ||||
@@ -71,7 +91,7 @@ class CheckpointCallback(HasMonitorCallback): | |||||
else: | else: | ||||
save_on_exception = [] | save_on_exception = [] | ||||
self.save_folder = Path(save_folder) | |||||
self.save_folder = save_folder | |||||
self.save_every_n_epochs = save_every_n_epochs | self.save_every_n_epochs = save_every_n_epochs | ||||
self.save_every_n_batches = save_every_n_batches | self.save_every_n_batches = save_every_n_batches | ||||
self.save_last = save_last | self.save_last = save_last | ||||
@@ -88,8 +108,7 @@ class CheckpointCallback(HasMonitorCallback): | |||||
# 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候, | # 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候, | ||||
# FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的; | # FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的; | ||||
self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) | self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) | ||||
# 我们只需要保证这个创建文件夹的操作只在进程 0 上进行即可;因为后续的实际的保存操作,其它进程实际并不会去执行; | |||||
synchronize_mkdir(self.timestamp_path) | |||||
# 该 folder 只在保存真的要发生的时候再创建。 | |||||
def on_after_trainer_initialized(self, trainer, driver): | def on_after_trainer_initialized(self, trainer, driver): | ||||
if self.save_topk is not None: | if self.save_topk is not None: | ||||
@@ -98,8 +117,6 @@ class CheckpointCallback(HasMonitorCallback): | |||||
logger.warning("You set `save_topk`, but `evaluate_dataloaders` is not set in Trainer.") | logger.warning("You set `save_topk`, but `evaluate_dataloaders` is not set in Trainer.") | ||||
def on_validate_end(self, trainer, results): | def on_validate_end(self, trainer, results): | ||||
if len(results) == 0: | |||||
return | |||||
self._save_topk(trainer, results) | self._save_topk(trainer, results) | ||||
def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): | def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): | ||||
@@ -136,16 +153,17 @@ class CheckpointCallback(HasMonitorCallback): | |||||
states['timestamp_path'] = str(self.timestamp_path.absolute()) | states['timestamp_path'] = str(self.timestamp_path.absolute()) | ||||
states['_topk_model'] = deepcopy(self._topk_model) | states['_topk_model'] = deepcopy(self._topk_model) | ||||
states['save_topk'] = 0 if self.save_topk is None else self.save_topk | states['save_topk'] = 0 if self.save_topk is None else self.save_topk | ||||
states['_real_monitor'] = self._real_monitor | |||||
if isinstance(self._real_monitor, str): | |||||
states['_real_monitor'] = self._real_monitor | |||||
return states | return states | ||||
def on_load_checkpoint(self, trainer, states: Optional[Dict]): | def on_load_checkpoint(self, trainer, states: Optional[Dict]): | ||||
timestamp_path = states['timestamp_path'] | timestamp_path = states['timestamp_path'] | ||||
if not os.path.exists(timestamp_path): | if not os.path.exists(timestamp_path): | ||||
logger.info(f"The resuming save folder {timestamp_path} is not exists, will checkpoint save to " | |||||
logger.info(f"The resuming checkpoint folder {timestamp_path} is not exists, will checkpoint save to " | |||||
f" {self.timestamp_path.absolute()}.") | f" {self.timestamp_path.absolute()}.") | ||||
else: | else: | ||||
logger.info(f"Resume to save in path: {timestamp_path}.") | |||||
logger.info(f"Resume to checkpoint in path: {timestamp_path}.") | |||||
self.timestamp_path = Path(timestamp_path) | self.timestamp_path = Path(timestamp_path) | ||||
_topk_model = states['_topk_model'] | _topk_model = states['_topk_model'] | ||||
save_topk = None if int(states['save_topk']) == 0 else int(states['save_topk']) | save_topk = None if int(states['save_topk']) == 0 else int(states['save_topk']) | ||||
@@ -153,7 +171,8 @@ class CheckpointCallback(HasMonitorCallback): | |||||
assert self.save_topk == save_topk, f"The checkpoint set save_topk={save_topk}, while this callback set it " \ | assert self.save_topk == save_topk, f"The checkpoint set save_topk={save_topk}, while this callback set it " \ | ||||
f"as {save_topk}." | f"as {save_topk}." | ||||
self._topk_model.update(self._topk_model) | self._topk_model.update(self._topk_model) | ||||
self._real_monitor = states["real_monitor"] | |||||
self._real_monitor = states["_real_monitor"] | |||||
def _save_topk(self, trainer: "fastNLP.Trainer", results: Dict): | def _save_topk(self, trainer: "fastNLP.Trainer", results: Dict): | ||||
""" | """ | ||||
@@ -231,9 +250,9 @@ class ModelCheckpointCallback(CheckpointCallback): | |||||
model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。 | model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。 | ||||
若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 | 若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 | ||||
:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型), | |||||
返回一个 float 值作为 monitor 的结果。 | |||||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | ||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | ||||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | :param save_every_n_epochs: 多少个 epoch 保存一次。 | ||||
@@ -249,6 +268,11 @@ class ModelCheckpointCallback(CheckpointCallback): | |||||
""" | """ | ||||
@property | @property | ||||
def save_fn_name(self): | def save_fn_name(self): | ||||
""" | |||||
调用 Trainer 中的哪个函数。 | |||||
:return: | |||||
""" | |||||
return 'save_model' | return 'save_model' | ||||
@property | @property | ||||
@@ -257,7 +281,7 @@ class ModelCheckpointCallback(CheckpointCallback): | |||||
通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; | 通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; | ||||
:return: | :return: | ||||
""" | """ | ||||
return f"model_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" | |||||
return f"model_checkpoint#monitor-{self.monitor_name}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" | |||||
@property | @property | ||||
def folder_prefix(self): | def folder_prefix(self): | ||||
@@ -279,9 +303,9 @@ class TrainerCheckpointCallback(CheckpointCallback): | |||||
model_save_fn 为 None ,则以上每个 folder 中,将生成两个文件:fastnlp_trainer.pkl.tar 以及 fastnlp_model.pkl.tar 。 | model_save_fn 为 None ,则以上每个 folder 中,将生成两个文件:fastnlp_trainer.pkl.tar 以及 fastnlp_model.pkl.tar 。 | ||||
若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 | 若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 | ||||
:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型), | |||||
返回一个 float 值作为 monitor 的结果。 | |||||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | ||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | ||||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | :param save_every_n_epochs: 多少个 epoch 保存一次。 | ||||
@@ -297,6 +321,11 @@ class TrainerCheckpointCallback(CheckpointCallback): | |||||
""" | """ | ||||
@property | @property | ||||
def save_fn_name(self): | def save_fn_name(self): | ||||
""" | |||||
调用 Trainer 中的哪个函数。 | |||||
:return: | |||||
""" | |||||
return 'save' | return 'save' | ||||
@property | @property | ||||
@@ -305,7 +334,8 @@ class TrainerCheckpointCallback(CheckpointCallback): | |||||
通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; | 通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; | ||||
:return: | :return: | ||||
""" | """ | ||||
return f"trainer_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" | |||||
return f"trainer_checkpoint#monitor-{self.monitor_name}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" | |||||
@property | @property | ||||
def folder_prefix(self): | def folder_prefix(self): | ||||
@@ -12,8 +12,9 @@ class EarlyStopCallback(HasMonitorCallback): | |||||
def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10): | def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10): | ||||
""" | """ | ||||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 | |||||
evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param larger_better: monitor 的值是否是越大越好。 | :param larger_better: monitor 的值是否是越大越好。 | ||||
:param patience: 多少次 validate 不没有提升就停止。 | :param patience: 多少次 validate 不没有提升就停止。 | ||||
""" | """ | ||||
@@ -46,17 +47,20 @@ class EarlyStopCallback(HasMonitorCallback): | |||||
states = { | states = { | ||||
'patience': self.patience, | 'patience': self.patience, | ||||
'wait': self.wait, | 'wait': self.wait, | ||||
'monitor': self.monitor, | |||||
'monitor_value': self.monitor_value | 'monitor_value': self.monitor_value | ||||
} | } | ||||
if not callable(self._real_monitor): | |||||
states['_real_monitor'] = self._real_monitor | |||||
return states | return states | ||||
def on_load_checkpoint(self, trainer, states): | def on_load_checkpoint(self, trainer, states): | ||||
self.patience = states['patience'] | self.patience = states['patience'] | ||||
self.wait = states['wait'] | self.wait = states['wait'] | ||||
self.monitor = states['monitor'] | |||||
self.monitor_value = float(states['monitor_value']) | self.monitor_value = float(states['monitor_value']) | ||||
if '_real_monitor' in states: | |||||
self._real_monitor = states['_real_monitor'] | |||||
@property | |||||
def callback_name(self): | def callback_name(self): | ||||
return f'EarlyStopCallback#monitor-{self.monitor}#patience-{self.patience}' | |||||
return f'EarlyStopCallback#monitor-{self.monitor_name}#patience-{self.patience}' | |||||
@@ -21,8 +21,9 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
""" | """ | ||||
保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 | 保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 | ||||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 | |||||
evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param larger_better: 该 metric 值是否是越大越好。 | :param larger_better: 该 metric 值是否是越大越好。 | ||||
:param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | :param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | ||||
不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | 不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | ||||
@@ -44,10 +44,11 @@ class RichCallback(ProgressCallback): | |||||
:param print_every: 多少个 batch 更新一次显示。 | :param print_every: 多少个 batch 更新一次显示。 | ||||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | ||||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。如果为 None ,会尝试使用 trainer 中设置的 monitor 。 | |||||
也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param larger_better: 是否是monitor的结果越大越好。 | |||||
:param format_json: 是否format json再打印 | |||||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 | |||||
完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor | |||||
。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param larger_better: 是否是 monitor 的结果越大越好。 | |||||
:param format_json: 是否格式化 json 再打印 | |||||
""" | """ | ||||
super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False) | super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False) | ||||
self.print_every = print_every | self.print_every = print_every | ||||
@@ -136,8 +137,9 @@ class RawTextCallback(ProgressCallback): | |||||
:param print_every: 多少个 batch 更新一次显示。 | :param print_every: 多少个 batch 更新一次显示。 | ||||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | ||||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。也可以传入一个函数,接受参数为 evaluation 的结果( | |||||
字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 | |||||
完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor | |||||
。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param larger_better: 是否是monitor的结果越大越好。 | :param larger_better: 是否是monitor的结果越大越好。 | ||||
:param format_json: 是否format json再打印 | :param format_json: 是否format json再打印 | ||||
""" | """ | ||||
@@ -36,7 +36,7 @@ class Evaluator: | |||||
model, | model, | ||||
dataloaders, | dataloaders, | ||||
metrics: Optional[Union[Dict, Metric]] = None, | metrics: Optional[Union[Dict, Metric]] = None, | ||||
driver: Union[str, Driver] = 'single', | |||||
driver: Union[str, Driver] = 'torch', | |||||
device: Optional[Union[int, List[int], str]] = None, | device: Optional[Union[int, List[int], str]] = None, | ||||
batch_step_fn: Optional[callable] = None, | batch_step_fn: Optional[callable] = None, | ||||
evaluate_fn: Optional[str] = None, # 首先尝试找 evaluate_step, 找不到 forward, callable | evaluate_fn: Optional[str] = None, # 首先尝试找 evaluate_step, 找不到 forward, callable | ||||
@@ -49,8 +49,8 @@ class Evaluator: | |||||
): | ): | ||||
""" | """ | ||||
:param dataloaders: | |||||
:param model: | :param model: | ||||
:param dataloaders: | |||||
:param metrics: 使用的 metric 。必须为 dict 类型,其中 key 为 metric 的名称,value 为一个 Metric 对象。支持 fastNLP 的 | :param metrics: 使用的 metric 。必须为 dict 类型,其中 key 为 metric 的名称,value 为一个 Metric 对象。支持 fastNLP 的 | ||||
metric ,torchmetrics,allennlpmetrics等。 | metric ,torchmetrics,allennlpmetrics等。 | ||||
:param driver: 使用 driver 。 | :param driver: 使用 driver 。 | ||||
@@ -120,7 +120,8 @@ class Evaluator: | |||||
if evaluate_fn is not None and not isinstance(evaluate_fn, str): | if evaluate_fn is not None and not isinstance(evaluate_fn, str): | ||||
raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") | raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") | ||||
self._evaluate_step, self._evaluate_step_signature_fn = self.driver.get_model_call_fn("evaluate_step" if evaluate_fn is None else evaluate_fn) | |||||
self._evaluate_step, self._evaluate_step_signature_fn = \ | |||||
self.driver.get_model_call_fn("evaluate_step" if evaluate_fn is None else evaluate_fn) | |||||
self.evaluate_fn = evaluate_fn | self.evaluate_fn = evaluate_fn | ||||
self.dataloaders = {} | self.dataloaders = {} | ||||
@@ -134,8 +135,6 @@ class Evaluator: | |||||
self.driver.barrier() | self.driver.barrier() | ||||
self.driver.check_dataloader_legality(self.dataloaders, "dataloaders", is_train=False) | |||||
def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: | def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: | ||||
""" | """ | ||||
返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。 | 返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。 | ||||
@@ -20,7 +20,7 @@ class TrainBatchLoop(Loop): | |||||
else lambda *args, **kwargs: None | else lambda *args, **kwargs: None | ||||
dataloader = iter(dataloader) | dataloader = iter(dataloader) | ||||
indices = None | indices = None | ||||
while True: | |||||
while trainer.batch_idx_in_epoch<=trainer.num_batches_per_epoch: | |||||
try: | try: | ||||
trainer.on_fetch_data_begin() | trainer.on_fetch_data_begin() | ||||
batch = next(dataloader) | batch = next(dataloader) | ||||
@@ -30,10 +30,8 @@ class TrainBatchLoop(Loop): | |||||
batch = trainer.move_data_to_device(batch) | batch = trainer.move_data_to_device(batch) | ||||
except StopIteration: | except StopIteration: | ||||
break | break | ||||
except EarlyStopException: # 在 Trainer 处理 earlystop 的 exception | |||||
break | |||||
except BaseException as e: | except BaseException as e: | ||||
if indices: | |||||
if indices and not isinstance(e, EarlyStopException): | |||||
logger.debug(f"The following exception happens when running on samples: {indices}") | logger.debug(f"The following exception happens when running on samples: {indices}") | ||||
raise e | raise e | ||||
@@ -174,9 +174,8 @@ class Trainer(TrainerEventTrigger): | |||||
optimizers=optimizers, | optimizers=optimizers, | ||||
device=device, | device=device, | ||||
n_epochs=n_epochs, | n_epochs=n_epochs, | ||||
validate_dataloaders=evaluate_dataloaders, | |||||
batch_step_fn=batch_step_fn, | batch_step_fn=batch_step_fn, | ||||
validate_batch_step_fn=evaluate_batch_step_fn, | |||||
z=evaluate_batch_step_fn, | |||||
evaluate_fn=evaluate_fn, | evaluate_fn=evaluate_fn, | ||||
callbacks=callbacks, | callbacks=callbacks, | ||||
metrics=metrics, | metrics=metrics, | ||||
@@ -264,7 +263,6 @@ class Trainer(TrainerEventTrigger): | |||||
self.on_after_trainer_initialized(self.driver) | self.on_after_trainer_initialized(self.driver) | ||||
self.driver.barrier() | self.driver.barrier() | ||||
self.driver.check_dataloader_legality(self.train_dataloader, "train_dataloader", is_train=True) | |||||
def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, | def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, | ||||
num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, | num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, | ||||
@@ -310,7 +308,7 @@ class Trainer(TrainerEventTrigger): | |||||
self.num_batches_per_epoch = len(self.dataloader) | self.num_batches_per_epoch = len(self.dataloader) | ||||
self.total_batches = self.num_batches_per_epoch * self.n_epochs | self.total_batches = self.num_batches_per_epoch * self.n_epochs | ||||
self.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + self.batch_idx_in_epoch | |||||
self.on_train_begin() | self.on_train_begin() | ||||
self.driver.barrier() | self.driver.barrier() | ||||
self.driver.zero_grad(self.set_grad_to_none) | self.driver.zero_grad(self.set_grad_to_none) | ||||
@@ -637,6 +635,8 @@ class Trainer(TrainerEventTrigger): | |||||
:param folder: 保存断点重训 states 的文件地址; | :param folder: 保存断点重训 states 的文件地址; | ||||
:param resume_training: 是否从上次的 batch 开始训练,或者只从最近的 epoch 开始训练;注意如果 resume_training=True,那么我们 | :param resume_training: 是否从上次的 batch 开始训练,或者只从最近的 epoch 开始训练;注意如果 resume_training=True,那么我们 | ||||
只会加载 model 和 optimizers 的状态;而其余的对象的值则根据用户的 Trainer 的初始化直接重置; | 只会加载 model 和 optimizers 的状态;而其余的对象的值则根据用户的 Trainer 的初始化直接重置; | ||||
:param only_state_dict: 保存的 model 是否只包含了权重。 | |||||
:param model_load_fn: 使用的模型加载函数,参数应为一个 文件夹,不返回任何内容。 | |||||
""" | """ | ||||
self.driver.barrier() | self.driver.barrier() | ||||
if isinstance(folder, str): | if isinstance(folder, str): | ||||
@@ -675,8 +675,6 @@ class Trainer(TrainerEventTrigger): | |||||
# 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 | # 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 | ||||
# '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 | # '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 | ||||
self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') | self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') | ||||
self.trainer_state.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + \ | |||||
self.batch_idx_in_epoch | |||||
# 这个是防止用户在 Trainer.load 之后还没结束当前 epoch 又继续 save | # 这个是防止用户在 Trainer.load 之后还没结束当前 epoch 又继续 save | ||||
self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch | self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch | ||||
@@ -65,10 +65,10 @@ class TrainerState: | |||||
""" | """ | ||||
n_epochs: Optional[int] = None # 无论如何重新算 | n_epochs: Optional[int] = None # 无论如何重新算 | ||||
cur_epoch_idx: Optional[int] = None # 断点重训; 仅当 resume=False 时为0; | |||||
global_forward_batches: Optional[int] = None # 断点重训 | |||||
cur_epoch_idx: Optional[int] = 0 # 断点重训; 仅当 resume=False 时为0; | |||||
global_forward_batches: Optional[int] = 0 # 断点重训 | |||||
batch_idx_in_epoch: Optional[int] = None # 断点重训 | |||||
batch_idx_in_epoch: Optional[int] = 0 # 断点重训 | |||||
num_batches_per_epoch: Optional[int] = None # 无论如何重新算 | num_batches_per_epoch: Optional[int] = None # 无论如何重新算 | ||||
@@ -86,7 +86,7 @@ class Driver(ABC): | |||||
函数; | 函数; | ||||
:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; | :param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; | ||||
:param fn: 由 Trainer 传入的用于网络前向传播一次的函数; | |||||
:param fn: 调用该函数进行一次计算。 | |||||
:param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call | :param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call | ||||
函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward; | 函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward; | ||||
:return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); | :return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); | ||||
@@ -126,17 +126,6 @@ class Driver(ABC): | |||||
def model(self, model): | def model(self, model): | ||||
self._model = model | self._model = model | ||||
@staticmethod | |||||
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||||
r""" | |||||
该函数会在 trainer 或者 evaluator 设置 dataloader 后检测 dataloader 的合法性,因为不同的深度学习的框架需要的 dataloader 的 | |||||
行为是不相同的; | |||||
:param dataloader: 需要检测的输入的 `dataloader`; | |||||
:param dataloader_name: | |||||
""" | |||||
raise NotImplementedError("Each specific driver should implemented its own `check_dataloader_legality` function.") | |||||
@property | @property | ||||
def optimizers(self) -> List: | def optimizers(self) -> List: | ||||
r""" | r""" | ||||
@@ -406,7 +406,7 @@ class TorchDDPDriver(TorchDriver): | |||||
if hasattr(model, fn): | if hasattr(model, fn): | ||||
fn = getattr(model, fn) | fn = getattr(model, fn) | ||||
if not callable(fn): | if not callable(fn): | ||||
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") | |||||
raise RuntimeError(f"The `{fn}` attribute of model is not `Callable`.") | |||||
return fn, None | return fn, None | ||||
elif fn in {"train_step", "evaluate_step"}: | elif fn in {"train_step", "evaluate_step"}: | ||||
return model, model.forward | return model, model.forward | ||||
@@ -199,6 +199,7 @@ class TorchDriver(Driver): | |||||
num_consumed_batches = sampler_states['num_consumed_samples'] | num_consumed_batches = sampler_states['num_consumed_samples'] | ||||
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | ||||
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | ||||
states['sampler_states'] = sampler_states | |||||
else: | else: | ||||
raise RuntimeError( | raise RuntimeError( | ||||
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | 'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | ||||
@@ -181,6 +181,7 @@ class TestCheckNumberOfParameters: | |||||
def test_get_fun_msg(): | def test_get_fun_msg(): | ||||
# 测试运行 | |||||
def demo(x): | def demo(x): | ||||
pass | pass | ||||
@@ -1,3 +1,6 @@ | |||||
import numpy as np | |||||
class NormalIterator: | class NormalIterator: | ||||
def __init__(self, num_of_data=1000): | def __init__(self, num_of_data=1000): | ||||
self._num_of_data = num_of_data | self._num_of_data = num_of_data | ||||
@@ -15,4 +18,15 @@ class NormalIterator: | |||||
return self._data | return self._data | ||||
def __len__(self): | def __len__(self): | ||||
return self._num_of_data | |||||
return self._num_of_data | |||||
class RandomDataset: | |||||
def __init__(self, num_data=10): | |||||
self.data = np.random.rand(num_data) | |||||
def __len__(self): | |||||
return len(self.data) | |||||
def __getitem__(self, item): | |||||
return self.data[item] |