@@ -10,7 +10,8 @@ __all__ = [ | |||
'ProgressCallback', | |||
'RichCallback', | |||
"LRSchedCallback", | |||
'LoadBestModelCallback' | |||
'LoadBestModelCallback', | |||
"EarlyStopCallback" | |||
] | |||
@@ -21,4 +22,5 @@ from .checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallb | |||
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback | |||
from .lr_scheduler_callback import LRSchedCallback | |||
from .load_best_model_callback import LoadBestModelCallback | |||
from .early_stop_callback import EarlyStopCallback | |||
@@ -1,11 +1,15 @@ | |||
from typing import Union, Callable, Dict, Optional | |||
from typing import Union, Callable, Dict, Optional, Any | |||
from abc import ABC | |||
__all__ = [ | |||
'Callback', | |||
] | |||
from .callback_events import Events, EventsList, Filter | |||
from .utils import _get_monitor_value | |||
from fastNLP.core.callbacks.callback_events import _SingleEventState | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.utils import apply_to_collection | |||
class Callback: | |||
@@ -150,4 +154,82 @@ class _CallbackWrapper(Callback): | |||
return self.fn.__name__ | |||
class CanItemDataType(ABC): | |||
""" | |||
检测可以进行传输的对象。 | |||
""" | |||
@classmethod | |||
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: | |||
if cls is CanItemDataType: | |||
item = getattr(subclass, 'item', None) | |||
return callable(item) | |||
return NotImplemented | |||
class HasMonitorCallback(Callback): | |||
def __init__(self, monitor, larger_better, must_have_monitor=False): | |||
self.set_monitor(monitor, larger_better) | |||
self.must_have_moinitor = must_have_monitor | |||
def set_monitor(self, monitor, larger_better): | |||
self.monitor = str(monitor) if monitor is not None else None | |||
self.larger_better = bool(larger_better) | |||
if larger_better: | |||
self.monitor_value = float('-inf') | |||
else: | |||
self.monitor_value = float('inf') | |||
self._real_monitor = self.monitor | |||
def on_after_trainer_initialized(self, trainer, driver): | |||
""" | |||
如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。 | |||
同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。 | |||
:param trainer: | |||
:param driver: | |||
:return: | |||
""" | |||
if self.monitor is None and trainer.monitor is not None: | |||
self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better) | |||
if self.must_have_moinitor and self.monitor is None: | |||
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " | |||
f"You can set it in the initialization or through Trainer.") | |||
def get_monitor_value(self, results:Dict)->float: | |||
""" | |||
获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。 | |||
:param results: | |||
:return: | |||
""" | |||
if len(results)==0: | |||
return 0 | |||
# 保证所有的 tensor 都被转换为了 python 特定的类型 | |||
results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) | |||
use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | |||
real_monitor=self._real_monitor, | |||
res=results) | |||
if self._real_monitor != use_monitor: # 发生了替换需要打印 | |||
logger.warning( | |||
f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " | |||
f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.") | |||
self._real_monitor = use_monitor | |||
return monitor_value | |||
def is_better_monitor_value(self, monitor_value: float, keep_if_better=True): | |||
""" | |||
检测 monitor_value 是否是更好的 | |||
:param monitor_value: | |||
:param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 | |||
:return: | |||
""" | |||
better = False | |||
if (self.larger_better and monitor_value > self.monitor_value) or \ | |||
(not self.larger_better and monitor_value < self.monitor_value): | |||
better = True | |||
if keep_if_better: | |||
self.monitor_value = monitor_value | |||
return better |
@@ -5,12 +5,12 @@ __all__ = [ | |||
import os | |||
from typing import Union, Optional, Callable, Dict, Sequence, Any, Mapping | |||
from pathlib import Path | |||
from abc import ABC | |||
import sys | |||
from copy import deepcopy | |||
import fastNLP | |||
from .callback import Callback, Filter | |||
from .callback import Callback, HasMonitorCallback | |||
from fastNLP.core.callbacks.utils import _get_monitor_value | |||
from fastNLP.core.log import logger | |||
from fastNLP.envs import FASTNLP_LAUNCH_TIME | |||
@@ -18,22 +18,7 @@ from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir | |||
from fastNLP.core.utils import apply_to_collection | |||
class CanItemDataType(ABC): | |||
""" | |||
检测可以进行传输的对象。 | |||
""" | |||
@classmethod | |||
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: | |||
if cls is CanItemDataType: | |||
item = getattr(subclass, 'item', None) | |||
return callable(item) | |||
return NotImplemented | |||
class CheckpointCallback(Callback): | |||
class CheckpointCallback(HasMonitorCallback): | |||
def __init__( | |||
self, | |||
monitor, | |||
@@ -48,12 +33,8 @@ class CheckpointCallback(Callback): | |||
model_save_fn: Optional[Callable] = None, | |||
**kwargs, | |||
): | |||
if monitor is None and save_topk is not None: | |||
raise ValueError("Parameter `monitor` must be set when you want to use 'save_topk'.") | |||
if monitor is not None and not isinstance(monitor, str): | |||
raise ValueError("Parameter `monitor` should be of 'str' type.") | |||
super().__init__(monitor=monitor, larger_better=larger_better, | |||
must_have_monitor=save_topk is not None) | |||
if save_folder is None: | |||
logger.warning( | |||
"Parameter `path` is None, and we will use the current work directory to find and load your model.") | |||
@@ -91,13 +72,12 @@ class CheckpointCallback(Callback): | |||
"`BaseException` type.") | |||
else: | |||
save_on_exception = [] | |||
self.monitor = monitor | |||
self.save_folder = Path(save_folder) | |||
self.save_every_n_epochs = save_every_n_epochs | |||
self.save_every_n_batches = save_every_n_batches | |||
self.save_last = save_last | |||
self.save_topk = save_topk | |||
self.larger_better = larger_better | |||
self.only_state_dict = only_state_dict | |||
self.model_save_fn = model_save_fn | |||
self.save_on_exception = save_on_exception | |||
@@ -107,20 +87,22 @@ class CheckpointCallback(Callback): | |||
self._topk_model = {} | |||
self._topn = 0 # 表示目前已经保存了几个最好的模型; | |||
# 因为我们在 `_get_validate_metric` 函数中,当在返回的 `validate_res` 字典中找不到 `monitor` 时,是使用匹配找到的 | |||
# key 对应的 value 当做结果;但是这样存在的一个问题在于如果用户传入的 metric 返回的 sub_metric 的名字可能会混淆,并且其在下一次 | |||
# 训练的代码中修改了这些 sub_metric 返回的顺序,那么就会导致模糊匹配拿到的 key 和 value 与之前的不是同一个,这显然不是合理的行为; | |||
# 因此我们通过该变量来表示我们通过模糊匹配拿到的 key; | |||
self._real_monitor = self.monitor | |||
# 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候, | |||
# FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的; | |||
self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) | |||
# 我们只需要保证这个创建文件夹的操作只在进程 0 上进行即可;因为后续的实际的保存操作,其它进程实际并不会去执行; | |||
synchronize_mkdir(self.timestamp_path) | |||
def on_validate_end(self, trainer, validate_res): | |||
self._save_topk(trainer, validate_res) | |||
def on_after_trainer_initialized(self, trainer, driver): | |||
if self.save_topk is not None: | |||
super().on_after_trainer_initialized(trainer, driver) | |||
if self.save_topk is not None and trainer.evaluator is None: | |||
logger.warning("You set `save_topk`, but `validate_dataloaders` is not set in Trainer.") | |||
def on_validate_end(self, trainer, results): | |||
if len(results) == 0: | |||
return | |||
self._save_topk(trainer, results) | |||
def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): | |||
if trainer.cur_epoch_idx % self.save_every_n_epochs == 0: | |||
@@ -143,7 +125,7 @@ class CheckpointCallback(Callback): | |||
def on_sanity_check_end(self, trainer, sanity_check_res): | |||
# 主要核对一下 monitor 是否存在。 | |||
self._get_validate_metric(sanity_check_res) | |||
self.get_monitor_value(results=sanity_check_res) | |||
def on_save_checkpoint(self, trainer) -> Dict: | |||
""" | |||
@@ -154,8 +136,7 @@ class CheckpointCallback(Callback): | |||
states = {} | |||
states['timestamp_path'] = str(self.timestamp_path.absolute()) | |||
states['_topk_model'] = apply_to_collection(self._topk_model, dtype=CanItemDataType, | |||
function=lambda x:x.item()) | |||
states['_topk_model'] = deepcopy(self._topk_model) | |||
states['save_topk'] = 0 if self.save_topk is None else self.save_topk | |||
states['_real_monitor'] = self._real_monitor | |||
return states | |||
@@ -176,30 +157,30 @@ class CheckpointCallback(Callback): | |||
self._topk_model.update(self._topk_model) | |||
self._real_monitor = states["real_monitor"] | |||
def _save_topk(self, trainer: "fastNLP.Trainer", validate_res: Dict): | |||
def _save_topk(self, trainer: "fastNLP.Trainer", results: Dict): | |||
""" | |||
根据validate_res决定保存哪些model的函数。会自动移除掉不满足topk的文件夹。 | |||
:param trainer: | |||
:param validate_res: | |||
:param results: | |||
:return: | |||
""" | |||
if self.save_topk is not None: | |||
_metric_value = self._get_validate_metric(validate_res) | |||
monitor_value = self.get_monitor_value(results=results) | |||
folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ | |||
f"-{self._real_monitor}_{_metric_value}" | |||
f"-{self._real_monitor}_{monitor_value}" | |||
_should_save = False | |||
if self._topn < self.save_topk: | |||
self._topk_model[folder_name] = _metric_value | |||
self._topk_model[folder_name] = monitor_value | |||
self._topn += 1 | |||
_should_save = True | |||
else: | |||
_least_valuable_model = (min if self.larger_better else max)(self._topk_model, | |||
key=lambda x: self._topk_model[x]) | |||
if (self.larger_better and _metric_value > self._topk_model[_least_valuable_model]) or \ | |||
(self.larger_better is False and _metric_value < self._topk_model[_least_valuable_model]): | |||
self._topk_model[folder_name] = _metric_value | |||
if (self.larger_better and monitor_value > self._topk_model[_least_valuable_model]) or \ | |||
(self.larger_better is False and monitor_value < self._topk_model[_least_valuable_model]): | |||
self._topk_model[folder_name] = monitor_value | |||
_should_save = True | |||
self._topk_model.pop(_least_valuable_model) | |||
synchronize_safe_rm(self.timestamp_path.joinpath(_least_valuable_model)) | |||
@@ -235,7 +216,11 @@ class CheckpointCallback(Callback): | |||
:return: | |||
""" | |||
use_monitor, value = _get_monitor_value(monitor=self.monitor, real_monitor=self._real_monitor, res=res) | |||
if self._real_monitor != use_monitor: | |||
logger.warning(f"We can not find `{self._real_monitor}` in the evaluation result (with keys as {list(res.keys())}), " | |||
f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.") | |||
self._real_monitor = use_monitor | |||
return value | |||
@property | |||
@@ -263,7 +248,7 @@ class ModelCheckpointCallback(CheckpointCallback): | |||
若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 | |||
:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
的那个作为 monitor 。 | |||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。 | |||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | |||
@@ -310,7 +295,7 @@ class TrainerCheckpointCallback(CheckpointCallback): | |||
若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 | |||
:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
的那个作为 monitor 。 | |||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。 | |||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | |||
@@ -0,0 +1,61 @@ | |||
__all__ = [ | |||
'EarlyStopCallback' | |||
] | |||
from typing import Dict | |||
from .callback import HasMonitorCallback | |||
from fastNLP.core.utils.exceptions import EarlyStopException | |||
class EarlyStopCallback(HasMonitorCallback): | |||
def __init__(self, monitor:str=None, larger_better:bool=True, patience:int=10): | |||
""" | |||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。 | |||
:param larger_better: monitor 的值是否是越大越好。 | |||
:param patience: 多少次 validate 不没有提升就停止。 | |||
""" | |||
super(EarlyStopCallback, self).__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True) | |||
self.wait = 0 | |||
self.patience = patience | |||
def on_validate_end(self, trainer, results): | |||
if len(results)==0: | |||
return | |||
monitor_value = self.get_monitor_value(results) | |||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||
self.wait = 0 | |||
else: | |||
self.wait += 1 | |||
def on_fetch_data_begin(self, trainer): | |||
# 当是 step validate 的时候,下一步执行的就是这个, 所以在这里检查。 | |||
if self.wait >= self.patience: | |||
raise EarlyStopException(f"After {self.wait} validations, no improvement for " | |||
f"metric `{self._real_monitor}`") | |||
def on_train_epoch_begin(self, trainer): | |||
# 当是 epoch validate 的时候,下一步执行的就是这个, 所以在这里检查。 | |||
if self.wait >= self.patience: | |||
raise EarlyStopException(f"After {self.wait} validations, no improvement for " | |||
f"metric `{self._real_monitor}`(best value: {self.monitor_value})") | |||
def on_save_checkpoint(self, trainer) -> Dict: | |||
states = { | |||
'patience': self.patience, | |||
'wait': self.wait, | |||
'monitor': self.monitor, | |||
'monitor_value': self.monitor_value | |||
} | |||
return states | |||
def on_load_checkpoint(self, trainer, states): | |||
self.patience = states['patience'] | |||
self.wait = states['wait'] | |||
self.monitor = states['monitor'] | |||
self.monitor_value = float(states['monitor_value']) | |||
def callback_name(self): | |||
return f'EarlyStopCallback#monitor-{self.monitor}#patience-{self.patience}' | |||
@@ -4,8 +4,7 @@ __all__ = [ | |||
import os | |||
from typing import Optional, Callable | |||
from .callback import Callback | |||
from .utils import _get_monitor_value | |||
from .callback import HasMonitorCallback | |||
from io import BytesIO | |||
import shutil | |||
@@ -14,15 +13,15 @@ from fastNLP.core.log import logger | |||
from fastNLP.envs import all_rank_call | |||
class LoadBestModelCallback(Callback): | |||
def __init__(self, monitor:str, larger_better:bool = True, only_state_dict:bool = True, | |||
class LoadBestModelCallback(HasMonitorCallback): | |||
def __init__(self, monitor:str=None, larger_better:bool = True, only_state_dict:bool = True, | |||
save_folder:Optional[str] = None, model_save_fn:Optional[Callable] = None, | |||
model_load_fn:Optional[Callable] = None, | |||
delete_after_train:bool = True): | |||
""" | |||
保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 | |||
:param str monitor: 监控的 metric 值。 | |||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。 | |||
:param larger_better: 该 metric 值是否是越大越好。 | |||
:param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | |||
不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | |||
@@ -33,6 +32,7 @@ class LoadBestModelCallback(Callback): | |||
请在函数内完成对模型的加载。 | |||
:param delete_after_train: 在训练结束后是否删掉模型。 | |||
""" | |||
super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True) | |||
if model_load_fn is not None: | |||
assert callable(model_load_fn), "`model_load_fn` must be a callable object." | |||
assert model_save_fn is not None, "`model_load_fn` and `model_save_fn` must be passed at the same time." | |||
@@ -56,15 +56,11 @@ class LoadBestModelCallback(Callback): | |||
self.real_save_folder = None | |||
self.buffer = BytesIO() | |||
self.monitor = monitor | |||
self.larger_better = larger_better | |||
self.save_folder = save_folder | |||
self.only_state_dict = only_state_dict | |||
self.model_save_fn = model_save_fn | |||
self.model_load_fn = model_load_fn | |||
self.delete_after_after = delete_after_train | |||
self._real_monitor = None | |||
self.monitor_value = float('-inf') if larger_better else float('inf') | |||
def on_after_trainer_initialized(self, trainer, driver): | |||
if self.save_folder is not None and driver.is_distributed() and int(os.environ.get(FASTNLP_BACKEND_LAUNCH, 0))==1: | |||
@@ -76,13 +72,16 @@ class LoadBestModelCallback(Callback): | |||
raise RuntimeError(f"Currently {driver.__class__.__name__} does not support using `save_folder` to " | |||
f"save best model when launch using script.") | |||
super().on_after_trainer_initialized(trainer, driver) | |||
def on_sanity_check_end(self, trainer, sanity_check_res): | |||
self.get_monitor_value(sanity_check_res) | |||
def on_validate_end(self, trainer, results): | |||
self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | |||
real_monitor=self._real_monitor, | |||
res=results) | |||
if (monitor_value < self.monitor_value and self.larger_better is False) or \ | |||
(monitor_value > self.monitor_value and self.larger_better): | |||
self.monitor_value = monitor_value | |||
if len(results)==0: | |||
return | |||
monitor_value = self.get_monitor_value(results) | |||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||
if self.real_save_folder: | |||
trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | |||
model_save_fn=self.model_save_fn) | |||
@@ -8,7 +8,7 @@ __all__ = [ | |||
'RichCallback' | |||
] | |||
from .callback import Callback | |||
from .callback import HasMonitorCallback | |||
from fastNLP.core.callbacks.utils import _get_monitor_value | |||
from fastNLP.core.utils import f_rich_progress | |||
from fastNLP.core.log import logger | |||
@@ -28,15 +28,13 @@ def choose_progress_callback(progress_bar:str): | |||
return None | |||
class ProgressCallback(Callback): | |||
class ProgressCallback(HasMonitorCallback): | |||
def on_train_end(self, trainer): | |||
f_rich_progress.stop() | |||
def on_sanity_check_end(self, trainer, sanity_check_res): | |||
if len(sanity_check_res) and getattr(self, 'monitor', None) is not None: | |||
self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | |||
real_monitor=self._real_monitor, | |||
res=sanity_check_res) | |||
self.get_monitor_value(sanity_check_res) | |||
class RichCallback(ProgressCallback): | |||
@@ -46,28 +44,22 @@ class RichCallback(ProgressCallback): | |||
:param print_every: 多少个 batch 更新一次显示。 | |||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | |||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 | |||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。如果为 None ,会尝试使用 trainer 中设置的 monitor 。 | |||
:param larger_better: 是否是monitor的结果越大越好。 | |||
:param format_json: 是否format json再打印 | |||
""" | |||
super().__init__() | |||
super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False) | |||
self.print_every = print_every | |||
self.progress_bar = f_rich_progress | |||
self.task2id = {} | |||
self.loss = 0 | |||
self.loss_round_ndigit = loss_round_ndigit | |||
self.monitor = monitor | |||
self.larger_better = larger_better | |||
if larger_better: | |||
self.monitor_value = float('-inf') | |||
else: | |||
self.monitor_value = float('inf') | |||
self._real_monitor = monitor | |||
self.format_json = format_json | |||
def on_after_trainer_initialized(self, trainer, driver): | |||
if not self.progress_bar.disable: | |||
self.progress_bar.set_disable(flag=trainer.driver.get_local_rank() != 0) | |||
super(RichCallback, self).on_after_trainer_initialized(trainer, driver) | |||
def on_train_begin(self, trainer): | |||
self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs, | |||
@@ -109,16 +101,12 @@ class RichCallback(ProgressCallback): | |||
text_style = '' | |||
characters = '-' | |||
if self.monitor is not None: | |||
self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | |||
real_monitor=self._real_monitor, | |||
res=results) | |||
if (self.larger_better and monitor_value > self.monitor_value) or \ | |||
(not self.larger_better and monitor_value < self.monitor_value): | |||
monitor_value = self.get_monitor_value(results) | |||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||
if abs(self.monitor_value) != float('inf'): | |||
rule_style = 'spring_green3' | |||
text_style = '[bold]' | |||
characters = '+' | |||
self.monitor_value = monitor_value | |||
self.progress_bar.print() | |||
self.progress_bar.console.rule(text_style+f"Eval. results on Epoch:{trainer.cur_epoch_idx}, " | |||
f"Batch:{trainer.batch_idx_in_epoch}", | |||
@@ -151,18 +139,12 @@ class RawTextCallback(ProgressCallback): | |||
:param larger_better: 是否是monitor的结果越大越好。 | |||
:param format_json: 是否format json再打印 | |||
""" | |||
super().__init__() | |||
super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False) | |||
self.print_every = print_every | |||
self.task2id = {} | |||
self.loss = 0 | |||
self.loss_round_ndigit = loss_round_ndigit | |||
self.monitor = monitor | |||
self.larger_better = larger_better | |||
if larger_better: | |||
self.monitor_value = float('-inf') | |||
else: | |||
self.monitor_value = float('inf') | |||
self._real_monitor = monitor | |||
self.set_monitor(monitor, larger_better) | |||
self.format_json = format_json | |||
self.num_signs = 10 | |||
@@ -189,14 +171,10 @@ class RawTextCallback(ProgressCallback): | |||
base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' | |||
text = '' | |||
if self.monitor is not None: | |||
self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | |||
real_monitor=self._real_monitor, | |||
res=results) | |||
if (self.larger_better and monitor_value > self.monitor_value) or \ | |||
(not self.larger_better and monitor_value < self.monitor_value): | |||
monitor_value = self.get_monitor_value(results) | |||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||
if abs(self.monitor_value) != float('inf'): | |||
text = '+'*self.num_signs + base_text + '+'*self.num_signs | |||
self.monitor_value = monitor_value | |||
if len(text) == 0: | |||
text = '-'*self.num_signs + base_text + '-'*self.num_signs | |||
@@ -19,23 +19,31 @@ def _get_monitor_value(monitor: str, real_monitor: Optional[str], res: dict) ->( | |||
if monitor in res: | |||
return monitor, res[monitor] | |||
if real_monitor in res: | |||
return real_monitor, res[real_monitor] | |||
pairs = [] | |||
for idx, (key, value) in enumerate(res.items()): | |||
match = SequenceMatcher(None, key, monitor).find_longest_match(0, len(key), 0, len(monitor)) | |||
pairs.append((key, value, match.size, idx)) | |||
match_size = _match_length(monitor, key) | |||
pairs.append((key, value, match_size, idx)) | |||
pairs.sort(key=lambda pair: (pair[2], -pair[3]), reverse=True) | |||
key, value, match_size = pairs[0][:3] | |||
if real_monitor is not None and real_monitor in res and real_monitor != key: | |||
# 如果 real_monitor 比新找的更长就继续用之前的。 | |||
match = SequenceMatcher(None, real_monitor, monitor).find_longest_match(0, len(real_monitor), 0, len(monitor)) | |||
if match.size > match_size: | |||
return real_monitor, res[real_monitor] | |||
return key, value | |||
logger.warning(f"We can not find `{monitor}` in the evaluation result (with keys as {list(res.keys())}), " | |||
f"we use the `{key}` as the monitor.") | |||
real_monitor = key | |||
return real_monitor, value | |||
def _match_length(a:str, b:str)->int: | |||
""" | |||
需要把长度短的放在前面 | |||
:param a: | |||
:param b: | |||
:return: | |||
""" | |||
short = a if len(a) < len(b) else b | |||
long = a if len(a)>=len(b) else b | |||
match = SequenceMatcher(None, short, long).find_longest_match(0, len(short), 0, len(long)) | |||
return match.size | |||
@@ -219,6 +219,7 @@ class Evaluator: | |||
def remove_progress_bar(self, dataloader_name): | |||
if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): | |||
f_rich_progress.destroy_task(self._rich_task_id) | |||
f_rich_progress.refresh() # 使得最终的bar可以消失 | |||
delattr(self, '_rich_task_id') | |||
elif self.progress_bar == 'raw': | |||
desc = 'Evaluation ends' | |||
@@ -229,6 +230,7 @@ class Evaluator: | |||
def finally_progress_bar(self): | |||
if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): | |||
f_rich_progress.destroy_task(self._rich_task_id) | |||
f_rich_progress.refresh() | |||
delattr(self, '_rich_task_id') | |||
@property | |||
@@ -25,6 +25,7 @@ from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, matc | |||
from fastNLP.envs import rank_zero_call | |||
from fastNLP.core.log import logger | |||
from fastNLP.envs import FASTNLP_MODEL_FILENAME | |||
from fastNLP.core.utils.exceptions import EarlyStopException | |||
class Trainer(TrainerEventTrigger): | |||
@@ -49,6 +50,8 @@ class Trainer(TrainerEventTrigger): | |||
output_mapping: Optional[Union[Callable, Dict]] = None, | |||
accumulation_steps: int = 1, | |||
fp16: bool = False, | |||
monitor: str = None, | |||
larger_better: bool = True, | |||
marker: Optional[str] = None, | |||
**kwargs | |||
): | |||
@@ -102,6 +105,10 @@ class Trainer(TrainerEventTrigger): | |||
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换 | |||
:param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; | |||
:param fp16: 是否开启混合精度训练;默认为 False; | |||
:param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 | |||
在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
的那个作为 monitor 。 | |||
:param larger_better: monitor 的值是否是越大越好。 | |||
:param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; | |||
:param kwargs: 一些其它的可能需要的参数; | |||
torch_non_blocking: 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; | |||
@@ -210,6 +217,8 @@ class Trainer(TrainerEventTrigger): | |||
self.evaluator = None | |||
self.epoch_validate = lambda *args, **kwargs: ... | |||
self.step_validate = lambda *args, **kwargs: ... | |||
self.monitor = monitor | |||
self.larger_better = larger_better | |||
if metrics is not None and validate_dataloaders is not None: | |||
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): | |||
raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") | |||
@@ -239,6 +248,7 @@ class Trainer(TrainerEventTrigger): | |||
else: | |||
# validate_every > 0 | |||
self._step_validate_filter = Filter(every=validate_every) | |||
self.metrics = metrics | |||
self.validate_every = validate_every | |||
@@ -320,6 +330,10 @@ class Trainer(TrainerEventTrigger): | |||
self.driver.barrier() | |||
self.on_train_end() | |||
self.driver.barrier() | |||
except EarlyStopException as e: | |||
logger.info(f"Catch early stop exception: {e.msg}.") | |||
self.on_exception(e) | |||
except KeyboardInterrupt as e: | |||
self.driver.on_exception() | |||
self.on_exception(e) | |||
@@ -599,7 +599,7 @@ class TorchDDPDriver(TorchDriver): | |||
:param group: | |||
:return: | |||
""" | |||
return fastnlp_torch_all_gather(obj, device=self.data_device, group=group) | |||
return fastnlp_torch_all_gather(obj, group=group) | |||
def find_free_network_port() -> str: | |||
@@ -1,11 +1,8 @@ | |||
import io | |||
import pickle | |||
from typing import Mapping | |||
_pickler = pickle.Pickler | |||
_unpickler = pickle.Unpickler | |||
from abc import ABC | |||
from typing import Any, Union, List | |||
import numpy as np | |||
from typing import Any, List | |||
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 | |||
@@ -13,103 +10,25 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
from torch import distributed as dist | |||
try: | |||
from torch._C._distributed_c10d import ProcessGroupMPI | |||
except ImportError: | |||
_MPI_AVAILABLE = False | |||
try: | |||
from torch._C._distributed_c10d import ProcessGroupNCCL | |||
except ImportError: | |||
_NCCL_AVAILABLE = False | |||
try: | |||
from torch._C._distributed_c10d import ProcessGroupGloo | |||
from torch._C._distributed_c10d import _ProcessGroupWrapper | |||
except ImportError: | |||
_GLOO_AVAILABLE = False | |||
from fastNLP.core.utils import apply_to_collection | |||
def all_gather_object(object_list, obj, group=None): | |||
""" | |||
Gathers picklable objects from the whole group into a list. Similar to | |||
:func:`all_gather`, but Python objects can be passed in. Note that the object | |||
must be picklable in order to be gathered. | |||
Args: | |||
object_list (list[Any]): Output list. It should be correctly sized as the | |||
size of the group for this collective and will contain the output. | |||
object (Any): Pickable Python object to be broadcast from current process. | |||
group (ProcessGroup, optional): The process group to work on. If None, | |||
the default process group will be used. Default is ``None``. | |||
Returns: | |||
None. If the calling rank is part of this group, the output of the | |||
collective will be populated into the input ``object_list``. If the | |||
calling rank is not part of the group, the passed in ``object_list`` will | |||
be unmodified. | |||
.. note:: Note that this API differs slightly from the :func:`all_gather` | |||
collective since it does not provide an ``async_op`` handle and thus | |||
will be a blocking call. | |||
.. note:: For NCCL-based processed groups, internal tensor representations | |||
of objects must be moved to the GPU device before communication takes | |||
place. In this case, the device used is given by | |||
``torch.cuda.current_device()`` and it is the user's responsiblity to | |||
ensure that this is set so that each rank has an individual GPU, via | |||
``torch.cuda.set_device()``. | |||
.. warning:: | |||
:func:`all_gather_object` uses ``pickle`` module implicitly, which is | |||
known to be insecure. It is possible to construct malicious pickle data | |||
which will execute arbitrary code during unpickling. Only call this | |||
function with data you trust. | |||
Example:: | |||
>>> # Note: Process group initialization omitted on each rank. | |||
>>> import torch.distributed as dist | |||
>>> # Assumes world_size of 3. | |||
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object | |||
>>> output = [None for _ in gather_objects] | |||
>>> dist.all_gather_object(output, gather_objects[dist.get_rank()]) | |||
>>> output | |||
['foo', 12, {1: 2}] | |||
""" | |||
if dist.distributed_c10d._rank_not_in_group(group): | |||
return | |||
input_tensor, local_size = _object_to_tensor(obj) | |||
current_device = torch.device("cpu") | |||
if dist.is_nccl_available() and isinstance( | |||
group or dist.distributed_c10d._get_default_group(), dist.ProcessGroupNCCL | |||
): | |||
# See note about using torch.cuda.current_device() here in docstring. | |||
# We cannot simply use my_rank since rank == device is not necessarily | |||
# true. | |||
current_device = torch.device("cuda", torch.cuda.current_device()) | |||
input_tensor = input_tensor.to(current_device) | |||
local_size = local_size.to(current_device) | |||
# Gather all local sizes. This is so that we can find the max size, and index | |||
# until the correct size when deserializing the tensors. | |||
group_size = dist.get_world_size(group=group) | |||
object_sizes_tensor = torch.zeros( | |||
group_size, dtype=torch.long, device=current_device | |||
) | |||
object_size_list = [ | |||
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) | |||
] | |||
# Allgather tensor sizes | |||
dist.all_gather(object_size_list, local_size, group=group) | |||
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] | |||
# Resize tensor to max size across all ranks. | |||
input_tensor.resize_(max_object_size) | |||
coalesced_output_tensor = torch.empty( | |||
max_object_size * group_size, dtype=torch.uint8, device=current_device | |||
) | |||
# Output tensors are nonoverlapping views of coalesced_output_tensor | |||
output_tensors = [ | |||
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] | |||
for i in range(group_size) | |||
] | |||
dist.all_gather(output_tensors, input_tensor, group=group) | |||
# Deserialize outputs back to object. | |||
for i, tensor in enumerate(output_tensors): | |||
tensor = tensor.type(torch.uint8) | |||
if tensor.device != torch.device("cpu"): | |||
tensor = tensor.cpu() | |||
tensor_size = object_size_list[i] | |||
object_list[i] = _tensor_to_object(tensor, tensor_size) | |||
def _validate_output_list_for_rank(my_rank, dst, gather_list): | |||
if dst == my_rank: | |||
if not gather_list: | |||
@@ -123,8 +42,10 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list): | |||
) | |||
def gather_object(obj, object_gather_list=None, dst=0, group=None): | |||
def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=None): | |||
""" | |||
从其它 rank gather 东西到 dst rank 。 | |||
Gathers picklable objects from the whole group in a single process. | |||
Similar to :func:`gather`, but Python objects can be passed in. Note that the | |||
object must be picklable in order to be gathered. | |||
@@ -176,6 +97,8 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None): | |||
# Ensure object_gather_list is specified appopriately. | |||
my_rank = dist.get_rank() | |||
_validate_output_list_for_rank(my_rank, dst, object_gather_list) | |||
# 防止 unpickle 的时候出现在了发送的 gpu 上。 | |||
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) | |||
input_tensor, local_size = _object_to_tensor(obj) | |||
group_backend = dist.get_backend(group) | |||
current_device = torch.device("cpu") | |||
@@ -266,113 +189,11 @@ def send_recv_object(obj, src, cur_rank, device, group=None, tag=0): | |||
return _tensor_to_object(tensor.cpu(), size) | |||
def _all_gather(obj, **kwargs): | |||
group = kwargs.get('group', None) | |||
if isinstance(obj, torch.Tensor): | |||
gathered_tensor = [torch.zeros_like(obj) for _ in | |||
range(torch.distributed.get_world_size(group=group))] | |||
torch.distributed.all_gather(gathered_tensor, obj, group=group) | |||
return gathered_tensor | |||
elif isinstance(obj, tuple) and isinstance(obj[1], torch.Tensor): | |||
tensor, size = obj | |||
# 首先需要同步 size 吧? | |||
group_size = dist.get_world_size(group=group) | |||
object_sizes_tensor = torch.zeros( | |||
group_size, dtype=torch.long, device=tensor.device | |||
) | |||
object_size_list = [ | |||
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) | |||
] | |||
dist.all_gather(object_size_list, size, group=group) | |||
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] | |||
# Resize tensor to max size across all ranks. | |||
tensor.resize_(max_object_size) | |||
coalesced_output_tensor = torch.empty( | |||
max_object_size * group_size, dtype=torch.uint8, device=tensor.device | |||
) | |||
# Output tensors are nonoverlapping views of coalesced_output_tensor | |||
output_tensors = [ | |||
coalesced_output_tensor[max_object_size * i: max_object_size * (i + 1)] | |||
for i in range(group_size) | |||
] | |||
dist.all_gather(output_tensors, tensor, group=group) | |||
object_list = [] | |||
for i, tensor in enumerate(output_tensors): | |||
tensor = tensor.type(torch.uint8) | |||
tensor_size = object_size_list[i] | |||
object_list.append(_tensor_to_object(tensor, tensor_size)) | |||
return object_list | |||
elif isinstance(obj, tuple) and len(obj) == 2: | |||
obj, _type = obj | |||
gathered_tensor = [torch.zeros_like(obj) for _ in | |||
range(torch.distributed.get_world_size(group=group))] | |||
torch.distributed.all_gather(gathered_tensor, obj, group=group) | |||
if _type == np.ndarray: | |||
gathered_tensor = [t.detach().cpu().numpy() for t in gathered_tensor] | |||
else: | |||
gathered_tensor = [_type(t.item()) for t in gathered_tensor] | |||
return gathered_tensor | |||
else: | |||
raise RuntimeError("Unsupported types to implement all_gather.") | |||
class CanTransferDataType(ABC): | |||
""" | |||
检测可以进行传输的对象。 | |||
""" | |||
@classmethod | |||
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: | |||
if cls is CanTransferDataType: | |||
if issubclass(subclass, Mapping): | |||
return False | |||
if subclass in (torch.Tensor, tuple, list, str, int, float, bool, np.ndarray): | |||
return True | |||
return False | |||
return NotImplemented | |||
def _tensorize(obj, device=None): | |||
if isinstance(obj, torch.Tensor): | |||
return obj | |||
if isinstance(obj, bool): | |||
return torch.tensor(obj, dtype=torch.uint8, device=device), bool | |||
if isinstance(obj, float): | |||
return torch.tensor(obj, dtype=torch.float, device=device), float | |||
if isinstance(obj, int): | |||
return torch.tensor(obj, dtype=torch.int, device=device), int | |||
if isinstance(obj, np.ndarray): | |||
return torch.from_numpy(obj), np.ndarray | |||
return _object_to_tensor(obj, device) | |||
def _to_device(tensor, device): | |||
return tensor.contiguous().to(device) | |||
def convert_to_tensors(data: Any, device=None) -> Any: | |||
data = apply_to_collection(data, CanTransferDataType, _tensorize) | |||
def _move_to_device_and_make_contiguous(t: Union[torch.Tensor, tuple], device: Union[str, torch.device]): | |||
if isinstance(t, tuple): | |||
if isinstance(t[1], torch.Tensor): # 说明是 object 转的 | |||
return t[0].to(device).contiguous(), t[1].to(device) | |||
else: # 说明第二个元素是type,见 to_dtype_tensor 函数 | |||
return t[0].to(device).contiguous(), t[1] | |||
return t.to(device).contiguous() | |||
data = apply_to_collection(data, (torch.Tensor, tuple), _move_to_device_and_make_contiguous, device=device) | |||
return data | |||
def fastnlp_torch_all_gather(obj:Any, device=None, group=None)->List: | |||
def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List: | |||
""" | |||
实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 | |||
@@ -390,36 +211,28 @@ def fastnlp_torch_all_gather(obj:Any, device=None, group=None)->List: | |||
{'a': 1, 'b':[1, 2], 'c':{'d': 2}} | |||
] | |||
:param obj: 任意结构的数据,所有的 value 都会变成 list ,其长度为 world_size ,依次为每个 rank 上的对象值 | |||
:param device: 当前 rank 使用的 device 是哪个。为 None 的话默认使用 torch.cuda.current_device() 获取。 | |||
:param obj: 任意结构的数据,如果为 tensor ,需要保证每个显卡上的 tensor 的形状是一样的。如果传入的是非 tensor 对象都将直接进行 | |||
序列化之后进行传输。 | |||
:param device: 当前该参数无意义。 | |||
:param group: | |||
:return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 | |||
""" | |||
# # 首先将所有的都移动到cpu上并且连续,防止有 pickle 出问题 | |||
# obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) | |||
if device is None: | |||
device = torch.cuda.current_device() | |||
if _TORCH_GREATER_EQUAL_1_8: | |||
if isinstance(obj, torch.Tensor): | |||
objs = [torch.zeros_like(obj) for _ in range(dist.get_world_size(group))] | |||
dist.all_gather(objs, obj, group=group) | |||
else: | |||
objs = [None for _ in range(dist.get_world_size(group))] | |||
dist.all_gather_object(objs, obj) | |||
objs = apply_to_collection(objs, torch.Tensor, _to_device, device=device) # 保证如果有tensor的话,所有tensor都在当前卡上 | |||
return objs | |||
group = group if group is not None else torch.distributed.group.WORLD | |||
data = convert_to_tensors(obj, device=device) | |||
data = apply_to_collection(data, (torch.Tensor, tuple), _all_gather, group=group) | |||
objs = [] | |||
def _get_obj_on_idx(obj, idx): | |||
return obj[idx] | |||
for i in range(dist.get_world_size(group)): | |||
objs.append(apply_to_collection(data, dtype=list, function=_get_obj_on_idx, idx=i)) | |||
# 防止 unpickle 的时候弄到发送的 gpu 上了 | |||
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) | |||
if _TORCH_GREATER_EQUAL_1_8: | |||
dist.all_gather_object(objs, obj, group=group) | |||
else: | |||
objs = all_gather_object(objs, obj, group=group) | |||
return objs | |||
def fastnlp_torch_broadcast_object(obj, src, device, group=None): | |||
def fastnlp_torch_broadcast_object(obj, src, device=None, group=None): | |||
""" | |||
将 src 上的 obj 对象广播到其它 rank 上。 | |||
@@ -430,10 +243,9 @@ def fastnlp_torch_broadcast_object(obj, src, device, group=None): | |||
:return: | |||
""" | |||
cur_rank = dist.get_rank(group) | |||
# if cur_rank == src: | |||
# # 如果有 tensor 全部移动到 cpu 上,方便 pickle | |||
# obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) | |||
if cur_rank == src: | |||
# 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里 | |||
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) | |||
if _TORCH_GREATER_EQUAL_1_8: | |||
if cur_rank!=src: | |||
get_obj = [None] | |||
@@ -442,6 +254,8 @@ def fastnlp_torch_broadcast_object(obj, src, device, group=None): | |||
else: | |||
dist.broadcast_object_list([obj], src=src, group=group) | |||
return obj | |||
if device is None: | |||
device = torch.cuda.current_device() | |||
if cur_rank == src: | |||
tensor, size = _object_to_tensor(obj, device=device) | |||
@@ -460,3 +274,107 @@ def fastnlp_torch_broadcast_object(obj, src, device, group=None): | |||
return _tensor_to_object(tensor, tensor_size=size.item()) | |||
def _check_for_nccl_backend(group): | |||
pg = group or dist.distributed_c10d._get_default_group() | |||
# It is not expected for PG to be wrapped many times, but support it just | |||
# in case | |||
while isinstance(pg, _ProcessGroupWrapper): | |||
pg = pg.wrapped_pg | |||
return ( | |||
dist.is_nccl_available() and | |||
isinstance(pg, dist.ProcessGroupNCCL) | |||
) | |||
def all_gather_object(object_list, obj, group=None): | |||
""" | |||
复制 pytorch 的代码,使得可以版本兼容低版本的 pytorch 。 | |||
Gathers picklable objects from the whole group into a list. Similar to | |||
:func:`all_gather`, but Python objects can be passed in. Note that the object | |||
must be picklable in order to be gathered. | |||
Args: | |||
object_list (list[Any]): Output list. It should be correctly sized as the | |||
size of the group for this collective and will contain the output. | |||
object (Any): Pickable Python object to be broadcast from current process. | |||
group (ProcessGroup, optional): The process group to work on. If None, | |||
the default process group will be used. Default is ``None``. | |||
Returns: | |||
None. If the calling rank is part of this group, the output of the | |||
collective will be populated into the input ``object_list``. If the | |||
calling rank is not part of the group, the passed in ``object_list`` will | |||
be unmodified. | |||
.. note:: Note that this API differs slightly from the :func:`all_gather` | |||
collective since it does not provide an ``async_op`` handle and thus | |||
will be a blocking call. | |||
.. note:: For NCCL-based processed groups, internal tensor representations | |||
of objects must be moved to the GPU device before communication takes | |||
place. In this case, the device used is given by | |||
``torch.cuda.current_device()`` and it is the user's responsiblity to | |||
ensure that this is set so that each rank has an individual GPU, via | |||
``torch.cuda.set_device()``. | |||
.. warning:: | |||
:func:`all_gather_object` uses ``pickle`` module implicitly, which is | |||
known to be insecure. It is possible to construct malicious pickle data | |||
which will execute arbitrary code during unpickling. Only call this | |||
function with data you trust. | |||
Example:: | |||
>>> # Note: Process group initialization omitted on each rank. | |||
>>> import torch.distributed as dist | |||
>>> # Assumes world_size of 3. | |||
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object | |||
>>> output = [None for _ in gather_objects] | |||
>>> dist.all_gather_object(output, gather_objects[dist.get_rank()]) | |||
>>> output | |||
['foo', 12, {1: 2}] | |||
""" | |||
if dist._rank_not_in_group(group): | |||
return | |||
input_tensor, local_size = _object_to_tensor(obj) | |||
current_device = torch.device("cpu") | |||
is_nccl_backend = _check_for_nccl_backend(group) | |||
if is_nccl_backend: | |||
# See note about using torch.cuda.current_device() here in docstring. | |||
# We cannot simply use my_rank since rank == device is not necessarily | |||
# true. | |||
current_device = torch.device("cuda", torch.cuda.current_device()) | |||
input_tensor = input_tensor.to(current_device) | |||
local_size = local_size.to(current_device) | |||
# Gather all local sizes. This is so that we can find the max size, and index | |||
# until the correct size when deserializing the tensors. | |||
group_size = dist.get_world_size(group=group) | |||
object_sizes_tensor = torch.zeros( | |||
group_size, dtype=torch.long, device=current_device | |||
) | |||
object_size_list = [ | |||
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) | |||
] | |||
# Allgather tensor sizes | |||
dist.all_gather(object_size_list, local_size, group=group) | |||
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] | |||
# Resize tensor to max size across all ranks. | |||
input_tensor.resize_(max_object_size) | |||
coalesced_output_tensor = torch.empty( | |||
max_object_size * group_size, dtype=torch.uint8, device=current_device | |||
) | |||
# Output tensors are nonoverlapping views of coalesced_output_tensor | |||
output_tensors = [ | |||
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] | |||
for i in range(group_size) | |||
] | |||
dist.all_gather(output_tensors, input_tensor, group=group) | |||
# Deserialize outputs back to object. | |||
for i, tensor in enumerate(output_tensors): | |||
tensor = tensor.type(torch.uint8) | |||
if tensor.device != torch.device("cpu"): | |||
tensor = tensor.cpu() | |||
tensor_size = object_size_list[i] | |||
object_list[i] = _tensor_to_object(tensor, tensor_size) |
@@ -29,14 +29,16 @@ def _compute_f_pre_rec(beta_square, tp, fn, fp): | |||
class ClassifyFPreRecMetric(Metric): | |||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = False, | |||
tag_vocab: Vocabulary = None, encoding_type: str = None, ignore_labels: List[str] = None, | |||
only_gross: bool = True, f_type='micro', beta=1) -> None: | |||
def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None, num_class: int = 0, | |||
only_gross: bool = True, f_type='micro', beta=1, backend: Union[str, Backend, None] = 'auto', | |||
aggregate_when_get_metric: bool = False) -> None: | |||
super(ClassifyFPreRecMetric, self).__init__(backend=backend, | |||
aggregate_when_get_metric=aggregate_when_get_metric) | |||
if f_type not in ('micro', 'macro'): | |||
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | |||
if tag_vocab: | |||
if not isinstance(tag_vocab, Vocabulary): | |||
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) | |||
self.ignore_labels = ignore_labels | |||
self.f_type = f_type | |||
self.beta = beta | |||
@@ -45,9 +47,32 @@ class ClassifyFPreRecMetric(Metric): | |||
self.tag_vocab = tag_vocab | |||
self._tp, self._fp, self._fn = defaultdict(partial(self.register_element, aggregate_method='sum')),\ | |||
defaultdict(partial(self.register_element, aggregate_method='sum')),\ | |||
defaultdict(partial(self.register_element, aggregate_method='sum')) | |||
self._tp = {} | |||
self._fp = {} | |||
self._fn = {} | |||
if tag_vocab: | |||
for word, _ in tag_vocab: | |||
word = word.lower() | |||
if word != 'o': | |||
word = word[2:] | |||
if word in self._true_positives: | |||
continue | |||
self._tp[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', | |||
backend=backend) | |||
self._fn[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', | |||
backend=backend) | |||
self._fp[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', | |||
backend=backend) | |||
elif num_class > 0: | |||
for word in range(num_class): | |||
self._tp[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', | |||
backend=backend) | |||
self._fn[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', | |||
backend=backend) | |||
self._fp[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', | |||
backend=backend) | |||
else: | |||
raise ValueError() | |||
def get_metric(self) -> dict: | |||
r""" | |||
@@ -68,9 +93,11 @@ class ClassifyFPreRecMetric(Metric): | |||
tag_name = self.tag_vocab.to_word(tag) | |||
else: | |||
tag_name = int(tag) | |||
tp = self._tp[tag] | |||
fn = self._fn[tag] | |||
fp = self._fp[tag] | |||
tp = self._tp[tag].get_scalar() | |||
fn = self._fn[tag].get_scalar() | |||
fp = self._fp[tag].get_scalar() | |||
if tp == fn == fp == 0: | |||
continue | |||
f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp) | |||
f_sum += f | |||
pre_sum += pre | |||
@@ -90,20 +117,29 @@ class ClassifyFPreRecMetric(Metric): | |||
if self.f_type == 'micro': | |||
f, pre, rec = _compute_f_pre_rec(self.beta_square, | |||
sum(self._tp.values()), | |||
sum(self._fn.values()), | |||
sum(self._fp.values())) | |||
sum(val.get_scalar() for val in self._tp.values()), | |||
sum(val.get_scalar() for val in self._fn.values()), | |||
sum(val.get_scalar() for val in self._fp.values())) | |||
evaluate_result['f'] = f | |||
evaluate_result['pre'] = pre | |||
evaluate_result['rec'] = rec | |||
for key, value in evaluate_result.items(): | |||
evaluate_result[key] = round(value, 6) | |||
return evaluate_result | |||
def update(self, pred, target, seq_len=None): | |||
r""" | |||
evaluate函数将针对一个批次的预测结果做评价指标的累计 | |||
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | |||
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) | |||
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), | |||
torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len]) | |||
:param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]). | |||
如果mask也被传进来的话seq_len会被忽略. | |||
""" | |||
pred = self.tensor2numpy(pred) | |||
target = self.tensor2numpy(target) | |||
if seq_len is not None: | |||
@@ -122,14 +158,14 @@ class ClassifyFPreRecMetric(Metric): | |||
f"pred have element numbers: {len(target.flatten())}") | |||
pass | |||
elif len(pred.ndim) == len(target.ndim) + 1: | |||
elif pred.ndim == target.ndim + 1: | |||
pred = pred.argmax(axis=-1) | |||
if seq_len is None and len(target.ndim) > 1: | |||
if seq_len is None and target.ndim > 1: | |||
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | |||
else: | |||
raise RuntimeError(f"when pred have " | |||
f"size:{pred.ndim}, target should have size: {pred.ndim} or " | |||
f"{pred.ndim[:-1]}, got {target.ndim}.") | |||
f"size:{pred.shape}, target should have size: {pred.shape} or " | |||
f"{pred.shape[:-1]}, got {target.shape}.") | |||
if masks is not None: | |||
target = target * masks | |||
pred = pred * masks | |||
@@ -138,5 +174,3 @@ class ClassifyFPreRecMetric(Metric): | |||
self._tp[target_idx] += ((pred == target_idx) * (target != target_idx)).sum().item() | |||
self._fp[target_idx] += ((pred == target_idx) * (target == target_idx)).sum().item() | |||
self._fn[target_idx] += ((pred != target_idx) * (target != target_idx)).sum().item() | |||
@@ -0,0 +1,10 @@ | |||
class EarlyStopException(BaseException): | |||
r""" | |||
用于EarlyStop时从Trainer训练循环中跳出。 | |||
""" | |||
def __init__(self, msg): | |||
super(EarlyStopException, self).__init__(msg) | |||
self.msg = msg |
@@ -94,9 +94,6 @@ class FRichProgress(Progress, metaclass=Singleton): | |||
self.print = self.console.print | |||
self.log = self.console.log | |||
# start new | |||
self.start() | |||
self.console.show_cursor(show=True) | |||
return self | |||
def set_transient(self, transient: bool = True): | |||
@@ -154,6 +151,7 @@ class FRichProgress(Progress, metaclass=Singleton): | |||
super().start() | |||
self.console.show_cursor(show=True) | |||
if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: | |||
f_rich_progress = FRichProgress().new_progess( | |||
"[progress.description]{task.description}", | |||
@@ -12,32 +12,27 @@ def test_get_monitor_value(): | |||
with Capturing() as output: | |||
monitor, value = _get_monitor_value(monitor='f1', real_monitor=None, res=res) | |||
assert monitor == 'f1' and value==0.2 | |||
assert 'We can not find' not in output[0] | |||
# 测试可以匹配,且选择更靠前的 | |||
res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} | |||
with Capturing() as output: | |||
monitor, value = _get_monitor_value(monitor='f1', real_monitor=None, res=res) | |||
assert monitor=='acc#f1' and value==0.2 | |||
assert 'We can not find' in output[0] | |||
# 测试monitor匹配不上,使用real_monitor | |||
res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} | |||
with Capturing() as output: | |||
monitor, value = _get_monitor_value(monitor='acc#f', real_monitor='acc#rec', res=res) | |||
monitor, value = _get_monitor_value(monitor='acc', real_monitor='acc#rec', res=res) | |||
assert monitor=='acc#rec' and value==0.3 | |||
assert 'We can not find' not in output[0] | |||
# 测试monitor/real_monitor匹配不上, 重新选择 | |||
res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} | |||
with Capturing() as output: | |||
monitor, value = _get_monitor_value(monitor='acc#f', real_monitor='acc#r', res=res) | |||
assert monitor=='acc#f1' and value==0.2 | |||
assert 'We can not find' in output[0] | |||
# 测试partial的位置 | |||
res = {"acc#acc": 0.52, "loss#loss": 2} | |||
with Capturing() as output: | |||
monitor, value = _get_monitor_value(monitor='-loss', real_monitor=None, res=res) | |||
assert monitor=='loss#loss' and value==2 | |||
assert 'We can not find' in output[0] |
@@ -7,38 +7,10 @@ import numpy as np | |||
# print(isinstance((1,), tuple)) | |||
# exit() | |||
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, convert_to_tensors, fastnlp_torch_broadcast_object | |||
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object | |||
from tests.helpers.utils import re_run_current_cmd_for_torch, magic_argv_env_context | |||
def test_convert_to_tensors(): | |||
local_rank = 0 | |||
obj = { | |||
'tensor': torch.full(size=(2,), fill_value=local_rank), | |||
'numpy': np.full(shape=(1,), fill_value=local_rank), | |||
'bool': local_rank % 2 == 0, | |||
'float': local_rank + 0.1, | |||
'int': local_rank, | |||
'dict': { | |||
'rank': local_rank | |||
}, | |||
'list': [local_rank] * 2, | |||
'str': 'xxx' | |||
} | |||
data = convert_to_tensors(obj) | |||
assert len(data) == len(obj) | |||
assert (data['tensor'] == obj['tensor']).sum() == 2 | |||
for name in ['list', 'str']: | |||
assert len(data[name])==2 and isinstance(data[name][0], torch.Tensor) and \ | |||
isinstance(data[name][1], torch.Tensor) and data[name][1].ndim==1 | |||
for name in ['numpy', 'bool', 'float', 'int']: | |||
assert isinstance(data[name][0], torch.Tensor) and data[name][0].numel()==1 | |||
assert isinstance(data['dict']['rank'][0], torch.Tensor) and data[name][0].numel() == 1 | |||
@magic_argv_env_context | |||
def test_fastnlp_torch_all_gather(): | |||
os.environ['MASTER_ADDR'] = '127.0.0.1' | |||
@@ -66,7 +38,7 @@ def test_fastnlp_torch_all_gather(): | |||
'tensors': [torch.full(size=(2,), fill_value=local_rank).cuda(), | |||
torch.full(size=(2,), fill_value=local_rank).cuda()] | |||
} | |||
data = fastnlp_torch_all_gather(obj, device=torch.cuda.current_device()) | |||
data = fastnlp_torch_all_gather(obj) | |||
world_size = int(os.environ['WORLD_SIZE']) | |||
assert len(data) == world_size | |||
for i in range(world_size): | |||
@@ -81,10 +53,12 @@ def test_fastnlp_torch_all_gather(): | |||
assert data[i]['tensors'][0][0] == i | |||
for obj in [1, True, 'xxx']: | |||
data = fastnlp_torch_all_gather(obj, device=torch.cuda.current_device()) | |||
data = fastnlp_torch_all_gather(obj) | |||
assert len(data)==world_size | |||
assert data[0]==data[1] | |||
dist.destroy_process_group() | |||
@magic_argv_env_context | |||
def test_fastnlp_torch_broadcast_object(): | |||
os.environ['MASTER_ADDR'] = '127.0.0.1' | |||
@@ -130,3 +104,4 @@ def test_fastnlp_torch_broadcast_object(): | |||
for obj in [int(os.environ['LOCAL_RANK']), bool(os.environ['LOCAL_RANK']=='1'), os.environ['LOCAL_RANK']]: | |||
data = fastnlp_torch_broadcast_object(obj, src=0, device=torch.cuda.current_device()) | |||
assert int(data)==0 | |||
dist.destroy_process_group() |
@@ -0,0 +1,88 @@ | |||
import pytest | |||
import torch | |||
import numpy as np | |||
from fastNLP.core.metrics import ClassifyFPreRecMetric | |||
class TestClassfiyFPreRecMetric: | |||
def test_case_1(self): | |||
pred = torch.tensor([[-0.4375, -0.1779, -1.0985, -1.1592, 0.4910], | |||
[1.3410, 0.2889, -0.8667, -1.8580, 0.3029], | |||
[0.7459, -1.1957, 0.3231, 0.0308, -0.1847], | |||
[1.1439, -0.0057, 0.8203, 0.0312, -1.0051], | |||
[-0.4870, 0.3215, -0.8290, 0.9221, 0.4683], | |||
[0.9078, 1.0674, -0.5629, 0.3895, 0.8917], | |||
[-0.7743, -0.4041, -0.9026, 0.2112, 1.0892], | |||
[1.8232, -1.4188, -2.5615, -2.4187, 0.5907], | |||
[-1.0592, 0.4164, -0.1192, 1.4238, -0.9258], | |||
[-1.1137, 0.5773, 2.5778, 0.5398, -0.3323], | |||
[-0.3868, -0.5165, 0.2286, -1.3876, 0.5561], | |||
[-0.3304, 1.3619, -1.5744, 0.4902, -0.7661], | |||
[1.8387, 0.5234, 0.4269, 1.3748, -1.2793], | |||
[0.6692, 0.2571, 1.2425, -0.5894, -0.0184], | |||
[0.4165, 0.4084, -0.1280, 1.4489, -2.3058], | |||
[-0.5826, -0.5469, 1.5898, -0.2786, -0.9882], | |||
[-1.5548, -2.2891, 0.2983, -1.2145, -0.1947], | |||
[-0.7222, 2.3543, -0.5801, -0.0640, -1.5614], | |||
[-1.4978, 1.9297, -1.3652, -0.2358, 2.5566], | |||
[0.1561, -0.0316, 0.9331, 1.0363, 2.3949], | |||
[0.2650, -0.8459, 1.3221, 0.1321, -1.1900], | |||
[0.0664, -1.2353, -0.5242, -1.4491, 1.3300], | |||
[-0.2744, 0.0941, 0.7157, 0.1404, 1.2046], | |||
[0.9341, -0.6652, 1.4512, 0.9608, -0.3623], | |||
[-1.1641, 0.0873, 0.1163, -0.2068, -0.7002], | |||
[1.4775, -2.0025, -0.5634, -0.1589, 0.0247], | |||
[1.0151, 1.0304, -0.1042, -0.6955, -0.0629], | |||
[-0.3119, -0.4558, 0.7757, 0.0758, -1.6297], | |||
[1.0654, 0.0313, -0.7716, 0.1194, 0.6913], | |||
[-0.8088, -0.6648, -0.5018, -0.0230, -0.8207], | |||
[-0.7753, -0.3508, 1.6163, 0.7158, 1.5207], | |||
[0.8692, 0.7718, -0.6734, 0.6515, 0.0641]]) | |||
arg_max_pred = torch.argmax(pred, dim=-1) | |||
target = torch.tensor([0, 2, 4, 1, 4, 0, 1, 3, 3, 3, 1, 3, 4, 4, 3, 4, 0, 2, 4, 4, 3, 4, 4, 3, | |||
0, 3, 0, 0, 0, 1, 3, 1]) | |||
metric = ClassifyFPreRecMetric(f_type='macro', num_class=5) | |||
metric.update(pred, target) | |||
result_dict = metric.get_metric() | |||
f1_score = 0.1882051282051282 | |||
recall = 0.1619047619047619 | |||
pre = 0.23928571428571427 | |||
ground_truth = {'f': f1_score, 'pre': pre, 'rec': recall} | |||
for keys in ['f', 'pre', 'rec']: | |||
np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001) | |||
metric = ClassifyFPreRecMetric(f_type='micro', num_class=5) | |||
metric.update(pred, target) | |||
result_dict = metric.get_metric() | |||
f1_score = 0.21875 | |||
recall = 0.21875 | |||
pre = 0.21875 | |||
ground_truth = {'f': f1_score, 'pre': pre, 'rec': recall} | |||
for keys in ['f', 'pre', 'rec']: | |||
np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001) | |||
metric = ClassifyFPreRecMetric(only_gross=False, f_type='macro', num_class=5) | |||
metric.update(pred, target) | |||
result_dict = metric.get_metric() | |||
ground_truth = { | |||
'0': {'f1-score': 0.13333333333333333, 'precision': 0.125, 'recall': 0.14285714285714285, 'support': 7}, | |||
'1': {'f1-score': 0.0, 'precision': 0.0, 'recall': 0.0, 'support': 5}, | |||
'2': {'f1-score': 0.0, 'precision': 0.0, 'recall': 0.0, 'support': 2}, | |||
'3': {'f1-score': 0.30769230769230765, 'precision': 0.5, 'recall': 0.2222222222222222, 'support': 9}, | |||
'4': {'f1-score': 0.5, 'precision': 0.5714285714285714, 'recall': 0.4444444444444444, 'support': 9}, | |||
'macro avg': {'f1-score': 0.1882051282051282, 'precision': 0.23928571428571427, | |||
'recall': 0.1619047619047619, 'support': 32}, | |||
'micro avg': {'f1-score': 0.21875, 'precision': 0.21875, 'recall': 0.21875, 'support': 32}, | |||
'weighted avg': {'f1-score': 0.2563301282051282, 'precision': 0.3286830357142857, 'recall': 0.21875, | |||
'support': 32}} | |||
for keys in result_dict.keys(): | |||
if keys == "f" or "pre" or "rec": | |||
continue | |||
gl = str(keys[-1]) | |||
tmp_d = {"p": "precision", "r": "recall", "f": "f1-score"} | |||
gk = tmp_d[keys[0]] | |||
np.allclose(result_dict[keys], ground_truth[gl][gk], atol=0.000001) |