diff --git a/fastNLP/core/callbacks/__init__.py b/fastNLP/core/callbacks/__init__.py index a47ab998..fc5d9d5b 100644 --- a/fastNLP/core/callbacks/__init__.py +++ b/fastNLP/core/callbacks/__init__.py @@ -10,7 +10,8 @@ __all__ = [ 'ProgressCallback', 'RichCallback', "LRSchedCallback", - 'LoadBestModelCallback' + 'LoadBestModelCallback', + "EarlyStopCallback" ] @@ -21,4 +22,5 @@ from .checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallb from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback from .lr_scheduler_callback import LRSchedCallback from .load_best_model_callback import LoadBestModelCallback +from .early_stop_callback import EarlyStopCallback diff --git a/fastNLP/core/callbacks/callback.py b/fastNLP/core/callbacks/callback.py index b2d99b51..4b553a1f 100644 --- a/fastNLP/core/callbacks/callback.py +++ b/fastNLP/core/callbacks/callback.py @@ -1,11 +1,15 @@ -from typing import Union, Callable, Dict, Optional +from typing import Union, Callable, Dict, Optional, Any +from abc import ABC __all__ = [ 'Callback', ] from .callback_events import Events, EventsList, Filter +from .utils import _get_monitor_value from fastNLP.core.callbacks.callback_events import _SingleEventState +from fastNLP.core.log import logger +from fastNLP.core.utils import apply_to_collection class Callback: @@ -150,4 +154,82 @@ class _CallbackWrapper(Callback): return self.fn.__name__ +class CanItemDataType(ABC): + """ + 检测可以进行传输的对象。 + + """ + + @classmethod + def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: + if cls is CanItemDataType: + item = getattr(subclass, 'item', None) + return callable(item) + return NotImplemented + + +class HasMonitorCallback(Callback): + def __init__(self, monitor, larger_better, must_have_monitor=False): + self.set_monitor(monitor, larger_better) + self.must_have_moinitor = must_have_monitor + + def set_monitor(self, monitor, larger_better): + self.monitor = str(monitor) if monitor is not None else None + self.larger_better = bool(larger_better) + if larger_better: + self.monitor_value = float('-inf') + else: + self.monitor_value = float('inf') + self._real_monitor = self.monitor + + def on_after_trainer_initialized(self, trainer, driver): + """ + 如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。 + 同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。 + + :param trainer: + :param driver: + :return: + """ + if self.monitor is None and trainer.monitor is not None: + self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better) + if self.must_have_moinitor and self.monitor is None: + raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " + f"You can set it in the initialization or through Trainer.") + + def get_monitor_value(self, results:Dict)->float: + """ + 获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。 + :param results: + :return: + """ + if len(results)==0: + return 0 + # 保证所有的 tensor 都被转换为了 python 特定的类型 + results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) + use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, + real_monitor=self._real_monitor, + res=results) + if self._real_monitor != use_monitor: # 发生了替换需要打印 + logger.warning( + f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " + f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.") + self._real_monitor = use_monitor + return monitor_value + + def is_better_monitor_value(self, monitor_value: float, keep_if_better=True): + """ + 检测 monitor_value 是否是更好的 + + :param monitor_value: + :param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 + :return: + """ + better = False + if (self.larger_better and monitor_value > self.monitor_value) or \ + (not self.larger_better and monitor_value < self.monitor_value): + better = True + if keep_if_better: + self.monitor_value = monitor_value + return better \ No newline at end of file diff --git a/fastNLP/core/callbacks/checkpoint_callback.py b/fastNLP/core/callbacks/checkpoint_callback.py index d3a3b52d..839a9522 100644 --- a/fastNLP/core/callbacks/checkpoint_callback.py +++ b/fastNLP/core/callbacks/checkpoint_callback.py @@ -5,12 +5,12 @@ __all__ = [ import os from typing import Union, Optional, Callable, Dict, Sequence, Any, Mapping from pathlib import Path -from abc import ABC import sys +from copy import deepcopy import fastNLP -from .callback import Callback, Filter +from .callback import Callback, HasMonitorCallback from fastNLP.core.callbacks.utils import _get_monitor_value from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_LAUNCH_TIME @@ -18,22 +18,7 @@ from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir from fastNLP.core.utils import apply_to_collection -class CanItemDataType(ABC): - """ - 检测可以进行传输的对象。 - - """ - - @classmethod - def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: - if cls is CanItemDataType: - item = getattr(subclass, 'item', None) - return callable(item) - return NotImplemented - - - -class CheckpointCallback(Callback): +class CheckpointCallback(HasMonitorCallback): def __init__( self, monitor, @@ -48,12 +33,8 @@ class CheckpointCallback(Callback): model_save_fn: Optional[Callable] = None, **kwargs, ): - if monitor is None and save_topk is not None: - raise ValueError("Parameter `monitor` must be set when you want to use 'save_topk'.") - - if monitor is not None and not isinstance(monitor, str): - raise ValueError("Parameter `monitor` should be of 'str' type.") - + super().__init__(monitor=monitor, larger_better=larger_better, + must_have_monitor=save_topk is not None) if save_folder is None: logger.warning( "Parameter `path` is None, and we will use the current work directory to find and load your model.") @@ -91,13 +72,12 @@ class CheckpointCallback(Callback): "`BaseException` type.") else: save_on_exception = [] - self.monitor = monitor + self.save_folder = Path(save_folder) self.save_every_n_epochs = save_every_n_epochs self.save_every_n_batches = save_every_n_batches self.save_last = save_last self.save_topk = save_topk - self.larger_better = larger_better self.only_state_dict = only_state_dict self.model_save_fn = model_save_fn self.save_on_exception = save_on_exception @@ -107,20 +87,22 @@ class CheckpointCallback(Callback): self._topk_model = {} self._topn = 0 # 表示目前已经保存了几个最好的模型; - # 因为我们在 `_get_validate_metric` 函数中,当在返回的 `validate_res` 字典中找不到 `monitor` 时,是使用匹配找到的 - # key 对应的 value 当做结果;但是这样存在的一个问题在于如果用户传入的 metric 返回的 sub_metric 的名字可能会混淆,并且其在下一次 - # 训练的代码中修改了这些 sub_metric 返回的顺序,那么就会导致模糊匹配拿到的 key 和 value 与之前的不是同一个,这显然不是合理的行为; - # 因此我们通过该变量来表示我们通过模糊匹配拿到的 key; - self._real_monitor = self.monitor - # 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候, # FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的; self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) # 我们只需要保证这个创建文件夹的操作只在进程 0 上进行即可;因为后续的实际的保存操作,其它进程实际并不会去执行; synchronize_mkdir(self.timestamp_path) - def on_validate_end(self, trainer, validate_res): - self._save_topk(trainer, validate_res) + def on_after_trainer_initialized(self, trainer, driver): + if self.save_topk is not None: + super().on_after_trainer_initialized(trainer, driver) + if self.save_topk is not None and trainer.evaluator is None: + logger.warning("You set `save_topk`, but `validate_dataloaders` is not set in Trainer.") + + def on_validate_end(self, trainer, results): + if len(results) == 0: + return + self._save_topk(trainer, results) def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): if trainer.cur_epoch_idx % self.save_every_n_epochs == 0: @@ -143,7 +125,7 @@ class CheckpointCallback(Callback): def on_sanity_check_end(self, trainer, sanity_check_res): # 主要核对一下 monitor 是否存在。 - self._get_validate_metric(sanity_check_res) + self.get_monitor_value(results=sanity_check_res) def on_save_checkpoint(self, trainer) -> Dict: """ @@ -154,8 +136,7 @@ class CheckpointCallback(Callback): states = {} states['timestamp_path'] = str(self.timestamp_path.absolute()) - states['_topk_model'] = apply_to_collection(self._topk_model, dtype=CanItemDataType, - function=lambda x:x.item()) + states['_topk_model'] = deepcopy(self._topk_model) states['save_topk'] = 0 if self.save_topk is None else self.save_topk states['_real_monitor'] = self._real_monitor return states @@ -176,30 +157,30 @@ class CheckpointCallback(Callback): self._topk_model.update(self._topk_model) self._real_monitor = states["real_monitor"] - def _save_topk(self, trainer: "fastNLP.Trainer", validate_res: Dict): + def _save_topk(self, trainer: "fastNLP.Trainer", results: Dict): """ 根据validate_res决定保存哪些model的函数。会自动移除掉不满足topk的文件夹。 :param trainer: - :param validate_res: + :param results: :return: """ if self.save_topk is not None: - _metric_value = self._get_validate_metric(validate_res) + monitor_value = self.get_monitor_value(results=results) folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ - f"-{self._real_monitor}_{_metric_value}" + f"-{self._real_monitor}_{monitor_value}" _should_save = False if self._topn < self.save_topk: - self._topk_model[folder_name] = _metric_value + self._topk_model[folder_name] = monitor_value self._topn += 1 _should_save = True else: _least_valuable_model = (min if self.larger_better else max)(self._topk_model, key=lambda x: self._topk_model[x]) - if (self.larger_better and _metric_value > self._topk_model[_least_valuable_model]) or \ - (self.larger_better is False and _metric_value < self._topk_model[_least_valuable_model]): - self._topk_model[folder_name] = _metric_value + if (self.larger_better and monitor_value > self._topk_model[_least_valuable_model]) or \ + (self.larger_better is False and monitor_value < self._topk_model[_least_valuable_model]): + self._topk_model[folder_name] = monitor_value _should_save = True self._topk_model.pop(_least_valuable_model) synchronize_safe_rm(self.timestamp_path.joinpath(_least_valuable_model)) @@ -235,7 +216,11 @@ class CheckpointCallback(Callback): :return: """ use_monitor, value = _get_monitor_value(monitor=self.monitor, real_monitor=self._real_monitor, res=res) + if self._real_monitor != use_monitor: + logger.warning(f"We can not find `{self._real_monitor}` in the evaluation result (with keys as {list(res.keys())}), " + f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.") self._real_monitor = use_monitor + return value @property @@ -263,7 +248,7 @@ class ModelCheckpointCallback(CheckpointCallback): 若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 - 的那个作为 monitor 。 + 的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。 :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 :param save_every_n_epochs: 多少个 epoch 保存一次。 @@ -310,7 +295,7 @@ class TrainerCheckpointCallback(CheckpointCallback): 若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 - 的那个作为 monitor 。 + 的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。 :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 :param save_every_n_epochs: 多少个 epoch 保存一次。 diff --git a/fastNLP/core/callbacks/early_stop_callback.py b/fastNLP/core/callbacks/early_stop_callback.py new file mode 100644 index 00000000..602236f7 --- /dev/null +++ b/fastNLP/core/callbacks/early_stop_callback.py @@ -0,0 +1,61 @@ +__all__ = [ + 'EarlyStopCallback' +] + +from typing import Dict + +from .callback import HasMonitorCallback +from fastNLP.core.utils.exceptions import EarlyStopException + + +class EarlyStopCallback(HasMonitorCallback): + def __init__(self, monitor:str=None, larger_better:bool=True, patience:int=10): + """ + + :param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。 + :param larger_better: monitor 的值是否是越大越好。 + :param patience: 多少次 validate 不没有提升就停止。 + """ + super(EarlyStopCallback, self).__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True) + self.wait = 0 + self.patience = patience + + def on_validate_end(self, trainer, results): + if len(results)==0: + return + monitor_value = self.get_monitor_value(results) + if self.is_better_monitor_value(monitor_value, keep_if_better=True): + self.wait = 0 + else: + self.wait += 1 + + def on_fetch_data_begin(self, trainer): + # 当是 step validate 的时候,下一步执行的就是这个, 所以在这里检查。 + if self.wait >= self.patience: + raise EarlyStopException(f"After {self.wait} validations, no improvement for " + f"metric `{self._real_monitor}`") + + def on_train_epoch_begin(self, trainer): + # 当是 epoch validate 的时候,下一步执行的就是这个, 所以在这里检查。 + if self.wait >= self.patience: + raise EarlyStopException(f"After {self.wait} validations, no improvement for " + f"metric `{self._real_monitor}`(best value: {self.monitor_value})") + + def on_save_checkpoint(self, trainer) -> Dict: + states = { + 'patience': self.patience, + 'wait': self.wait, + 'monitor': self.monitor, + 'monitor_value': self.monitor_value + } + return states + + def on_load_checkpoint(self, trainer, states): + self.patience = states['patience'] + self.wait = states['wait'] + self.monitor = states['monitor'] + self.monitor_value = float(states['monitor_value']) + + def callback_name(self): + return f'EarlyStopCallback#monitor-{self.monitor}#patience-{self.patience}' + diff --git a/fastNLP/core/callbacks/load_best_model_callback.py b/fastNLP/core/callbacks/load_best_model_callback.py index e7b94f8c..9a4bb65f 100644 --- a/fastNLP/core/callbacks/load_best_model_callback.py +++ b/fastNLP/core/callbacks/load_best_model_callback.py @@ -4,8 +4,7 @@ __all__ = [ import os from typing import Optional, Callable -from .callback import Callback -from .utils import _get_monitor_value +from .callback import HasMonitorCallback from io import BytesIO import shutil @@ -14,15 +13,15 @@ from fastNLP.core.log import logger from fastNLP.envs import all_rank_call -class LoadBestModelCallback(Callback): - def __init__(self, monitor:str, larger_better:bool = True, only_state_dict:bool = True, +class LoadBestModelCallback(HasMonitorCallback): + def __init__(self, monitor:str=None, larger_better:bool = True, only_state_dict:bool = True, save_folder:Optional[str] = None, model_save_fn:Optional[Callable] = None, model_load_fn:Optional[Callable] = None, delete_after_train:bool = True): """ 保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 - :param str monitor: 监控的 metric 值。 + :param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。 :param larger_better: 该 metric 值是否是越大越好。 :param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 @@ -33,6 +32,7 @@ class LoadBestModelCallback(Callback): 请在函数内完成对模型的加载。 :param delete_after_train: 在训练结束后是否删掉模型。 """ + super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True) if model_load_fn is not None: assert callable(model_load_fn), "`model_load_fn` must be a callable object." assert model_save_fn is not None, "`model_load_fn` and `model_save_fn` must be passed at the same time." @@ -56,15 +56,11 @@ class LoadBestModelCallback(Callback): self.real_save_folder = None self.buffer = BytesIO() - self.monitor = monitor - self.larger_better = larger_better self.save_folder = save_folder self.only_state_dict = only_state_dict self.model_save_fn = model_save_fn self.model_load_fn = model_load_fn self.delete_after_after = delete_after_train - self._real_monitor = None - self.monitor_value = float('-inf') if larger_better else float('inf') def on_after_trainer_initialized(self, trainer, driver): if self.save_folder is not None and driver.is_distributed() and int(os.environ.get(FASTNLP_BACKEND_LAUNCH, 0))==1: @@ -76,13 +72,16 @@ class LoadBestModelCallback(Callback): raise RuntimeError(f"Currently {driver.__class__.__name__} does not support using `save_folder` to " f"save best model when launch using script.") + super().on_after_trainer_initialized(trainer, driver) + + def on_sanity_check_end(self, trainer, sanity_check_res): + self.get_monitor_value(sanity_check_res) + def on_validate_end(self, trainer, results): - self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, - real_monitor=self._real_monitor, - res=results) - if (monitor_value < self.monitor_value and self.larger_better is False) or \ - (monitor_value > self.monitor_value and self.larger_better): - self.monitor_value = monitor_value + if len(results)==0: + return + monitor_value = self.get_monitor_value(results) + if self.is_better_monitor_value(monitor_value, keep_if_better=True): if self.real_save_folder: trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, model_save_fn=self.model_save_fn) diff --git a/fastNLP/core/callbacks/progress_callback.py b/fastNLP/core/callbacks/progress_callback.py index 633fbb09..756d236b 100644 --- a/fastNLP/core/callbacks/progress_callback.py +++ b/fastNLP/core/callbacks/progress_callback.py @@ -8,7 +8,7 @@ __all__ = [ 'RichCallback' ] -from .callback import Callback +from .callback import HasMonitorCallback from fastNLP.core.callbacks.utils import _get_monitor_value from fastNLP.core.utils import f_rich_progress from fastNLP.core.log import logger @@ -28,15 +28,13 @@ def choose_progress_callback(progress_bar:str): return None -class ProgressCallback(Callback): +class ProgressCallback(HasMonitorCallback): def on_train_end(self, trainer): f_rich_progress.stop() def on_sanity_check_end(self, trainer, sanity_check_res): if len(sanity_check_res) and getattr(self, 'monitor', None) is not None: - self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, - real_monitor=self._real_monitor, - res=sanity_check_res) + self.get_monitor_value(sanity_check_res) class RichCallback(ProgressCallback): @@ -46,28 +44,22 @@ class RichCallback(ProgressCallback): :param print_every: 多少个 batch 更新一次显示。 :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 - :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 + :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。如果为 None ,会尝试使用 trainer 中设置的 monitor 。 :param larger_better: 是否是monitor的结果越大越好。 :param format_json: 是否format json再打印 """ - super().__init__() + super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False) self.print_every = print_every self.progress_bar = f_rich_progress self.task2id = {} self.loss = 0 self.loss_round_ndigit = loss_round_ndigit - self.monitor = monitor - self.larger_better = larger_better - if larger_better: - self.monitor_value = float('-inf') - else: - self.monitor_value = float('inf') - self._real_monitor = monitor self.format_json = format_json def on_after_trainer_initialized(self, trainer, driver): if not self.progress_bar.disable: self.progress_bar.set_disable(flag=trainer.driver.get_local_rank() != 0) + super(RichCallback, self).on_after_trainer_initialized(trainer, driver) def on_train_begin(self, trainer): self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs, @@ -109,16 +101,12 @@ class RichCallback(ProgressCallback): text_style = '' characters = '-' if self.monitor is not None: - self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, - real_monitor=self._real_monitor, - res=results) - if (self.larger_better and monitor_value > self.monitor_value) or \ - (not self.larger_better and monitor_value < self.monitor_value): + monitor_value = self.get_monitor_value(results) + if self.is_better_monitor_value(monitor_value, keep_if_better=True): if abs(self.monitor_value) != float('inf'): rule_style = 'spring_green3' text_style = '[bold]' characters = '+' - self.monitor_value = monitor_value self.progress_bar.print() self.progress_bar.console.rule(text_style+f"Eval. results on Epoch:{trainer.cur_epoch_idx}, " f"Batch:{trainer.batch_idx_in_epoch}", @@ -151,18 +139,12 @@ class RawTextCallback(ProgressCallback): :param larger_better: 是否是monitor的结果越大越好。 :param format_json: 是否format json再打印 """ - super().__init__() + super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False) self.print_every = print_every self.task2id = {} self.loss = 0 self.loss_round_ndigit = loss_round_ndigit - self.monitor = monitor - self.larger_better = larger_better - if larger_better: - self.monitor_value = float('-inf') - else: - self.monitor_value = float('inf') - self._real_monitor = monitor + self.set_monitor(monitor, larger_better) self.format_json = format_json self.num_signs = 10 @@ -189,14 +171,10 @@ class RawTextCallback(ProgressCallback): base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' text = '' if self.monitor is not None: - self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, - real_monitor=self._real_monitor, - res=results) - if (self.larger_better and monitor_value > self.monitor_value) or \ - (not self.larger_better and monitor_value < self.monitor_value): + monitor_value = self.get_monitor_value(results) + if self.is_better_monitor_value(monitor_value, keep_if_better=True): if abs(self.monitor_value) != float('inf'): text = '+'*self.num_signs + base_text + '+'*self.num_signs - self.monitor_value = monitor_value if len(text) == 0: text = '-'*self.num_signs + base_text + '-'*self.num_signs diff --git a/fastNLP/core/callbacks/utils.py b/fastNLP/core/callbacks/utils.py index 900aebf6..2720ba3f 100644 --- a/fastNLP/core/callbacks/utils.py +++ b/fastNLP/core/callbacks/utils.py @@ -19,23 +19,31 @@ def _get_monitor_value(monitor: str, real_monitor: Optional[str], res: dict) ->( if monitor in res: return monitor, res[monitor] + if real_monitor in res: + return real_monitor, res[real_monitor] + pairs = [] for idx, (key, value) in enumerate(res.items()): - match = SequenceMatcher(None, key, monitor).find_longest_match(0, len(key), 0, len(monitor)) - pairs.append((key, value, match.size, idx)) + match_size = _match_length(monitor, key) + pairs.append((key, value, match_size, idx)) pairs.sort(key=lambda pair: (pair[2], -pair[3]), reverse=True) key, value, match_size = pairs[0][:3] - if real_monitor is not None and real_monitor in res and real_monitor != key: - # 如果 real_monitor 比新找的更长就继续用之前的。 - match = SequenceMatcher(None, real_monitor, monitor).find_longest_match(0, len(real_monitor), 0, len(monitor)) - if match.size > match_size: - return real_monitor, res[real_monitor] + return key, value + - logger.warning(f"We can not find `{monitor}` in the evaluation result (with keys as {list(res.keys())}), " - f"we use the `{key}` as the monitor.") - real_monitor = key - return real_monitor, value +def _match_length(a:str, b:str)->int: + """ + 需要把长度短的放在前面 + + :param a: + :param b: + :return: + """ + short = a if len(a) < len(b) else b + long = a if len(a)>=len(b) else b + match = SequenceMatcher(None, short, long).find_longest_match(0, len(short), 0, len(long)) + return match.size diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index bd66d0a0..865acc89 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -219,6 +219,7 @@ class Evaluator: def remove_progress_bar(self, dataloader_name): if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): f_rich_progress.destroy_task(self._rich_task_id) + f_rich_progress.refresh() # 使得最终的bar可以消失 delattr(self, '_rich_task_id') elif self.progress_bar == 'raw': desc = 'Evaluation ends' @@ -229,6 +230,7 @@ class Evaluator: def finally_progress_bar(self): if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): f_rich_progress.destroy_task(self._rich_task_id) + f_rich_progress.refresh() delattr(self, '_rich_task_id') @property diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index d710f967..b360c6a0 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -25,6 +25,7 @@ from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, matc from fastNLP.envs import rank_zero_call from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_MODEL_FILENAME +from fastNLP.core.utils.exceptions import EarlyStopException class Trainer(TrainerEventTrigger): @@ -49,6 +50,8 @@ class Trainer(TrainerEventTrigger): output_mapping: Optional[Union[Callable, Dict]] = None, accumulation_steps: int = 1, fp16: bool = False, + monitor: str = None, + larger_better: bool = True, marker: Optional[str] = None, **kwargs ): @@ -102,6 +105,10 @@ class Trainer(TrainerEventTrigger): 如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换 :param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; :param fp16: 是否开启混合精度训练;默认为 False; + :param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 + 在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + 的那个作为 monitor 。 + :param larger_better: monitor 的值是否是越大越好。 :param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; :param kwargs: 一些其它的可能需要的参数; torch_non_blocking: 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; @@ -210,6 +217,8 @@ class Trainer(TrainerEventTrigger): self.evaluator = None self.epoch_validate = lambda *args, **kwargs: ... self.step_validate = lambda *args, **kwargs: ... + self.monitor = monitor + self.larger_better = larger_better if metrics is not None and validate_dataloaders is not None: if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") @@ -239,6 +248,7 @@ class Trainer(TrainerEventTrigger): else: # validate_every > 0 self._step_validate_filter = Filter(every=validate_every) + self.metrics = metrics self.validate_every = validate_every @@ -320,6 +330,10 @@ class Trainer(TrainerEventTrigger): self.driver.barrier() self.on_train_end() self.driver.barrier() + + except EarlyStopException as e: + logger.info(f"Catch early stop exception: {e.msg}.") + self.on_exception(e) except KeyboardInterrupt as e: self.driver.on_exception() self.on_exception(e) diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index e19aa648..44cabcf4 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -599,7 +599,7 @@ class TorchDDPDriver(TorchDriver): :param group: :return: """ - return fastnlp_torch_all_gather(obj, device=self.data_device, group=group) + return fastnlp_torch_all_gather(obj, group=group) def find_free_network_port() -> str: diff --git a/fastNLP/core/drivers/torch_driver/dist_utils.py b/fastNLP/core/drivers/torch_driver/dist_utils.py index 37717f54..5e3819e7 100644 --- a/fastNLP/core/drivers/torch_driver/dist_utils.py +++ b/fastNLP/core/drivers/torch_driver/dist_utils.py @@ -1,11 +1,8 @@ import io import pickle -from typing import Mapping _pickler = pickle.Pickler _unpickler = pickle.Unpickler -from abc import ABC -from typing import Any, Union, List -import numpy as np +from typing import Any, List from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 @@ -13,103 +10,25 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: import torch from torch import distributed as dist + try: + from torch._C._distributed_c10d import ProcessGroupMPI + except ImportError: + _MPI_AVAILABLE = False + + try: + from torch._C._distributed_c10d import ProcessGroupNCCL + except ImportError: + _NCCL_AVAILABLE = False + + try: + from torch._C._distributed_c10d import ProcessGroupGloo + from torch._C._distributed_c10d import _ProcessGroupWrapper + except ImportError: + _GLOO_AVAILABLE = False from fastNLP.core.utils import apply_to_collection - -def all_gather_object(object_list, obj, group=None): - """ - Gathers picklable objects from the whole group into a list. Similar to - :func:`all_gather`, but Python objects can be passed in. Note that the object - must be picklable in order to be gathered. - - Args: - object_list (list[Any]): Output list. It should be correctly sized as the - size of the group for this collective and will contain the output. - object (Any): Pickable Python object to be broadcast from current process. - group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. Default is ``None``. - - Returns: - None. If the calling rank is part of this group, the output of the - collective will be populated into the input ``object_list``. If the - calling rank is not part of the group, the passed in ``object_list`` will - be unmodified. - - .. note:: Note that this API differs slightly from the :func:`all_gather` - collective since it does not provide an ``async_op`` handle and thus - will be a blocking call. - - .. note:: For NCCL-based processed groups, internal tensor representations - of objects must be moved to the GPU device before communication takes - place. In this case, the device used is given by - ``torch.cuda.current_device()`` and it is the user's responsiblity to - ensure that this is set so that each rank has an individual GPU, via - ``torch.cuda.set_device()``. - - .. warning:: - :func:`all_gather_object` uses ``pickle`` module implicitly, which is - known to be insecure. It is possible to construct malicious pickle data - which will execute arbitrary code during unpickling. Only call this - function with data you trust. - - Example:: - >>> # Note: Process group initialization omitted on each rank. - >>> import torch.distributed as dist - >>> # Assumes world_size of 3. - >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object - >>> output = [None for _ in gather_objects] - >>> dist.all_gather_object(output, gather_objects[dist.get_rank()]) - >>> output - ['foo', 12, {1: 2}] - """ - if dist.distributed_c10d._rank_not_in_group(group): - return - - input_tensor, local_size = _object_to_tensor(obj) - current_device = torch.device("cpu") - if dist.is_nccl_available() and isinstance( - group or dist.distributed_c10d._get_default_group(), dist.ProcessGroupNCCL - ): - # See note about using torch.cuda.current_device() here in docstring. - # We cannot simply use my_rank since rank == device is not necessarily - # true. - current_device = torch.device("cuda", torch.cuda.current_device()) - input_tensor = input_tensor.to(current_device) - local_size = local_size.to(current_device) - # Gather all local sizes. This is so that we can find the max size, and index - # until the correct size when deserializing the tensors. - group_size = dist.get_world_size(group=group) - object_sizes_tensor = torch.zeros( - group_size, dtype=torch.long, device=current_device - ) - object_size_list = [ - object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) - ] - # Allgather tensor sizes - dist.all_gather(object_size_list, local_size, group=group) - max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] - # Resize tensor to max size across all ranks. - input_tensor.resize_(max_object_size) - coalesced_output_tensor = torch.empty( - max_object_size * group_size, dtype=torch.uint8, device=current_device - ) - # Output tensors are nonoverlapping views of coalesced_output_tensor - output_tensors = [ - coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] - for i in range(group_size) - ] - dist.all_gather(output_tensors, input_tensor, group=group) - # Deserialize outputs back to object. - for i, tensor in enumerate(output_tensors): - tensor = tensor.type(torch.uint8) - if tensor.device != torch.device("cpu"): - tensor = tensor.cpu() - tensor_size = object_size_list[i] - object_list[i] = _tensor_to_object(tensor, tensor_size) - - def _validate_output_list_for_rank(my_rank, dst, gather_list): if dst == my_rank: if not gather_list: @@ -123,8 +42,10 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list): ) -def gather_object(obj, object_gather_list=None, dst=0, group=None): +def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=None): """ + 从其它 rank gather 东西到 dst rank 。 + Gathers picklable objects from the whole group in a single process. Similar to :func:`gather`, but Python objects can be passed in. Note that the object must be picklable in order to be gathered. @@ -176,6 +97,8 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None): # Ensure object_gather_list is specified appopriately. my_rank = dist.get_rank() _validate_output_list_for_rank(my_rank, dst, object_gather_list) + # 防止 unpickle 的时候出现在了发送的 gpu 上。 + obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) input_tensor, local_size = _object_to_tensor(obj) group_backend = dist.get_backend(group) current_device = torch.device("cpu") @@ -266,113 +189,11 @@ def send_recv_object(obj, src, cur_rank, device, group=None, tag=0): return _tensor_to_object(tensor.cpu(), size) -def _all_gather(obj, **kwargs): - group = kwargs.get('group', None) - if isinstance(obj, torch.Tensor): - gathered_tensor = [torch.zeros_like(obj) for _ in - range(torch.distributed.get_world_size(group=group))] - - torch.distributed.all_gather(gathered_tensor, obj, group=group) - - return gathered_tensor - - elif isinstance(obj, tuple) and isinstance(obj[1], torch.Tensor): - tensor, size = obj - # 首先需要同步 size 吧? - group_size = dist.get_world_size(group=group) - object_sizes_tensor = torch.zeros( - group_size, dtype=torch.long, device=tensor.device - ) - object_size_list = [ - object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) - ] - dist.all_gather(object_size_list, size, group=group) - max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] - # Resize tensor to max size across all ranks. - tensor.resize_(max_object_size) - coalesced_output_tensor = torch.empty( - max_object_size * group_size, dtype=torch.uint8, device=tensor.device - ) - - # Output tensors are nonoverlapping views of coalesced_output_tensor - output_tensors = [ - coalesced_output_tensor[max_object_size * i: max_object_size * (i + 1)] - for i in range(group_size) - ] - dist.all_gather(output_tensors, tensor, group=group) - object_list = [] - for i, tensor in enumerate(output_tensors): - tensor = tensor.type(torch.uint8) - tensor_size = object_size_list[i] - object_list.append(_tensor_to_object(tensor, tensor_size)) - return object_list - elif isinstance(obj, tuple) and len(obj) == 2: - obj, _type = obj - gathered_tensor = [torch.zeros_like(obj) for _ in - range(torch.distributed.get_world_size(group=group))] - - torch.distributed.all_gather(gathered_tensor, obj, group=group) - - if _type == np.ndarray: - gathered_tensor = [t.detach().cpu().numpy() for t in gathered_tensor] - else: - gathered_tensor = [_type(t.item()) for t in gathered_tensor] - - return gathered_tensor - else: - raise RuntimeError("Unsupported types to implement all_gather.") - - -class CanTransferDataType(ABC): - """ - 检测可以进行传输的对象。 - - """ - - @classmethod - def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: - if cls is CanTransferDataType: - if issubclass(subclass, Mapping): - return False - if subclass in (torch.Tensor, tuple, list, str, int, float, bool, np.ndarray): - return True - return False - return NotImplemented - - -def _tensorize(obj, device=None): - if isinstance(obj, torch.Tensor): - return obj - if isinstance(obj, bool): - return torch.tensor(obj, dtype=torch.uint8, device=device), bool - if isinstance(obj, float): - return torch.tensor(obj, dtype=torch.float, device=device), float - if isinstance(obj, int): - return torch.tensor(obj, dtype=torch.int, device=device), int - if isinstance(obj, np.ndarray): - return torch.from_numpy(obj), np.ndarray - return _object_to_tensor(obj, device) - - def _to_device(tensor, device): return tensor.contiguous().to(device) -def convert_to_tensors(data: Any, device=None) -> Any: - data = apply_to_collection(data, CanTransferDataType, _tensorize) - def _move_to_device_and_make_contiguous(t: Union[torch.Tensor, tuple], device: Union[str, torch.device]): - if isinstance(t, tuple): - if isinstance(t[1], torch.Tensor): # 说明是 object 转的 - return t[0].to(device).contiguous(), t[1].to(device) - else: # 说明第二个元素是type,见 to_dtype_tensor 函数 - return t[0].to(device).contiguous(), t[1] - return t.to(device).contiguous() - - data = apply_to_collection(data, (torch.Tensor, tuple), _move_to_device_and_make_contiguous, device=device) - return data - - -def fastnlp_torch_all_gather(obj:Any, device=None, group=None)->List: +def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List: """ 实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 @@ -390,36 +211,28 @@ def fastnlp_torch_all_gather(obj:Any, device=None, group=None)->List: {'a': 1, 'b':[1, 2], 'c':{'d': 2}} ] - :param obj: 任意结构的数据,所有的 value 都会变成 list ,其长度为 world_size ,依次为每个 rank 上的对象值 - :param device: 当前 rank 使用的 device 是哪个。为 None 的话默认使用 torch.cuda.current_device() 获取。 + :param obj: 任意结构的数据,如果为 tensor ,需要保证每个显卡上的 tensor 的形状是一样的。如果传入的是非 tensor 对象都将直接进行 + 序列化之后进行传输。 + :param device: 当前该参数无意义。 :param group: :return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 """ # # 首先将所有的都移动到cpu上并且连续,防止有 pickle 出问题 - # obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) - if device is None: - device = torch.cuda.current_device() - if _TORCH_GREATER_EQUAL_1_8: + if isinstance(obj, torch.Tensor): + objs = [torch.zeros_like(obj) for _ in range(dist.get_world_size(group))] + dist.all_gather(objs, obj, group=group) + else: objs = [None for _ in range(dist.get_world_size(group))] - dist.all_gather_object(objs, obj) - objs = apply_to_collection(objs, torch.Tensor, _to_device, device=device) # 保证如果有tensor的话,所有tensor都在当前卡上 - return objs - group = group if group is not None else torch.distributed.group.WORLD - data = convert_to_tensors(obj, device=device) - data = apply_to_collection(data, (torch.Tensor, tuple), _all_gather, group=group) - - objs = [] - - def _get_obj_on_idx(obj, idx): - return obj[idx] - - for i in range(dist.get_world_size(group)): - objs.append(apply_to_collection(data, dtype=list, function=_get_obj_on_idx, idx=i)) - + # 防止 unpickle 的时候弄到发送的 gpu 上了 + obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) + if _TORCH_GREATER_EQUAL_1_8: + dist.all_gather_object(objs, obj, group=group) + else: + objs = all_gather_object(objs, obj, group=group) return objs -def fastnlp_torch_broadcast_object(obj, src, device, group=None): +def fastnlp_torch_broadcast_object(obj, src, device=None, group=None): """ 将 src 上的 obj 对象广播到其它 rank 上。 @@ -430,10 +243,9 @@ def fastnlp_torch_broadcast_object(obj, src, device, group=None): :return: """ cur_rank = dist.get_rank(group) - # if cur_rank == src: - # # 如果有 tensor 全部移动到 cpu 上,方便 pickle - # obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) - + if cur_rank == src: + # 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里 + obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) if _TORCH_GREATER_EQUAL_1_8: if cur_rank!=src: get_obj = [None] @@ -442,6 +254,8 @@ def fastnlp_torch_broadcast_object(obj, src, device, group=None): else: dist.broadcast_object_list([obj], src=src, group=group) return obj + if device is None: + device = torch.cuda.current_device() if cur_rank == src: tensor, size = _object_to_tensor(obj, device=device) @@ -460,3 +274,107 @@ def fastnlp_torch_broadcast_object(obj, src, device, group=None): return _tensor_to_object(tensor, tensor_size=size.item()) +def _check_for_nccl_backend(group): + pg = group or dist.distributed_c10d._get_default_group() + # It is not expected for PG to be wrapped many times, but support it just + # in case + while isinstance(pg, _ProcessGroupWrapper): + pg = pg.wrapped_pg + + return ( + dist.is_nccl_available() and + isinstance(pg, dist.ProcessGroupNCCL) + ) + + +def all_gather_object(object_list, obj, group=None): + """ + 复制 pytorch 的代码,使得可以版本兼容低版本的 pytorch 。 + + Gathers picklable objects from the whole group into a list. Similar to + :func:`all_gather`, but Python objects can be passed in. Note that the object + must be picklable in order to be gathered. + + Args: + object_list (list[Any]): Output list. It should be correctly sized as the + size of the group for this collective and will contain the output. + object (Any): Pickable Python object to be broadcast from current process. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + + Returns: + None. If the calling rank is part of this group, the output of the + collective will be populated into the input ``object_list``. If the + calling rank is not part of the group, the passed in ``object_list`` will + be unmodified. + + .. note:: Note that this API differs slightly from the :func:`all_gather` + collective since it does not provide an ``async_op`` handle and thus + will be a blocking call. + + .. note:: For NCCL-based processed groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``torch.cuda.current_device()`` and it is the user's responsiblity to + ensure that this is set so that each rank has an individual GPU, via + ``torch.cuda.set_device()``. + + .. warning:: + :func:`all_gather_object` uses ``pickle`` module implicitly, which is + known to be insecure. It is possible to construct malicious pickle data + which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + Example:: + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> # Assumes world_size of 3. + >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object + >>> output = [None for _ in gather_objects] + >>> dist.all_gather_object(output, gather_objects[dist.get_rank()]) + >>> output + ['foo', 12, {1: 2}] + """ + if dist._rank_not_in_group(group): + return + + input_tensor, local_size = _object_to_tensor(obj) + current_device = torch.device("cpu") + is_nccl_backend = _check_for_nccl_backend(group) + if is_nccl_backend: + # See note about using torch.cuda.current_device() here in docstring. + # We cannot simply use my_rank since rank == device is not necessarily + # true. + current_device = torch.device("cuda", torch.cuda.current_device()) + input_tensor = input_tensor.to(current_device) + local_size = local_size.to(current_device) + # Gather all local sizes. This is so that we can find the max size, and index + # until the correct size when deserializing the tensors. + group_size = dist.get_world_size(group=group) + object_sizes_tensor = torch.zeros( + group_size, dtype=torch.long, device=current_device + ) + object_size_list = [ + object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) + ] + # Allgather tensor sizes + dist.all_gather(object_size_list, local_size, group=group) + max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] + # Resize tensor to max size across all ranks. + input_tensor.resize_(max_object_size) + coalesced_output_tensor = torch.empty( + max_object_size * group_size, dtype=torch.uint8, device=current_device + ) + # Output tensors are nonoverlapping views of coalesced_output_tensor + output_tensors = [ + coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] + for i in range(group_size) + ] + dist.all_gather(output_tensors, input_tensor, group=group) + # Deserialize outputs back to object. + for i, tensor in enumerate(output_tensors): + tensor = tensor.type(torch.uint8) + if tensor.device != torch.device("cpu"): + tensor = tensor.cpu() + tensor_size = object_size_list[i] + object_list[i] = _tensor_to_object(tensor, tensor_size) diff --git a/fastNLP/core/metrics/classify_f1_pre_rec_metric.py b/fastNLP/core/metrics/classify_f1_pre_rec_metric.py index a2a62d66..6298eae2 100644 --- a/fastNLP/core/metrics/classify_f1_pre_rec_metric.py +++ b/fastNLP/core/metrics/classify_f1_pre_rec_metric.py @@ -29,14 +29,16 @@ def _compute_f_pre_rec(beta_square, tp, fn, fp): class ClassifyFPreRecMetric(Metric): - def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = False, - tag_vocab: Vocabulary = None, encoding_type: str = None, ignore_labels: List[str] = None, - only_gross: bool = True, f_type='micro', beta=1) -> None: + def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None, num_class: int = 0, + only_gross: bool = True, f_type='micro', beta=1, backend: Union[str, Backend, None] = 'auto', + aggregate_when_get_metric: bool = False) -> None: super(ClassifyFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) if f_type not in ('micro', 'macro'): raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) - + if tag_vocab: + if not isinstance(tag_vocab, Vocabulary): + raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) self.ignore_labels = ignore_labels self.f_type = f_type self.beta = beta @@ -45,9 +47,32 @@ class ClassifyFPreRecMetric(Metric): self.tag_vocab = tag_vocab - self._tp, self._fp, self._fn = defaultdict(partial(self.register_element, aggregate_method='sum')),\ - defaultdict(partial(self.register_element, aggregate_method='sum')),\ - defaultdict(partial(self.register_element, aggregate_method='sum')) + self._tp = {} + self._fp = {} + self._fn = {} + if tag_vocab: + for word, _ in tag_vocab: + word = word.lower() + if word != 'o': + word = word[2:] + if word in self._true_positives: + continue + self._tp[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', + backend=backend) + self._fn[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', + backend=backend) + self._fp[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', + backend=backend) + elif num_class > 0: + for word in range(num_class): + self._tp[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', + backend=backend) + self._fn[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', + backend=backend) + self._fp[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', + backend=backend) + else: + raise ValueError() def get_metric(self) -> dict: r""" @@ -68,9 +93,11 @@ class ClassifyFPreRecMetric(Metric): tag_name = self.tag_vocab.to_word(tag) else: tag_name = int(tag) - tp = self._tp[tag] - fn = self._fn[tag] - fp = self._fp[tag] + tp = self._tp[tag].get_scalar() + fn = self._fn[tag].get_scalar() + fp = self._fp[tag].get_scalar() + if tp == fn == fp == 0: + continue f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp) f_sum += f pre_sum += pre @@ -90,20 +117,29 @@ class ClassifyFPreRecMetric(Metric): if self.f_type == 'micro': f, pre, rec = _compute_f_pre_rec(self.beta_square, - sum(self._tp.values()), - sum(self._fn.values()), - sum(self._fp.values())) + sum(val.get_scalar() for val in self._tp.values()), + sum(val.get_scalar() for val in self._fn.values()), + sum(val.get_scalar() for val in self._fp.values())) evaluate_result['f'] = f evaluate_result['pre'] = pre evaluate_result['rec'] = rec - for key, value in evaluate_result.items(): evaluate_result[key] = round(value, 6) return evaluate_result def update(self, pred, target, seq_len=None): + r""" + evaluate函数将针对一个批次的预测结果做评价指标的累计 + + :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), + torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) + :param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), + torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len]) + :param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]). + 如果mask也被传进来的话seq_len会被忽略. + """ pred = self.tensor2numpy(pred) target = self.tensor2numpy(target) if seq_len is not None: @@ -122,14 +158,14 @@ class ClassifyFPreRecMetric(Metric): f"pred have element numbers: {len(target.flatten())}") pass - elif len(pred.ndim) == len(target.ndim) + 1: + elif pred.ndim == target.ndim + 1: pred = pred.argmax(axis=-1) - if seq_len is None and len(target.ndim) > 1: + if seq_len is None and target.ndim > 1: warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") else: raise RuntimeError(f"when pred have " - f"size:{pred.ndim}, target should have size: {pred.ndim} or " - f"{pred.ndim[:-1]}, got {target.ndim}.") + f"size:{pred.shape}, target should have size: {pred.shape} or " + f"{pred.shape[:-1]}, got {target.shape}.") if masks is not None: target = target * masks pred = pred * masks @@ -138,5 +174,3 @@ class ClassifyFPreRecMetric(Metric): self._tp[target_idx] += ((pred == target_idx) * (target != target_idx)).sum().item() self._fp[target_idx] += ((pred == target_idx) * (target == target_idx)).sum().item() self._fn[target_idx] += ((pred != target_idx) * (target != target_idx)).sum().item() - - diff --git a/fastNLP/core/utils/exceptions.py b/fastNLP/core/utils/exceptions.py new file mode 100644 index 00000000..afedbcba --- /dev/null +++ b/fastNLP/core/utils/exceptions.py @@ -0,0 +1,10 @@ + +class EarlyStopException(BaseException): + r""" + 用于EarlyStop时从Trainer训练循环中跳出。 + + """ + + def __init__(self, msg): + super(EarlyStopException, self).__init__(msg) + self.msg = msg diff --git a/fastNLP/core/utils/rich_progress.py b/fastNLP/core/utils/rich_progress.py index 256cc906..a865f4c1 100644 --- a/fastNLP/core/utils/rich_progress.py +++ b/fastNLP/core/utils/rich_progress.py @@ -94,9 +94,6 @@ class FRichProgress(Progress, metaclass=Singleton): self.print = self.console.print self.log = self.console.log - # start new - self.start() - self.console.show_cursor(show=True) return self def set_transient(self, transient: bool = True): @@ -154,6 +151,7 @@ class FRichProgress(Progress, metaclass=Singleton): super().start() self.console.show_cursor(show=True) + if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: f_rich_progress = FRichProgress().new_progess( "[progress.description]{task.description}", diff --git a/tests/core/callbacks/test_utils.py b/tests/core/callbacks/test_utils.py index 10aba0e0..fdec93e0 100644 --- a/tests/core/callbacks/test_utils.py +++ b/tests/core/callbacks/test_utils.py @@ -12,32 +12,27 @@ def test_get_monitor_value(): with Capturing() as output: monitor, value = _get_monitor_value(monitor='f1', real_monitor=None, res=res) assert monitor == 'f1' and value==0.2 - assert 'We can not find' not in output[0] # 测试可以匹配,且选择更靠前的 res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} with Capturing() as output: monitor, value = _get_monitor_value(monitor='f1', real_monitor=None, res=res) assert monitor=='acc#f1' and value==0.2 - assert 'We can not find' in output[0] # 测试monitor匹配不上,使用real_monitor res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} with Capturing() as output: - monitor, value = _get_monitor_value(monitor='acc#f', real_monitor='acc#rec', res=res) + monitor, value = _get_monitor_value(monitor='acc', real_monitor='acc#rec', res=res) assert monitor=='acc#rec' and value==0.3 - assert 'We can not find' not in output[0] # 测试monitor/real_monitor匹配不上, 重新选择 res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} with Capturing() as output: monitor, value = _get_monitor_value(monitor='acc#f', real_monitor='acc#r', res=res) assert monitor=='acc#f1' and value==0.2 - assert 'We can not find' in output[0] # 测试partial的位置 res = {"acc#acc": 0.52, "loss#loss": 2} with Capturing() as output: monitor, value = _get_monitor_value(monitor='-loss', real_monitor=None, res=res) assert monitor=='loss#loss' and value==2 - assert 'We can not find' in output[0] diff --git a/tests/core/drivers/torch_driver/test_dist_utils.py b/tests/core/drivers/torch_driver/test_dist_utils.py index 8fb7eb34..2d2145c8 100644 --- a/tests/core/drivers/torch_driver/test_dist_utils.py +++ b/tests/core/drivers/torch_driver/test_dist_utils.py @@ -7,38 +7,10 @@ import numpy as np # print(isinstance((1,), tuple)) # exit() -from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, convert_to_tensors, fastnlp_torch_broadcast_object +from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object from tests.helpers.utils import re_run_current_cmd_for_torch, magic_argv_env_context - -def test_convert_to_tensors(): - local_rank = 0 - obj = { - 'tensor': torch.full(size=(2,), fill_value=local_rank), - 'numpy': np.full(shape=(1,), fill_value=local_rank), - 'bool': local_rank % 2 == 0, - 'float': local_rank + 0.1, - 'int': local_rank, - 'dict': { - 'rank': local_rank - }, - 'list': [local_rank] * 2, - 'str': 'xxx' - } - data = convert_to_tensors(obj) - assert len(data) == len(obj) - assert (data['tensor'] == obj['tensor']).sum() == 2 - for name in ['list', 'str']: - assert len(data[name])==2 and isinstance(data[name][0], torch.Tensor) and \ - isinstance(data[name][1], torch.Tensor) and data[name][1].ndim==1 - - for name in ['numpy', 'bool', 'float', 'int']: - assert isinstance(data[name][0], torch.Tensor) and data[name][0].numel()==1 - - assert isinstance(data['dict']['rank'][0], torch.Tensor) and data[name][0].numel() == 1 - - @magic_argv_env_context def test_fastnlp_torch_all_gather(): os.environ['MASTER_ADDR'] = '127.0.0.1' @@ -66,7 +38,7 @@ def test_fastnlp_torch_all_gather(): 'tensors': [torch.full(size=(2,), fill_value=local_rank).cuda(), torch.full(size=(2,), fill_value=local_rank).cuda()] } - data = fastnlp_torch_all_gather(obj, device=torch.cuda.current_device()) + data = fastnlp_torch_all_gather(obj) world_size = int(os.environ['WORLD_SIZE']) assert len(data) == world_size for i in range(world_size): @@ -81,10 +53,12 @@ def test_fastnlp_torch_all_gather(): assert data[i]['tensors'][0][0] == i for obj in [1, True, 'xxx']: - data = fastnlp_torch_all_gather(obj, device=torch.cuda.current_device()) + data = fastnlp_torch_all_gather(obj) assert len(data)==world_size assert data[0]==data[1] + dist.destroy_process_group() + @magic_argv_env_context def test_fastnlp_torch_broadcast_object(): os.environ['MASTER_ADDR'] = '127.0.0.1' @@ -130,3 +104,4 @@ def test_fastnlp_torch_broadcast_object(): for obj in [int(os.environ['LOCAL_RANK']), bool(os.environ['LOCAL_RANK']=='1'), os.environ['LOCAL_RANK']]: data = fastnlp_torch_broadcast_object(obj, src=0, device=torch.cuda.current_device()) assert int(data)==0 + dist.destroy_process_group() diff --git a/tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py b/tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py new file mode 100644 index 00000000..c9174e41 --- /dev/null +++ b/tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py @@ -0,0 +1,88 @@ +import pytest +import torch +import numpy as np + +from fastNLP.core.metrics import ClassifyFPreRecMetric + + +class TestClassfiyFPreRecMetric: + def test_case_1(self): + pred = torch.tensor([[-0.4375, -0.1779, -1.0985, -1.1592, 0.4910], + [1.3410, 0.2889, -0.8667, -1.8580, 0.3029], + [0.7459, -1.1957, 0.3231, 0.0308, -0.1847], + [1.1439, -0.0057, 0.8203, 0.0312, -1.0051], + [-0.4870, 0.3215, -0.8290, 0.9221, 0.4683], + [0.9078, 1.0674, -0.5629, 0.3895, 0.8917], + [-0.7743, -0.4041, -0.9026, 0.2112, 1.0892], + [1.8232, -1.4188, -2.5615, -2.4187, 0.5907], + [-1.0592, 0.4164, -0.1192, 1.4238, -0.9258], + [-1.1137, 0.5773, 2.5778, 0.5398, -0.3323], + [-0.3868, -0.5165, 0.2286, -1.3876, 0.5561], + [-0.3304, 1.3619, -1.5744, 0.4902, -0.7661], + [1.8387, 0.5234, 0.4269, 1.3748, -1.2793], + [0.6692, 0.2571, 1.2425, -0.5894, -0.0184], + [0.4165, 0.4084, -0.1280, 1.4489, -2.3058], + [-0.5826, -0.5469, 1.5898, -0.2786, -0.9882], + [-1.5548, -2.2891, 0.2983, -1.2145, -0.1947], + [-0.7222, 2.3543, -0.5801, -0.0640, -1.5614], + [-1.4978, 1.9297, -1.3652, -0.2358, 2.5566], + [0.1561, -0.0316, 0.9331, 1.0363, 2.3949], + [0.2650, -0.8459, 1.3221, 0.1321, -1.1900], + [0.0664, -1.2353, -0.5242, -1.4491, 1.3300], + [-0.2744, 0.0941, 0.7157, 0.1404, 1.2046], + [0.9341, -0.6652, 1.4512, 0.9608, -0.3623], + [-1.1641, 0.0873, 0.1163, -0.2068, -0.7002], + [1.4775, -2.0025, -0.5634, -0.1589, 0.0247], + [1.0151, 1.0304, -0.1042, -0.6955, -0.0629], + [-0.3119, -0.4558, 0.7757, 0.0758, -1.6297], + [1.0654, 0.0313, -0.7716, 0.1194, 0.6913], + [-0.8088, -0.6648, -0.5018, -0.0230, -0.8207], + [-0.7753, -0.3508, 1.6163, 0.7158, 1.5207], + [0.8692, 0.7718, -0.6734, 0.6515, 0.0641]]) + arg_max_pred = torch.argmax(pred, dim=-1) + target = torch.tensor([0, 2, 4, 1, 4, 0, 1, 3, 3, 3, 1, 3, 4, 4, 3, 4, 0, 2, 4, 4, 3, 4, 4, 3, + 0, 3, 0, 0, 0, 1, 3, 1]) + + metric = ClassifyFPreRecMetric(f_type='macro', num_class=5) + metric.update(pred, target) + result_dict = metric.get_metric() + f1_score = 0.1882051282051282 + recall = 0.1619047619047619 + pre = 0.23928571428571427 + + ground_truth = {'f': f1_score, 'pre': pre, 'rec': recall} + for keys in ['f', 'pre', 'rec']: + np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001) + + metric = ClassifyFPreRecMetric(f_type='micro', num_class=5) + metric.update(pred, target) + result_dict = metric.get_metric() + f1_score = 0.21875 + recall = 0.21875 + pre = 0.21875 + + ground_truth = {'f': f1_score, 'pre': pre, 'rec': recall} + for keys in ['f', 'pre', 'rec']: + np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001) + + metric = ClassifyFPreRecMetric(only_gross=False, f_type='macro', num_class=5) + metric.update(pred, target) + result_dict = metric.get_metric() + ground_truth = { + '0': {'f1-score': 0.13333333333333333, 'precision': 0.125, 'recall': 0.14285714285714285, 'support': 7}, + '1': {'f1-score': 0.0, 'precision': 0.0, 'recall': 0.0, 'support': 5}, + '2': {'f1-score': 0.0, 'precision': 0.0, 'recall': 0.0, 'support': 2}, + '3': {'f1-score': 0.30769230769230765, 'precision': 0.5, 'recall': 0.2222222222222222, 'support': 9}, + '4': {'f1-score': 0.5, 'precision': 0.5714285714285714, 'recall': 0.4444444444444444, 'support': 9}, + 'macro avg': {'f1-score': 0.1882051282051282, 'precision': 0.23928571428571427, + 'recall': 0.1619047619047619, 'support': 32}, + 'micro avg': {'f1-score': 0.21875, 'precision': 0.21875, 'recall': 0.21875, 'support': 32}, + 'weighted avg': {'f1-score': 0.2563301282051282, 'precision': 0.3286830357142857, 'recall': 0.21875, + 'support': 32}} + for keys in result_dict.keys(): + if keys == "f" or "pre" or "rec": + continue + gl = str(keys[-1]) + tmp_d = {"p": "precision", "r": "recall", "f": "f1-score"} + gk = tmp_d[keys[0]] + np.allclose(result_dict[keys], ground_truth[gl][gk], atol=0.000001)