@@ -138,10 +138,6 @@ class CheckpointCallback(HasMonitorCallback): | |||
f'exception_{exception.__class__.__name__}' | |||
self.save(trainer=trainer, folder_name=folder_name) | |||
def on_sanity_check_end(self, trainer, sanity_check_res): | |||
# 主要核对一下 monitor 是否存在。 | |||
self.get_monitor_value(results=sanity_check_res) | |||
def on_save_checkpoint(self, trainer) -> Dict: | |||
""" | |||
保存 timestamp_path 使得之后可以继续训练并保存到该文件夹。 | |||
@@ -49,7 +49,8 @@ class HasMonitorCallback(Callback): | |||
self.monitor = monitor | |||
else: | |||
self.monitor = str(monitor) if monitor is not None else None | |||
self.larger_better = bool(larger_better) | |||
if self.monitor is not None: | |||
self.larger_better = bool(larger_better) | |||
if larger_better: | |||
self.monitor_value = float('-inf') | |||
else: | |||
@@ -71,6 +72,12 @@ class HasMonitorCallback(Callback): | |||
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " | |||
f"You can set it in the initialization or through Trainer.") | |||
def on_sanity_check_end(self, trainer, sanity_check_res): | |||
# 主要核对一下 monitor 是否存在。 | |||
if self.monitor is not None: | |||
self.get_monitor_value(results=sanity_check_res) | |||
def get_monitor_value(self, results:Dict)->Union[float, None]: | |||
""" | |||
获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。 | |||
@@ -10,7 +10,7 @@ import shutil | |||
from fastNLP.envs.env import FASTNLP_LAUNCH_TIME, FASTNLP_GLOBAL_RANK, FASTNLP_BACKEND_LAUNCH | |||
from fastNLP.core.log import logger | |||
from fastNLP.envs import all_rank_call | |||
from fastNLP.envs import all_rank_call_context | |||
class LoadBestModelCallback(HasMonitorCallback): | |||
@@ -76,9 +76,6 @@ class LoadBestModelCallback(HasMonitorCallback): | |||
super().on_after_trainer_initialized(trainer, driver) | |||
def on_sanity_check_end(self, trainer, sanity_check_res): | |||
self.get_monitor_value(sanity_check_res) | |||
def on_validate_end(self, trainer, results): | |||
if self.is_better_results(results, keep_if_better=True): | |||
if self.real_save_folder: | |||
@@ -86,7 +83,7 @@ class LoadBestModelCallback(HasMonitorCallback): | |||
model_save_fn=self.model_save_fn) | |||
else: | |||
self.buffer.seek(0) | |||
with all_rank_call(): | |||
with all_rank_call_context(): | |||
trainer.save_model(folder=self.buffer, only_state_dict=self.only_state_dict) | |||
def on_train_end(self, trainer): | |||
@@ -11,14 +11,15 @@ class LRSchedCallback(Callback): | |||
根据 step_on 参数在合适的时机调用 scheduler 的 step 函数。 | |||
:param scheduler: 实现了 step() 函数的对象 | |||
:param step_on: 可选 ['batch', 'epoch'] 表示在何时调用 scheduler 的 step 函数 | |||
:param step_on: 可选 ['batch', 'epoch'] 表示在何时调用 scheduler 的 step 函数。如果为 batch 的话在每次更新参数 | |||
之前调用;如果为 epoch 则是在一个 epoch 运行结束后调用。 | |||
""" | |||
assert hasattr(scheduler, 'step') and callable(scheduler.step), "The scheduler object should have a " \ | |||
"step function." | |||
self.scheduler = scheduler | |||
self.step_on = 0 if step_on == 'batch' else 1 | |||
def on_train_batch_end(self, trainer): | |||
def on_before_optimizers_step(self, trainer, optimizers): | |||
if self.step_on == 0: | |||
self.scheduler.step() | |||
@@ -32,10 +32,6 @@ class ProgressCallback(HasMonitorCallback): | |||
def on_train_end(self, trainer): | |||
f_rich_progress.stop() | |||
def on_sanity_check_end(self, trainer, sanity_check_res): | |||
if len(sanity_check_res) and getattr(self, 'monitor', None) is not None: | |||
self.get_monitor_value(sanity_check_res) | |||
class RichCallback(ProgressCallback): | |||
def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, | |||
@@ -3,7 +3,6 @@ from functools import partial | |||
from dataclasses import is_dataclass | |||
import sys | |||
__all__ = [ | |||
'Evaluator' | |||
] | |||
@@ -75,8 +74,8 @@ class Evaluator: | |||
当 auto_tensor_conversion_for_metric 为True时,fastNLP 将自动将输出中 paddle 的 tensor (其它非 tensor 的参数 | |||
不做任何处理)转换为 pytorch 的 tensor 再输入到 metrics 中进行评测。 model 的输出 tensor 类型通过 driver 来决定, | |||
metrics 支持的输入类型由 metrics 决定。如果需要更复杂的转换,请使用 input_mapping、output_mapping 参数进行。 | |||
use_dist_sampler: 是否使用分布式evaluate的方式。仅当 driver 为分布式类型时,该参数才有效。如果为True,将使得每个进程上 | |||
的 dataloader 自动使用不同数据,所有进程的数据并集是整个数据集。请确保使用的 metrics 支持自动分布式累积。 | |||
use_dist_sampler: 是否使用分布式evaluate的方式。仅当 driver 为分布式类型时,该参数才有效。默认为根据 driver 是否支持 | |||
分布式进行设置。如果为True,将使得每个进程上的 dataloader 自动使用不同数据,所有进程的数据并集是整个数据集。 | |||
output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | |||
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | |||
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | |||
@@ -86,7 +85,8 @@ class Evaluator: | |||
self.model = model | |||
self.metrics = metrics | |||
self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call, **kwargs) | |||
self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call, | |||
**kwargs) | |||
if dataloaders is None: | |||
raise ValueError("Parameter `dataloaders` can not be None.") | |||
@@ -105,9 +105,13 @@ class Evaluator: | |||
dataloaders = {None: dataloaders} | |||
self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=batch_step_fn) | |||
self.driver.setup() | |||
self.driver.barrier() | |||
self.separator = kwargs.get('separator', '#') | |||
self.model_use_eval_mode = kwargs.get('model_use_eval_mode', True) | |||
use_dist_sampler = kwargs.get("use_dist_sampler", False) # 如果是 Evaluator 自身的默认值的话,应当为 False; | |||
use_dist_sampler = kwargs.get("use_dist_sampler", driver.is_distributed()) | |||
if use_dist_sampler: | |||
self._dist_sampler = "unrepeatdist" | |||
else: | |||
@@ -115,8 +119,9 @@ class Evaluator: | |||
self._metric_wrapper = None | |||
_ = self.metrics_wrapper # 触发检查 | |||
self.driver.setup() | |||
self.driver.barrier() | |||
if self._dist_sampler is not None and not self.driver.is_distributed(): | |||
logger.warning_once("Running in a non-distributed driver, but with distributed sampler, it may cause " | |||
"different process evaluating on different data.") | |||
if evaluate_fn is not None and not isinstance(evaluate_fn, str): | |||
raise TypeError("Parameter `evaluate_fn` can only be `str` type when it is not None.") | |||
@@ -183,7 +188,7 @@ class Evaluator: | |||
return metric_results | |||
def start_progress_bar(self, total:int, dataloader_name): | |||
def start_progress_bar(self, total: int, dataloader_name): | |||
if self.progress_bar == 'rich': | |||
if dataloader_name is None: | |||
desc = f'Eval. Batch:0' | |||
@@ -208,7 +213,7 @@ class Evaluator: | |||
advance=kwargs.get('advance', 1), refresh=kwargs.get('refresh', True), | |||
visible=kwargs.get('visible', True)) | |||
elif self.progress_bar == 'raw': | |||
if self.verbose>1: | |||
if self.verbose > 1: | |||
logger.info(desc) | |||
def remove_progress_bar(self, dataloader_name): | |||
@@ -256,7 +261,7 @@ class Evaluator: | |||
""" | |||
self.metrics_wrapper.update(*args, **kwargs) | |||
def get_dataloader_metric(self, dataloader_name:Optional[str]='') -> Dict: | |||
def get_dataloader_metric(self, dataloader_name: Optional[str] = '') -> Dict: | |||
""" | |||
获取当前dataloader的metric结果 | |||
@@ -313,6 +318,7 @@ class _MetricsWrapper: | |||
并且通过对 update() , reset() , get_metric() 函数的封装,实现支持 fastNLP 的 metric 以及 torchmetrics 或者更多。 | |||
""" | |||
def __init__(self, metrics, evaluator): | |||
self.evaluator = evaluator | |||
self._metrics = [] | |||
@@ -326,13 +332,14 @@ class _MetricsWrapper: | |||
# torchmetrics 是默认自动开启了多卡的 | |||
evaluator.driver.move_model_to_device(metric, evaluator.driver.data_device) | |||
elif isinstance(metric, Metric): | |||
if evaluator._dist_sampler is not None and evaluator.driver.is_distributed() \ | |||
and metric.aggregate_when_get_metric is False: | |||
logger.warning("You have replace the sampler as distributed sampler when evaluation, but your " | |||
f"metric:{metric_name}' `aggregate_when_get_metric` is False.") | |||
if evaluator._dist_sampler is None and evaluator.driver.is_distributed() \ | |||
and metric.aggregate_when_get_metric is True: | |||
pass # 这种情况无所谓,因为 | |||
# 如果数据是分布式的,但是不aggregate的话可能有问题 | |||
if evaluator._dist_sampler is not None and metric.aggregate_when_get_metric is False: | |||
logger.warning_once( | |||
"You have replace the sampler as distributed sampler when evaluation, but your " | |||
f"metric {metric_name}:{metric.__class__.__name__}' `aggregate_when_get_metric` is False.") | |||
if metric.aggregate_when_get_metric is None: | |||
metric.aggregate_when_get_metric = evaluator._dist_sampler is not None | |||
metric.to(evaluator.driver.data_device) | |||
self._metric_names.append(metric_name) | |||
self._metrics.append(metric) | |||
@@ -343,8 +350,9 @@ class _MetricsWrapper: | |||
for metric in self._metrics: | |||
args = [] | |||
if not isinstance(batch, dict): | |||
logger.warning_once(f"The output of the DataLoader is of type:`{type(batch)}`, fastNLP will only depend on " | |||
f"the output of model to update metric.") | |||
logger.warning_once( | |||
f"The output of the DataLoader is of type:`{type(batch)}`, fastNLP will only depend on " | |||
f"the output of model to update metric.") | |||
else: | |||
args.append(batch) | |||
if not isinstance(outputs, dict): | |||
@@ -368,7 +376,7 @@ class _MetricsWrapper: | |||
elif _is_torchmetrics_metric(metric) or _is_paddle_metric(metric) or isinstance(metric, Metric): | |||
metric.reset() | |||
def get_metric(self, dataloader_name:str, separator:str) -> Dict: | |||
def get_metric(self, dataloader_name: str, separator: str) -> Dict: | |||
""" | |||
将所有 metric 结果展平到一个一级的字典中,这个字典中 key 的命名规则是 | |||
indicator_name{separator}metric_name{separator}dataloader_name | |||
@@ -419,4 +427,4 @@ def _get_metric_res_name(dataloader_name: Optional[str], metric_name: str, indic | |||
names.append(dataloader_name) | |||
if len(names) == 0: | |||
raise RuntimeError("You cannot use empty `dataloader_name`, `metric_name`, and `monitor` simultaneously.") | |||
return separator.join(names) | |||
return separator.join(names) |
@@ -122,7 +122,8 @@ class Trainer(TrainerEventTrigger): | |||
注意如果 model_device 为 None,那么 data_device 不会起作用; | |||
torch_ddp_kwargs: 用于配置 pytorch 的 DistributedDataParallel 初始化时的参数; | |||
set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; | |||
use_dist_sampler: 表示在使用 TorchDDPDriver 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; | |||
use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch | |||
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 | |||
use_eval_dist_sampler: 表示在 Evaluator 中在使用 TorchDDPDriver 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; | |||
output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | |||
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | |||
@@ -211,12 +212,6 @@ class Trainer(TrainerEventTrigger): | |||
total_batches=None | |||
) | |||
use_dist_sampler = kwargs.get("use_dist_sampler", True) | |||
if use_dist_sampler: | |||
_dist_sampler = "dist" | |||
else: | |||
_dist_sampler = None | |||
""" 设置内部的 Evaluator """ | |||
if metrics is None and evaluate_dataloaders is not None: | |||
raise ValueError("You have set 'evaluate_dataloader' but forget to set 'metrics'.") | |||
@@ -224,6 +219,18 @@ class Trainer(TrainerEventTrigger): | |||
if metrics is not None and evaluate_dataloaders is None: | |||
raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloader'.") | |||
self.metrics = metrics | |||
self.validate_every = evaluate_every | |||
self.driver.setup() | |||
self.driver.barrier() | |||
use_dist_sampler = kwargs.get("use_dist_sampler", self.driver.is_distributed()) | |||
if use_dist_sampler: | |||
_dist_sampler = "dist" | |||
else: | |||
_dist_sampler = None | |||
self.evaluator = None | |||
self.monitor = monitor | |||
self.larger_better = larger_better | |||
@@ -241,16 +248,10 @@ class Trainer(TrainerEventTrigger): | |||
output_mapping=output_mapping, | |||
fp16=fp16, | |||
verbose=0, | |||
use_dist_sampler=kwargs.get("use_eval_dist_sampler", use_dist_sampler), | |||
use_dist_sampler=kwargs.get("use_eval_dist_sampler", None), | |||
progress_bar=kwargs.get('progress_bar', 'auto') | |||
) | |||
self.metrics = metrics | |||
self.validate_every = evaluate_every | |||
self.driver.setup() | |||
self.driver.barrier() | |||
if train_fn is not None and not isinstance(train_fn, str): | |||
raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") | |||
self._train_step, self._train_step_signature_fn = self.driver.get_model_call_fn("train_step" if train_fn is None else train_fn) | |||
@@ -753,7 +754,7 @@ class Trainer(TrainerEventTrigger): | |||
""" | |||
if (self.global_forward_batches + 1) % self.accumulation_steps != 0: | |||
_no_sync_context = self.driver.get_no_sync_context() | |||
_no_sync_context = self.driver.get_model_no_sync_context() | |||
else: | |||
_no_sync_context = nullcontext | |||
@@ -199,9 +199,10 @@ class Driver(ABC): | |||
""" | |||
raise NotImplementedError("Each specific driver should implemented its own `zero_grad` function.") | |||
def get_no_sync_context(self): | |||
def get_model_no_sync_context(self): | |||
r""" | |||
返回一个用于关闭多进程之间互相同步操作的 context 上下文对象;只有多卡的 driver 需要单独实现该函数,单卡的 driver 不需要; | |||
返回一个用于关闭多进程之间 model 中的自动互相同步操作的 context 上下文对象;只有多卡的 driver 需要单独实现该函数, | |||
单卡的 driver 不需要; | |||
:return: 返回一个类似于 DistributedDataParallel(model).no_sync 的 context 上下文对象; | |||
""" | |||
@@ -357,6 +358,8 @@ class Driver(ABC): | |||
r""" | |||
用于在多进程工作时同步各进程的工作进度,运行快的进程运行到这里会等待运行慢的进程,只有所有进程都运行到此函数时,所有的进程才会继续运行; | |||
仅在多分布式训练场景中有使用。 | |||
注意,该函数的行为会受到 FASTNLP_NO_SYNC 的影响。仅当 FASTNLP_NO_SYNC 在 os.environ 中不存在,或小于 1 时才真的执行 barrier 。 | |||
""" | |||
def is_distributed(self) -> bool: | |||
@@ -82,7 +82,7 @@ class JittorMPIDriver(JittorDriver): | |||
def is_global_zero(self): | |||
return self.global_rank == 0 | |||
def get_no_sync_context(self): | |||
def get_model_no_sync_context(self): | |||
return self.model.no_sync | |||
def unwrap_model(self): | |||
@@ -405,7 +405,7 @@ class PaddleFleetDriver(PaddleDriver): | |||
def is_global_zero(self): | |||
return self.global_rank == 0 | |||
def get_no_sync_context(self): | |||
def get_model_no_sync_context(self): | |||
return self.model.no_sync | |||
def unwrap_model(self): | |||
@@ -5,7 +5,6 @@ import socket | |||
import numpy as np | |||
from time import sleep | |||
from typing import List, Optional, Union, Dict, Tuple, Callable | |||
from functools import partial | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
@@ -29,7 +28,7 @@ from fastNLP.core.drivers.utils import distributed_open_proc | |||
from fastNLP.core.utils import auto_param_call, check_user_specific_params | |||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, ReproducibleBatchSampler, \ | |||
re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler | |||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED | |||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED, FASTNLP_NO_SYNC | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object | |||
@@ -511,7 +510,7 @@ class TorchDDPDriver(TorchDriver): | |||
def is_global_zero(self): | |||
return self.global_rank == 0 | |||
def get_no_sync_context(self): | |||
def get_model_no_sync_context(self): | |||
# 注意此时的 model 是 "DistributedDataParallel" 对象; | |||
return self.model.no_sync | |||
@@ -526,7 +525,8 @@ class TorchDDPDriver(TorchDriver): | |||
return self.local_rank | |||
def barrier(self): | |||
torch.distributed.barrier(async_op=True) | |||
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行 | |||
torch.distributed.barrier(async_op=True) | |||
def is_distributed(self): | |||
return True | |||
@@ -544,6 +544,8 @@ class TorchDDPDriver(TorchDriver): | |||
:return: 如果当前不是分布式 driver 直接返回输入的 obj 。如果当前 rank 是接收端(其 global rank 包含在了 dst 中),则返回 | |||
接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。 | |||
""" | |||
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC == 2 直接返回。 | |||
return | |||
return fastnlp_torch_broadcast_object(obj, src, device=self.data_device, group=group) | |||
def all_gather(self, obj, group) -> List: | |||
@@ -569,6 +571,8 @@ class TorchDDPDriver(TorchDriver): | |||
:param group: | |||
:return: | |||
""" | |||
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC 表示不执行 | |||
return [obj] | |||
return fastnlp_torch_all_gather(obj, group=group) | |||
@@ -1,5 +1,6 @@ | |||
import io | |||
import pickle | |||
import os | |||
_pickler = pickle.Pickler | |||
_unpickler = pickle.Unpickler | |||
from typing import Any, List | |||
@@ -7,6 +8,7 @@ from typing import Any, List | |||
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 | |||
from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
from fastNLP.envs.env import FASTNLP_NO_SYNC | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
from torch import distributed as dist | |||
@@ -34,47 +36,15 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list): | |||
) | |||
def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=DEFAULT_TORCH_GROUP): | |||
def fastnlp_torch_gather_object(obj, dst=0, group=DEFAULT_TORCH_GROUP): | |||
""" | |||
从其它 rank gather 东西到 dst rank 。 | |||
Gathers picklable objects from the whole group in a single process. | |||
Similar to :func:`gather`, but Python objects can be passed in. Note that the | |||
object must be picklable in order to be gathered. | |||
Args: | |||
obj (Any): Input object. Must be picklable. | |||
object_gather_list (list[Any]): Output list. On the ``dst`` rank, it | |||
should be correctly sized as the size of the group for this | |||
collective and will contain the output. Must be ``None`` on non-dst | |||
ranks. (default is ``None``) | |||
dst (int, optional): Destination rank. (default is 0) | |||
group: (ProcessGroup, optional): The process group to work on. If None, | |||
the default process group will be used. Default is ``None``. | |||
Returns: | |||
None. On the ``dst`` rank, ``object_gather_list`` will contain the | |||
output of the collective. | |||
.. note:: Note that this API differs slightly from the gather collective | |||
since it does not provide an async_op handle and thus will be a blocking | |||
call. | |||
.. note:: Note that this API is not supported when using the NCCL backend. | |||
.. warning:: | |||
:func:`gather_object` uses ``pickle`` module implicitly, which is | |||
known to be insecure. It is possible to construct malicious pickle data | |||
which will execute arbitrary code during unpickling. Only call this | |||
function with data you trust. | |||
Example:: | |||
>>> # Note: Process group initialization omitted on each rank. | |||
>>> import torch.distributed as dist | |||
>>> # Assumes world_size of 3. | |||
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object | |||
>>> output = [None for _ in gather_objects] | |||
>>> dist.gather_object( | |||
>>> fastnlp_torch_gather_object( | |||
gather_objects[dist.get_rank()], | |||
output if dist.get_rank() == 0 else None, | |||
dst=0 | |||
@@ -82,7 +52,20 @@ def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=DEFAU | |||
>>> # On rank 0 | |||
>>> output | |||
['foo', 12, {1: 2}] | |||
:param obj: 需要发送的 obj 对象,需要是可以 pickable 的对象 | |||
:param dst: 目标的 rank 。 | |||
:param group: 在哪个 group 执行该函数。 | |||
:return: 在 dst 上面返回 world_size 的 list,依次为 rank 0;rank 1...上 obj | |||
""" | |||
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: | |||
return [obj] | |||
if dist.get_rank() == dst: | |||
object_gather_list = [None for _ in range(dist.get_world_size(group))] | |||
else: | |||
object_gather_list = None | |||
if group is None: | |||
group = DEFAULT_TORCH_GROUP | |||
@@ -212,6 +195,9 @@ def fastnlp_torch_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) - | |||
:param group: | |||
:return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 | |||
""" | |||
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: | |||
return [obj] | |||
if group is None: | |||
group = DEFAULT_TORCH_GROUP | |||
if isinstance(obj, torch.Tensor): | |||
@@ -232,12 +218,18 @@ def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GR | |||
""" | |||
将 src 上的 obj 对象广播到其它 rank 上。 | |||
:param obj: | |||
:param src: | |||
:param obj: 需要发送的对象 | |||
:param src: 从哪里发出。 | |||
:param device: | |||
:param group: | |||
:param group: 属于哪个通信 group | |||
:return: | |||
""" | |||
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: | |||
if src == dist.get_rank(group): | |||
return obj | |||
else: | |||
return None | |||
if group is None: | |||
group = DEFAULT_TORCH_GROUP | |||
cur_rank = dist.get_rank(group) | |||
@@ -289,50 +281,23 @@ def all_gather_object(object_list, obj, group=None): | |||
""" | |||
复制 pytorch 的代码,使得可以版本兼容低版本的 pytorch 。 | |||
Gathers picklable objects from the whole group into a list. Similar to | |||
:func:`all_gather`, but Python objects can be passed in. Note that the object | |||
must be picklable in order to be gathered. | |||
Args: | |||
object_list (list[Any]): Output list. It should be correctly sized as the | |||
size of the group for this collective and will contain the output. | |||
object (Any): Pickable Python object to be broadcast from current process. | |||
group (ProcessGroup, optional): The process group to work on. If None, | |||
the default process group will be used. Default is ``None``. | |||
Returns: | |||
None. If the calling rank is part of this group, the output of the | |||
collective will be populated into the input ``object_list``. If the | |||
calling rank is not part of the group, the passed in ``object_list`` will | |||
be unmodified. | |||
.. note:: Note that this API differs slightly from the :func:`all_gather` | |||
collective since it does not provide an ``async_op`` handle and thus | |||
will be a blocking call. | |||
.. note:: For NCCL-based processed groups, internal tensor representations | |||
of objects must be moved to the GPU device before communication takes | |||
place. In this case, the device used is given by | |||
``torch.cuda.current_device()`` and it is the user's responsiblity to | |||
ensure that this is set so that each rank has an individual GPU, via | |||
``torch.cuda.set_device()``. | |||
.. warning:: | |||
:func:`all_gather_object` uses ``pickle`` module implicitly, which is | |||
known to be insecure. It is possible to construct malicious pickle data | |||
which will execute arbitrary code during unpickling. Only call this | |||
function with data you trust. | |||
Example:: | |||
>>> # Note: Process group initialization omitted on each rank. | |||
>>> import torch.distributed as dist | |||
>>> # Assumes world_size of 3. | |||
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object | |||
>>> output = [None for _ in gather_objects] | |||
>>> dist.all_gather_object(output, gather_objects[dist.get_rank()]) | |||
>>> all_gather_object(output, gather_objects[dist.get_rank()]) | |||
>>> output | |||
['foo', 12, {1: 2}] | |||
:param object_list: | |||
:param obj: | |||
:param group: | |||
:return: | |||
""" | |||
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: | |||
return [obj] | |||
if dist.distributed_c10d._rank_not_in_group(group): | |||
return | |||
if _TORCH_GREATER_EQUAL_1_8: | |||
@@ -35,7 +35,6 @@ def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, | |||
"'jittor', 'paddle', 'fleet'].") | |||
def distributed_open_proc(output_from_new_proc:str, command:List[str], env_copy:dict, rank:int=None): | |||
""" | |||
使用 command 通过 subprocess.Popen 开启新的进程。 | |||
@@ -60,30 +59,3 @@ def distributed_open_proc(output_from_new_proc:str, command:List[str], env_copy: | |||
err_f = open(output_from_new_proc + f'/{rank}_err.log', 'w') | |||
proc = subprocess.Popen(command, env=env_copy, stdout=std_f, stderr=err_f) | |||
return proc | |||
def load_model(filepath: Union[str, Path], backend: str = "torch", **kwargs): | |||
r""" | |||
对应 `load_model`,用来帮助用户加载之前通过 `load_model` 所保存的模型; | |||
:param filepath: 加载的文件的位置; | |||
:param backend: 使用哪种 backend 来加载该 filepath, 目前支持 ["torch", "paddle", "jittor"] 。 | |||
""" | |||
if filepath is None: | |||
raise ValueError("Parameter `path` can not be None.") | |||
assert backend is not None, "Parameter `backend` can not be None." | |||
if backend == "torch": | |||
import torch | |||
_res = torch.load(filepath) | |||
return _res | |||
elif backend == "jittor": | |||
raise NotImplementedError | |||
elif backend == "paddle": | |||
raise NotImplementedError | |||
else: | |||
raise ValueError("Parameter `backend` could only be one of these values: ['torch', 'jittor', 'paddle']") | |||
@@ -13,15 +13,22 @@ from fastNLP.core.utils.utils import seq_len_to_mask | |||
class Accuracy(Metric): | |||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = None): | |||
""" | |||
计算 准确率 的 metric 。 | |||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True): | |||
:param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update() | |||
函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。 | |||
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | |||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 | |||
""" | |||
super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | |||
self.register_element(name='correct', value=0, aggregate_method='sum', backend=backend) | |||
self.register_element(name='total', value=0, aggregate_method="sum", backend=backend) | |||
def get_metric(self) -> dict: | |||
r""" | |||
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. | |||
get_metric 函数将根据 evaluate 函数累计的评价指标统计量来计算最终的评价结果. | |||
:return dict evaluate_result: {"acc": float} | |||
""" | |||
@@ -11,7 +11,6 @@ from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gathe | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
import torch.distributed as dist | |||
import torch.nn.functional as F | |||
def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List: | |||
@@ -31,7 +31,20 @@ def _compute_f_pre_rec(beta_square, tp, fn, fp): | |||
class ClassifyFPreRecMetric(Metric): | |||
def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None, num_class: int = 0, | |||
only_gross: bool = True, f_type='micro', beta=1, backend: Union[str, Backend, None] = 'auto', | |||
aggregate_when_get_metric: bool = False) -> None: | |||
aggregate_when_get_metric: bool = None) -> None: | |||
""" | |||
:param tag_vocab: | |||
:param ignore_labels: | |||
:param num_class: | |||
:param only_gross: | |||
:param f_type: | |||
:param beta: | |||
:param str backend: 目前支持四种类型的backend, [torch, paddle, jittor, auto]。其中 auto 表示根据实际调用 Metric.update() | |||
函数时传入的参数决定具体的 backend ,大部分情况下直接使用 auto 即可。 | |||
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | |||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 | |||
""" | |||
super(ClassifyFPreRecMetric, self).__init__(backend=backend, | |||
aggregate_when_get_metric=aggregate_when_get_metric) | |||
if f_type not in ('micro', 'macro'): | |||
@@ -35,6 +35,8 @@ class Element: | |||
""" | |||
self._check_value_initialized() | |||
if self.aggregate_method is None: # 如果没有 aggregate 则不进行聚合。 | |||
return | |||
try: | |||
self._value = self.backend.aggregate(self._value, self.aggregate_method) | |||
except AggregateMethodError as e: | |||
@@ -14,13 +14,13 @@ from fastNLP.core.metrics.element import Element | |||
class Metric: | |||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True): | |||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = None): | |||
""" | |||
:param str backend: 目前支持四种类型的backend, [torch, paddle, jittor, auto]。其中 auto 表示根据实际调用 Metric.update() | |||
函数时传入的参数决定具体的 backend ,大部分情况下直接使用 auto 即可。 | |||
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | |||
当 backend 不支持分布式时,该参数无意义。 | |||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 | |||
""" | |||
self.backend = AutoBackend(backend) | |||
self._updated = False | |||
@@ -43,7 +43,7 @@ class Metric: | |||
:param name: 当前 element 的名字,注册后,在 Metric 中可以通过 self.{name} 访问该变量。 | |||
:param value: 初始化的值。在调用 Metric.reset() 方法时也将自动设置为该值 | |||
:param aggregate_method: 如何聚合多卡上的结果,如果为单卡执行,该值无意义。 | |||
:param aggregate_method: 如何聚合多卡上的结果,如果为单卡执行,该值无意义。如果设置为 None 则表示该 element 不进行聚合。 | |||
:param backend: 使用的 backend 。Element 的类型会根据 backend 进行实际的初始化。例如 backend 为 torch 则该对象为 | |||
Torch.tensor ; 如果backend 为 paddle 则该对象为 paddle.tensor ;如果 backend 为 jittor , 则该对象为 jittor.Var 。 | |||
一般情况下直接默认为 auto 就行了,fastNLP 会根据实际调用 Metric.update() 函数时传入的参数进行合理的初始化,例如当传入 | |||
@@ -218,7 +218,7 @@ class SpanFPreRecMetric(Metric): | |||
def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None, | |||
only_gross: bool = True, f_type='micro', | |||
beta=1, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True,) -> None: | |||
beta=1, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = None) -> None: | |||
r""" | |||
:param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), | |||
@@ -234,7 +234,7 @@ class SpanFPreRecMetric(Metric): | |||
:param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update() | |||
函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。 | |||
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | |||
当 backend 不支持分布式时,该参数无意义。 | |||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 | |||
""" | |||
super(SpanFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | |||
if f_type not in ('micro', 'macro'): | |||
@@ -222,14 +222,14 @@ class RandomSampler(ReproducibleSampler): | |||
class SequentialSampler(RandomSampler): | |||
def __init__(self, dataset, dist_mode:str='interval', **kwargs): | |||
def __init__(self, dataset, **kwargs): | |||
""" | |||
按照顺序读取 dataset 。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。 | |||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||
:param kwargs: | |||
""" | |||
super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) | |||
super().__init__(dataset=dataset, **kwargs) | |||
def __iter__(self): | |||
if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 | |||
@@ -6,8 +6,9 @@ __all__ = [ | |||
'is_cur_env_distributed', | |||
'get_global_rank', | |||
'rank_zero_call', | |||
'all_rank_call', | |||
'get_gpu_count' | |||
'all_rank_call_context', | |||
'get_gpu_count', | |||
'fastnlp_no_sync_context' | |||
] | |||
@@ -7,10 +7,11 @@ __all__ = [ | |||
'is_cur_env_distributed', | |||
'get_global_rank', | |||
'rank_zero_call', | |||
'all_rank_call' | |||
'all_rank_call_context', | |||
'fastnlp_no_sync_context' | |||
] | |||
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK | |||
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK, FASTNLP_NO_SYNC | |||
def is_cur_env_distributed() -> bool: | |||
@@ -41,24 +42,46 @@ def rank_zero_call(fn: Callable): | |||
return a+b | |||
rank_zero_call(add)(1, 2) | |||
同时,该函数还会设置 FASTNLP_NO_SYNC 为 2,在这个环境下,所有的 fastNLP 内置的 barrier 接口,gather/broadcast 操作都没有任何 | |||
意义。 | |||
:param fn: 需要包裹的可执行的函数。 | |||
:return: | |||
""" | |||
@wraps(fn) | |||
def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]: | |||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||
return fn(*args, **kwargs) | |||
with fastnlp_no_sync_context(level=2): | |||
return fn(*args, **kwargs) | |||
return None | |||
return wrapped_fn | |||
@contextmanager | |||
def all_rank_call(): | |||
def fastnlp_no_sync_context(level=2): | |||
""" | |||
用于让 fastNLP 的 barrier 以及 gather/broadcast等操作等同于只有1卡的多卡程序。如果为 1 表示 fastNLP 里的barrier 操作失效; | |||
如果为 2 表示 barrier 与 gather/broadcast 都失效。 | |||
:param int level: 可选 [0, 1, 2] | |||
:return: | |||
""" | |||
old_level = os.environ.get(FASTNLP_NO_SYNC, None) | |||
os.environ[FASTNLP_NO_SYNC] = f'{level}' | |||
yield | |||
if old_level is None: | |||
os.environ.pop(FASTNLP_NO_SYNC) | |||
else: | |||
os.environ[FASTNLP_NO_SYNC] = old_level | |||
@contextmanager | |||
def all_rank_call_context(): | |||
""" | |||
在多卡模式下,该环境内,会暂时地将 FASTNLP_GLOBAL_RANK 设置为 "0",使得 rank_zero_call 函数失效,使得每个进程都会运行该函数。 | |||
# 使用方式 | |||
with all_rank_run(): | |||
with all_rank_call_context(): | |||
do_something # all rank will do | |||
:param fn: | |||
@@ -48,6 +48,10 @@ FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH" | |||
# fastNLP 中初始化deque的默认大小 | |||
FASTNLP_DEQUE_SIZE = 'FASTNLP_DEQUE_SIZE' | |||
# fastNLP中用于关闭 fastNLP 1.barrier 与 2.gather/broadcast 。默认为 '0' 表示不关闭;为 '1' 表示 fastNLP 的 barrier 不执行; | |||
# 为 '2' 表示 barrier 与 gather/broadcast 都关闭。 | |||
FASTNLP_NO_SYNC = 'FASTNLP_NO_SYNC' | |||
# todo 注释 直接使用的变量 | |||
FASTNLP_MODEL_FILENAME = "fastnlp_model.pkl.tar" | |||
FASTNLP_CHECKPOINT_FILENAME = "fastnlp_checkpoint.pkl.tar" | |||
@@ -69,7 +69,7 @@ class TestFleetDriverFunction: | |||
""" | |||
测试 get_no_sync_context 函数 | |||
""" | |||
res = self.driver.get_no_sync_context() | |||
res = self.driver.get_model_no_sync_context() | |||
dist.barrier() | |||
@magic_argv_env_context | |||
@@ -1,6 +1,6 @@ | |||
import os | |||
from fastNLP.envs.distributed import rank_zero_call, all_rank_call | |||
from fastNLP.envs.distributed import rank_zero_call, all_rank_call_context | |||
from tests.helpers.utils import re_run_current_cmd_for_torch, Capturing, magic_argv_env_context | |||
@@ -70,7 +70,7 @@ class TestTorch: | |||
re_run_current_cmd_for_torch(1, output_from_new_proc='all') | |||
# torch.distributed.init_process_group(backend='nccl') | |||
# torch.distributed.barrier() | |||
with all_rank_call(): | |||
with all_rank_call_context(): | |||
with Capturing(no_del=True) as output: | |||
write_something() | |||
output = output[0] | |||
@@ -80,7 +80,7 @@ class TestTorch: | |||
else: | |||
assert '11111' in output | |||
with all_rank_call(): | |||
with all_rank_call_context(): | |||
with Capturing(no_del=True) as output: | |||
rank_zero_call(write_other_thing)() | |||