diff --git a/fastNLP/core/callbacks/checkpoint_callback.py b/fastNLP/core/callbacks/checkpoint_callback.py index b13632d1..2cb3510e 100644 --- a/fastNLP/core/callbacks/checkpoint_callback.py +++ b/fastNLP/core/callbacks/checkpoint_callback.py @@ -138,10 +138,6 @@ class CheckpointCallback(HasMonitorCallback): f'exception_{exception.__class__.__name__}' self.save(trainer=trainer, folder_name=folder_name) - def on_sanity_check_end(self, trainer, sanity_check_res): - # 主要核对一下 monitor 是否存在。 - self.get_monitor_value(results=sanity_check_res) - def on_save_checkpoint(self, trainer) -> Dict: """ 保存 timestamp_path 使得之后可以继续训练并保存到该文件夹。 diff --git a/fastNLP/core/callbacks/has_monitor_callback.py b/fastNLP/core/callbacks/has_monitor_callback.py index 54bd9bb4..74fa3aaf 100644 --- a/fastNLP/core/callbacks/has_monitor_callback.py +++ b/fastNLP/core/callbacks/has_monitor_callback.py @@ -49,7 +49,8 @@ class HasMonitorCallback(Callback): self.monitor = monitor else: self.monitor = str(monitor) if monitor is not None else None - self.larger_better = bool(larger_better) + if self.monitor is not None: + self.larger_better = bool(larger_better) if larger_better: self.monitor_value = float('-inf') else: @@ -71,6 +72,12 @@ class HasMonitorCallback(Callback): raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " f"You can set it in the initialization or through Trainer.") + + def on_sanity_check_end(self, trainer, sanity_check_res): + # 主要核对一下 monitor 是否存在。 + if self.monitor is not None: + self.get_monitor_value(results=sanity_check_res) + def get_monitor_value(self, results:Dict)->Union[float, None]: """ 获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。 diff --git a/fastNLP/core/callbacks/load_best_model_callback.py b/fastNLP/core/callbacks/load_best_model_callback.py index f240caa7..93b95667 100644 --- a/fastNLP/core/callbacks/load_best_model_callback.py +++ b/fastNLP/core/callbacks/load_best_model_callback.py @@ -10,7 +10,7 @@ import shutil from fastNLP.envs.env import FASTNLP_LAUNCH_TIME, FASTNLP_GLOBAL_RANK, FASTNLP_BACKEND_LAUNCH from fastNLP.core.log import logger -from fastNLP.envs import all_rank_call +from fastNLP.envs import all_rank_call_context class LoadBestModelCallback(HasMonitorCallback): @@ -76,9 +76,6 @@ class LoadBestModelCallback(HasMonitorCallback): super().on_after_trainer_initialized(trainer, driver) - def on_sanity_check_end(self, trainer, sanity_check_res): - self.get_monitor_value(sanity_check_res) - def on_validate_end(self, trainer, results): if self.is_better_results(results, keep_if_better=True): if self.real_save_folder: @@ -86,7 +83,7 @@ class LoadBestModelCallback(HasMonitorCallback): model_save_fn=self.model_save_fn) else: self.buffer.seek(0) - with all_rank_call(): + with all_rank_call_context(): trainer.save_model(folder=self.buffer, only_state_dict=self.only_state_dict) def on_train_end(self, trainer): diff --git a/fastNLP/core/callbacks/lr_scheduler_callback.py b/fastNLP/core/callbacks/lr_scheduler_callback.py index a0219778..ba496b5e 100644 --- a/fastNLP/core/callbacks/lr_scheduler_callback.py +++ b/fastNLP/core/callbacks/lr_scheduler_callback.py @@ -11,14 +11,15 @@ class LRSchedCallback(Callback): 根据 step_on 参数在合适的时机调用 scheduler 的 step 函数。 :param scheduler: 实现了 step() 函数的对象 - :param step_on: 可选 ['batch', 'epoch'] 表示在何时调用 scheduler 的 step 函数 + :param step_on: 可选 ['batch', 'epoch'] 表示在何时调用 scheduler 的 step 函数。如果为 batch 的话在每次更新参数 + 之前调用;如果为 epoch 则是在一个 epoch 运行结束后调用。 """ assert hasattr(scheduler, 'step') and callable(scheduler.step), "The scheduler object should have a " \ "step function." self.scheduler = scheduler self.step_on = 0 if step_on == 'batch' else 1 - def on_train_batch_end(self, trainer): + def on_before_optimizers_step(self, trainer, optimizers): if self.step_on == 0: self.scheduler.step() diff --git a/fastNLP/core/callbacks/progress_callback.py b/fastNLP/core/callbacks/progress_callback.py index bb638122..f3d5a435 100644 --- a/fastNLP/core/callbacks/progress_callback.py +++ b/fastNLP/core/callbacks/progress_callback.py @@ -32,10 +32,6 @@ class ProgressCallback(HasMonitorCallback): def on_train_end(self, trainer): f_rich_progress.stop() - def on_sanity_check_end(self, trainer, sanity_check_res): - if len(sanity_check_res) and getattr(self, 'monitor', None) is not None: - self.get_monitor_value(sanity_check_res) - class RichCallback(ProgressCallback): def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index 7394961a..38522c9b 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -3,7 +3,6 @@ from functools import partial from dataclasses import is_dataclass import sys - __all__ = [ 'Evaluator' ] @@ -75,8 +74,8 @@ class Evaluator: 当 auto_tensor_conversion_for_metric 为True时,fastNLP 将自动将输出中 paddle 的 tensor (其它非 tensor 的参数 不做任何处理)转换为 pytorch 的 tensor 再输入到 metrics 中进行评测。 model 的输出 tensor 类型通过 driver 来决定, metrics 支持的输入类型由 metrics 决定。如果需要更复杂的转换,请使用 input_mapping、output_mapping 参数进行。 - use_dist_sampler: 是否使用分布式evaluate的方式。仅当 driver 为分布式类型时,该参数才有效。如果为True,将使得每个进程上 - 的 dataloader 自动使用不同数据,所有进程的数据并集是整个数据集。请确保使用的 metrics 支持自动分布式累积。 + use_dist_sampler: 是否使用分布式evaluate的方式。仅当 driver 为分布式类型时,该参数才有效。默认为根据 driver 是否支持 + 分布式进行设置。如果为True,将使得每个进程上的 dataloader 自动使用不同数据,所有进程的数据并集是整个数据集。 output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; @@ -86,7 +85,8 @@ class Evaluator: self.model = model self.metrics = metrics - self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call, **kwargs) + self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call, + **kwargs) if dataloaders is None: raise ValueError("Parameter `dataloaders` can not be None.") @@ -105,9 +105,13 @@ class Evaluator: dataloaders = {None: dataloaders} self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=batch_step_fn) + + self.driver.setup() + self.driver.barrier() + self.separator = kwargs.get('separator', '#') self.model_use_eval_mode = kwargs.get('model_use_eval_mode', True) - use_dist_sampler = kwargs.get("use_dist_sampler", False) # 如果是 Evaluator 自身的默认值的话,应当为 False; + use_dist_sampler = kwargs.get("use_dist_sampler", driver.is_distributed()) if use_dist_sampler: self._dist_sampler = "unrepeatdist" else: @@ -115,8 +119,9 @@ class Evaluator: self._metric_wrapper = None _ = self.metrics_wrapper # 触发检查 - self.driver.setup() - self.driver.barrier() + if self._dist_sampler is not None and not self.driver.is_distributed(): + logger.warning_once("Running in a non-distributed driver, but with distributed sampler, it may cause " + "different process evaluating on different data.") if evaluate_fn is not None and not isinstance(evaluate_fn, str): raise TypeError("Parameter `evaluate_fn` can only be `str` type when it is not None.") @@ -183,7 +188,7 @@ class Evaluator: return metric_results - def start_progress_bar(self, total:int, dataloader_name): + def start_progress_bar(self, total: int, dataloader_name): if self.progress_bar == 'rich': if dataloader_name is None: desc = f'Eval. Batch:0' @@ -208,7 +213,7 @@ class Evaluator: advance=kwargs.get('advance', 1), refresh=kwargs.get('refresh', True), visible=kwargs.get('visible', True)) elif self.progress_bar == 'raw': - if self.verbose>1: + if self.verbose > 1: logger.info(desc) def remove_progress_bar(self, dataloader_name): @@ -256,7 +261,7 @@ class Evaluator: """ self.metrics_wrapper.update(*args, **kwargs) - def get_dataloader_metric(self, dataloader_name:Optional[str]='') -> Dict: + def get_dataloader_metric(self, dataloader_name: Optional[str] = '') -> Dict: """ 获取当前dataloader的metric结果 @@ -313,6 +318,7 @@ class _MetricsWrapper: 并且通过对 update() , reset() , get_metric() 函数的封装,实现支持 fastNLP 的 metric 以及 torchmetrics 或者更多。 """ + def __init__(self, metrics, evaluator): self.evaluator = evaluator self._metrics = [] @@ -326,13 +332,14 @@ class _MetricsWrapper: # torchmetrics 是默认自动开启了多卡的 evaluator.driver.move_model_to_device(metric, evaluator.driver.data_device) elif isinstance(metric, Metric): - if evaluator._dist_sampler is not None and evaluator.driver.is_distributed() \ - and metric.aggregate_when_get_metric is False: - logger.warning("You have replace the sampler as distributed sampler when evaluation, but your " - f"metric:{metric_name}' `aggregate_when_get_metric` is False.") - if evaluator._dist_sampler is None and evaluator.driver.is_distributed() \ - and metric.aggregate_when_get_metric is True: - pass # 这种情况无所谓,因为 + # 如果数据是分布式的,但是不aggregate的话可能有问题 + if evaluator._dist_sampler is not None and metric.aggregate_when_get_metric is False: + logger.warning_once( + "You have replace the sampler as distributed sampler when evaluation, but your " + f"metric {metric_name}:{metric.__class__.__name__}' `aggregate_when_get_metric` is False.") + if metric.aggregate_when_get_metric is None: + metric.aggregate_when_get_metric = evaluator._dist_sampler is not None + metric.to(evaluator.driver.data_device) self._metric_names.append(metric_name) self._metrics.append(metric) @@ -343,8 +350,9 @@ class _MetricsWrapper: for metric in self._metrics: args = [] if not isinstance(batch, dict): - logger.warning_once(f"The output of the DataLoader is of type:`{type(batch)}`, fastNLP will only depend on " - f"the output of model to update metric.") + logger.warning_once( + f"The output of the DataLoader is of type:`{type(batch)}`, fastNLP will only depend on " + f"the output of model to update metric.") else: args.append(batch) if not isinstance(outputs, dict): @@ -368,7 +376,7 @@ class _MetricsWrapper: elif _is_torchmetrics_metric(metric) or _is_paddle_metric(metric) or isinstance(metric, Metric): metric.reset() - def get_metric(self, dataloader_name:str, separator:str) -> Dict: + def get_metric(self, dataloader_name: str, separator: str) -> Dict: """ 将所有 metric 结果展平到一个一级的字典中,这个字典中 key 的命名规则是 indicator_name{separator}metric_name{separator}dataloader_name @@ -419,4 +427,4 @@ def _get_metric_res_name(dataloader_name: Optional[str], metric_name: str, indic names.append(dataloader_name) if len(names) == 0: raise RuntimeError("You cannot use empty `dataloader_name`, `metric_name`, and `monitor` simultaneously.") - return separator.join(names) \ No newline at end of file + return separator.join(names) diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index f26d841c..40ec635d 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -122,7 +122,8 @@ class Trainer(TrainerEventTrigger): 注意如果 model_device 为 None,那么 data_device 不会起作用; torch_ddp_kwargs: 用于配置 pytorch 的 DistributedDataParallel 初始化时的参数; set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; - use_dist_sampler: 表示在使用 TorchDDPDriver 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; + use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch + 内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 use_eval_dist_sampler: 表示在 Evaluator 中在使用 TorchDDPDriver 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 @@ -211,12 +212,6 @@ class Trainer(TrainerEventTrigger): total_batches=None ) - use_dist_sampler = kwargs.get("use_dist_sampler", True) - if use_dist_sampler: - _dist_sampler = "dist" - else: - _dist_sampler = None - """ 设置内部的 Evaluator """ if metrics is None and evaluate_dataloaders is not None: raise ValueError("You have set 'evaluate_dataloader' but forget to set 'metrics'.") @@ -224,6 +219,18 @@ class Trainer(TrainerEventTrigger): if metrics is not None and evaluate_dataloaders is None: raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloader'.") + self.metrics = metrics + self.validate_every = evaluate_every + + self.driver.setup() + self.driver.barrier() + + use_dist_sampler = kwargs.get("use_dist_sampler", self.driver.is_distributed()) + if use_dist_sampler: + _dist_sampler = "dist" + else: + _dist_sampler = None + self.evaluator = None self.monitor = monitor self.larger_better = larger_better @@ -241,16 +248,10 @@ class Trainer(TrainerEventTrigger): output_mapping=output_mapping, fp16=fp16, verbose=0, - use_dist_sampler=kwargs.get("use_eval_dist_sampler", use_dist_sampler), + use_dist_sampler=kwargs.get("use_eval_dist_sampler", None), progress_bar=kwargs.get('progress_bar', 'auto') ) - self.metrics = metrics - self.validate_every = evaluate_every - - self.driver.setup() - self.driver.barrier() - if train_fn is not None and not isinstance(train_fn, str): raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") self._train_step, self._train_step_signature_fn = self.driver.get_model_call_fn("train_step" if train_fn is None else train_fn) @@ -753,7 +754,7 @@ class Trainer(TrainerEventTrigger): """ if (self.global_forward_batches + 1) % self.accumulation_steps != 0: - _no_sync_context = self.driver.get_no_sync_context() + _no_sync_context = self.driver.get_model_no_sync_context() else: _no_sync_context = nullcontext diff --git a/fastNLP/core/drivers/driver.py b/fastNLP/core/drivers/driver.py index 06547516..1a810865 100644 --- a/fastNLP/core/drivers/driver.py +++ b/fastNLP/core/drivers/driver.py @@ -199,9 +199,10 @@ class Driver(ABC): """ raise NotImplementedError("Each specific driver should implemented its own `zero_grad` function.") - def get_no_sync_context(self): + def get_model_no_sync_context(self): r""" - 返回一个用于关闭多进程之间互相同步操作的 context 上下文对象;只有多卡的 driver 需要单独实现该函数,单卡的 driver 不需要; + 返回一个用于关闭多进程之间 model 中的自动互相同步操作的 context 上下文对象;只有多卡的 driver 需要单独实现该函数, + 单卡的 driver 不需要; :return: 返回一个类似于 DistributedDataParallel(model).no_sync 的 context 上下文对象; """ @@ -357,6 +358,8 @@ class Driver(ABC): r""" 用于在多进程工作时同步各进程的工作进度,运行快的进程运行到这里会等待运行慢的进程,只有所有进程都运行到此函数时,所有的进程才会继续运行; 仅在多分布式训练场景中有使用。 + + 注意,该函数的行为会受到 FASTNLP_NO_SYNC 的影响。仅当 FASTNLP_NO_SYNC 在 os.environ 中不存在,或小于 1 时才真的执行 barrier 。 """ def is_distributed(self) -> bool: diff --git a/fastNLP/core/drivers/jittor_driver/mpi.py b/fastNLP/core/drivers/jittor_driver/mpi.py index 98ac44a0..bb52f67d 100644 --- a/fastNLP/core/drivers/jittor_driver/mpi.py +++ b/fastNLP/core/drivers/jittor_driver/mpi.py @@ -82,7 +82,7 @@ class JittorMPIDriver(JittorDriver): def is_global_zero(self): return self.global_rank == 0 - def get_no_sync_context(self): + def get_model_no_sync_context(self): return self.model.no_sync def unwrap_model(self): diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index a083e42c..7f6f5cc5 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -405,7 +405,7 @@ class PaddleFleetDriver(PaddleDriver): def is_global_zero(self): return self.global_rank == 0 - def get_no_sync_context(self): + def get_model_no_sync_context(self): return self.model.no_sync def unwrap_model(self): diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index a37525f4..d68b6a0d 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -5,7 +5,6 @@ import socket import numpy as np from time import sleep from typing import List, Optional, Union, Dict, Tuple, Callable -from functools import partial from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: @@ -29,7 +28,7 @@ from fastNLP.core.drivers.utils import distributed_open_proc from fastNLP.core.utils import auto_param_call, check_user_specific_params from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, ReproducibleBatchSampler, \ re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler -from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED +from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED, FASTNLP_NO_SYNC from fastNLP.core.log import logger from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object @@ -511,7 +510,7 @@ class TorchDDPDriver(TorchDriver): def is_global_zero(self): return self.global_rank == 0 - def get_no_sync_context(self): + def get_model_no_sync_context(self): # 注意此时的 model 是 "DistributedDataParallel" 对象; return self.model.no_sync @@ -526,7 +525,8 @@ class TorchDDPDriver(TorchDriver): return self.local_rank def barrier(self): - torch.distributed.barrier(async_op=True) + if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行 + torch.distributed.barrier(async_op=True) def is_distributed(self): return True @@ -544,6 +544,8 @@ class TorchDDPDriver(TorchDriver): :return: 如果当前不是分布式 driver 直接返回输入的 obj 。如果当前 rank 是接收端(其 global rank 包含在了 dst 中),则返回 接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。 """ + if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC == 2 直接返回。 + return return fastnlp_torch_broadcast_object(obj, src, device=self.data_device, group=group) def all_gather(self, obj, group) -> List: @@ -569,6 +571,8 @@ class TorchDDPDriver(TorchDriver): :param group: :return: """ + if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC 表示不执行 + return [obj] return fastnlp_torch_all_gather(obj, group=group) diff --git a/fastNLP/core/drivers/torch_driver/dist_utils.py b/fastNLP/core/drivers/torch_driver/dist_utils.py index 37110577..c77b8416 100644 --- a/fastNLP/core/drivers/torch_driver/dist_utils.py +++ b/fastNLP/core/drivers/torch_driver/dist_utils.py @@ -1,5 +1,6 @@ import io import pickle +import os _pickler = pickle.Pickler _unpickler = pickle.Unpickler from typing import Any, List @@ -7,6 +8,7 @@ from typing import Any, List from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP from fastNLP.envs.imports import _NEED_IMPORT_TORCH +from fastNLP.envs.env import FASTNLP_NO_SYNC if _NEED_IMPORT_TORCH: import torch from torch import distributed as dist @@ -34,47 +36,15 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list): ) -def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=DEFAULT_TORCH_GROUP): +def fastnlp_torch_gather_object(obj, dst=0, group=DEFAULT_TORCH_GROUP): """ 从其它 rank gather 东西到 dst rank 。 - Gathers picklable objects from the whole group in a single process. - Similar to :func:`gather`, but Python objects can be passed in. Note that the - object must be picklable in order to be gathered. - - Args: - obj (Any): Input object. Must be picklable. - object_gather_list (list[Any]): Output list. On the ``dst`` rank, it - should be correctly sized as the size of the group for this - collective and will contain the output. Must be ``None`` on non-dst - ranks. (default is ``None``) - dst (int, optional): Destination rank. (default is 0) - group: (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. Default is ``None``. - - Returns: - None. On the ``dst`` rank, ``object_gather_list`` will contain the - output of the collective. - - .. note:: Note that this API differs slightly from the gather collective - since it does not provide an async_op handle and thus will be a blocking - call. - - .. note:: Note that this API is not supported when using the NCCL backend. - - .. warning:: - :func:`gather_object` uses ``pickle`` module implicitly, which is - known to be insecure. It is possible to construct malicious pickle data - which will execute arbitrary code during unpickling. Only call this - function with data you trust. - Example:: - >>> # Note: Process group initialization omitted on each rank. - >>> import torch.distributed as dist >>> # Assumes world_size of 3. >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object >>> output = [None for _ in gather_objects] - >>> dist.gather_object( + >>> fastnlp_torch_gather_object( gather_objects[dist.get_rank()], output if dist.get_rank() == 0 else None, dst=0 @@ -82,7 +52,20 @@ def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=DEFAU >>> # On rank 0 >>> output ['foo', 12, {1: 2}] + + :param obj: 需要发送的 obj 对象,需要是可以 pickable 的对象 + :param dst: 目标的 rank 。 + :param group: 在哪个 group 执行该函数。 + :return: 在 dst 上面返回 world_size 的 list,依次为 rank 0;rank 1...上 obj """ + if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: + return [obj] + + if dist.get_rank() == dst: + object_gather_list = [None for _ in range(dist.get_world_size(group))] + else: + object_gather_list = None + if group is None: group = DEFAULT_TORCH_GROUP @@ -212,6 +195,9 @@ def fastnlp_torch_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) - :param group: :return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 """ + if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: + return [obj] + if group is None: group = DEFAULT_TORCH_GROUP if isinstance(obj, torch.Tensor): @@ -232,12 +218,18 @@ def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GR """ 将 src 上的 obj 对象广播到其它 rank 上。 - :param obj: - :param src: + :param obj: 需要发送的对象 + :param src: 从哪里发出。 :param device: - :param group: + :param group: 属于哪个通信 group :return: """ + if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: + if src == dist.get_rank(group): + return obj + else: + return None + if group is None: group = DEFAULT_TORCH_GROUP cur_rank = dist.get_rank(group) @@ -289,50 +281,23 @@ def all_gather_object(object_list, obj, group=None): """ 复制 pytorch 的代码,使得可以版本兼容低版本的 pytorch 。 - Gathers picklable objects from the whole group into a list. Similar to - :func:`all_gather`, but Python objects can be passed in. Note that the object - must be picklable in order to be gathered. - - Args: - object_list (list[Any]): Output list. It should be correctly sized as the - size of the group for this collective and will contain the output. - object (Any): Pickable Python object to be broadcast from current process. - group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. Default is ``None``. - - Returns: - None. If the calling rank is part of this group, the output of the - collective will be populated into the input ``object_list``. If the - calling rank is not part of the group, the passed in ``object_list`` will - be unmodified. - - .. note:: Note that this API differs slightly from the :func:`all_gather` - collective since it does not provide an ``async_op`` handle and thus - will be a blocking call. - - .. note:: For NCCL-based processed groups, internal tensor representations - of objects must be moved to the GPU device before communication takes - place. In this case, the device used is given by - ``torch.cuda.current_device()`` and it is the user's responsiblity to - ensure that this is set so that each rank has an individual GPU, via - ``torch.cuda.set_device()``. - - .. warning:: - :func:`all_gather_object` uses ``pickle`` module implicitly, which is - known to be insecure. It is possible to construct malicious pickle data - which will execute arbitrary code during unpickling. Only call this - function with data you trust. - Example:: >>> # Note: Process group initialization omitted on each rank. - >>> import torch.distributed as dist >>> # Assumes world_size of 3. >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object >>> output = [None for _ in gather_objects] - >>> dist.all_gather_object(output, gather_objects[dist.get_rank()]) + >>> all_gather_object(output, gather_objects[dist.get_rank()]) >>> output ['foo', 12, {1: 2}] + + :param object_list: + :param obj: + :param group: + :return: """ + if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: + return [obj] + if dist.distributed_c10d._rank_not_in_group(group): return if _TORCH_GREATER_EQUAL_1_8: diff --git a/fastNLP/core/drivers/utils.py b/fastNLP/core/drivers/utils.py index 358046b6..040747f0 100644 --- a/fastNLP/core/drivers/utils.py +++ b/fastNLP/core/drivers/utils.py @@ -35,7 +35,6 @@ def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, "'jittor', 'paddle', 'fleet'].") - def distributed_open_proc(output_from_new_proc:str, command:List[str], env_copy:dict, rank:int=None): """ 使用 command 通过 subprocess.Popen 开启新的进程。 @@ -60,30 +59,3 @@ def distributed_open_proc(output_from_new_proc:str, command:List[str], env_copy: err_f = open(output_from_new_proc + f'/{rank}_err.log', 'w') proc = subprocess.Popen(command, env=env_copy, stdout=std_f, stderr=err_f) return proc - - -def load_model(filepath: Union[str, Path], backend: str = "torch", **kwargs): - r""" - 对应 `load_model`,用来帮助用户加载之前通过 `load_model` 所保存的模型; - - :param filepath: 加载的文件的位置; - :param backend: 使用哪种 backend 来加载该 filepath, 目前支持 ["torch", "paddle", "jittor"] 。 - """ - - if filepath is None: - raise ValueError("Parameter `path` can not be None.") - - assert backend is not None, "Parameter `backend` can not be None." - - if backend == "torch": - import torch - _res = torch.load(filepath) - return _res - elif backend == "jittor": - raise NotImplementedError - elif backend == "paddle": - raise NotImplementedError - else: - raise ValueError("Parameter `backend` could only be one of these values: ['torch', 'jittor', 'paddle']") - - diff --git a/fastNLP/core/metrics/accuracy.py b/fastNLP/core/metrics/accuracy.py index d1ac1776..8b3889d0 100644 --- a/fastNLP/core/metrics/accuracy.py +++ b/fastNLP/core/metrics/accuracy.py @@ -13,15 +13,22 @@ from fastNLP.core.utils.utils import seq_len_to_mask class Accuracy(Metric): + def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = None): + """ + 计算 准确率 的 metric 。 - def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True): + :param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update() + 函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。 + :param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, + 当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 + """ super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) self.register_element(name='correct', value=0, aggregate_method='sum', backend=backend) self.register_element(name='total', value=0, aggregate_method="sum", backend=backend) def get_metric(self) -> dict: r""" - get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. + get_metric 函数将根据 evaluate 函数累计的评价指标统计量来计算最终的评价结果. :return dict evaluate_result: {"acc": float} """ diff --git a/fastNLP/core/metrics/backend/torch_backend/backend.py b/fastNLP/core/metrics/backend/torch_backend/backend.py index 8945ab01..45f25ba9 100644 --- a/fastNLP/core/metrics/backend/torch_backend/backend.py +++ b/fastNLP/core/metrics/backend/torch_backend/backend.py @@ -11,7 +11,6 @@ from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gathe if _NEED_IMPORT_TORCH: import torch import torch.distributed as dist - import torch.nn.functional as F def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List: diff --git a/fastNLP/core/metrics/classify_f1_pre_rec_metric.py b/fastNLP/core/metrics/classify_f1_pre_rec_metric.py index 6298eae2..c030d257 100644 --- a/fastNLP/core/metrics/classify_f1_pre_rec_metric.py +++ b/fastNLP/core/metrics/classify_f1_pre_rec_metric.py @@ -31,7 +31,20 @@ def _compute_f_pre_rec(beta_square, tp, fn, fp): class ClassifyFPreRecMetric(Metric): def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None, num_class: int = 0, only_gross: bool = True, f_type='micro', beta=1, backend: Union[str, Backend, None] = 'auto', - aggregate_when_get_metric: bool = False) -> None: + aggregate_when_get_metric: bool = None) -> None: + """ + + :param tag_vocab: + :param ignore_labels: + :param num_class: + :param only_gross: + :param f_type: + :param beta: + :param str backend: 目前支持四种类型的backend, [torch, paddle, jittor, auto]。其中 auto 表示根据实际调用 Metric.update() + 函数时传入的参数决定具体的 backend ,大部分情况下直接使用 auto 即可。 + :param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, + 当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 + """ super(ClassifyFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) if f_type not in ('micro', 'macro'): diff --git a/fastNLP/core/metrics/element.py b/fastNLP/core/metrics/element.py index 22ba2635..2e492e23 100644 --- a/fastNLP/core/metrics/element.py +++ b/fastNLP/core/metrics/element.py @@ -35,6 +35,8 @@ class Element: """ self._check_value_initialized() + if self.aggregate_method is None: # 如果没有 aggregate 则不进行聚合。 + return try: self._value = self.backend.aggregate(self._value, self.aggregate_method) except AggregateMethodError as e: diff --git a/fastNLP/core/metrics/metric.py b/fastNLP/core/metrics/metric.py index 2fb575fc..ef4839df 100644 --- a/fastNLP/core/metrics/metric.py +++ b/fastNLP/core/metrics/metric.py @@ -14,13 +14,13 @@ from fastNLP.core.metrics.element import Element class Metric: - def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True): + def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = None): """ :param str backend: 目前支持四种类型的backend, [torch, paddle, jittor, auto]。其中 auto 表示根据实际调用 Metric.update() 函数时传入的参数决定具体的 backend ,大部分情况下直接使用 auto 即可。 :param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, - 当 backend 不支持分布式时,该参数无意义。 + 当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 """ self.backend = AutoBackend(backend) self._updated = False @@ -43,7 +43,7 @@ class Metric: :param name: 当前 element 的名字,注册后,在 Metric 中可以通过 self.{name} 访问该变量。 :param value: 初始化的值。在调用 Metric.reset() 方法时也将自动设置为该值 - :param aggregate_method: 如何聚合多卡上的结果,如果为单卡执行,该值无意义。 + :param aggregate_method: 如何聚合多卡上的结果,如果为单卡执行,该值无意义。如果设置为 None 则表示该 element 不进行聚合。 :param backend: 使用的 backend 。Element 的类型会根据 backend 进行实际的初始化。例如 backend 为 torch 则该对象为 Torch.tensor ; 如果backend 为 paddle 则该对象为 paddle.tensor ;如果 backend 为 jittor , 则该对象为 jittor.Var 。 一般情况下直接默认为 auto 就行了,fastNLP 会根据实际调用 Metric.update() 函数时传入的参数进行合理的初始化,例如当传入 diff --git a/fastNLP/core/metrics/span_f1_pre_rec_metric.py b/fastNLP/core/metrics/span_f1_pre_rec_metric.py index 716cea30..a49914a5 100644 --- a/fastNLP/core/metrics/span_f1_pre_rec_metric.py +++ b/fastNLP/core/metrics/span_f1_pre_rec_metric.py @@ -218,7 +218,7 @@ class SpanFPreRecMetric(Metric): def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None, only_gross: bool = True, f_type='micro', - beta=1, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True,) -> None: + beta=1, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = None) -> None: r""" :param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), @@ -234,7 +234,7 @@ class SpanFPreRecMetric(Metric): :param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update() 函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。 :param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, - 当 backend 不支持分布式时,该参数无意义。 + 当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 """ super(SpanFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) if f_type not in ('micro', 'macro'): diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py index 6ea9cc6b..43017098 100644 --- a/fastNLP/core/samplers/reproducible_sampler.py +++ b/fastNLP/core/samplers/reproducible_sampler.py @@ -222,14 +222,14 @@ class RandomSampler(ReproducibleSampler): class SequentialSampler(RandomSampler): - def __init__(self, dataset, dist_mode:str='interval', **kwargs): + def __init__(self, dataset, **kwargs): """ 按照顺序读取 dataset 。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。 :param dataset: 实现了 __len__ 方法的数据容器。 :param kwargs: """ - super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) + super().__init__(dataset=dataset, **kwargs) def __iter__(self): if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 diff --git a/fastNLP/envs/__init__.py b/fastNLP/envs/__init__.py index 4ae30677..8ec8f3a4 100644 --- a/fastNLP/envs/__init__.py +++ b/fastNLP/envs/__init__.py @@ -6,8 +6,9 @@ __all__ = [ 'is_cur_env_distributed', 'get_global_rank', 'rank_zero_call', - 'all_rank_call', - 'get_gpu_count' + 'all_rank_call_context', + 'get_gpu_count', + 'fastnlp_no_sync_context' ] diff --git a/fastNLP/envs/distributed.py b/fastNLP/envs/distributed.py index f608272b..34515c2c 100644 --- a/fastNLP/envs/distributed.py +++ b/fastNLP/envs/distributed.py @@ -7,10 +7,11 @@ __all__ = [ 'is_cur_env_distributed', 'get_global_rank', 'rank_zero_call', - 'all_rank_call' + 'all_rank_call_context', + 'fastnlp_no_sync_context' ] -from fastNLP.envs.env import FASTNLP_GLOBAL_RANK +from fastNLP.envs.env import FASTNLP_GLOBAL_RANK, FASTNLP_NO_SYNC def is_cur_env_distributed() -> bool: @@ -41,24 +42,46 @@ def rank_zero_call(fn: Callable): return a+b rank_zero_call(add)(1, 2) + 同时,该函数还会设置 FASTNLP_NO_SYNC 为 2,在这个环境下,所有的 fastNLP 内置的 barrier 接口,gather/broadcast 操作都没有任何 + 意义。 + :param fn: 需要包裹的可执行的函数。 :return: """ @wraps(fn) def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]: if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: - return fn(*args, **kwargs) + with fastnlp_no_sync_context(level=2): + return fn(*args, **kwargs) return None return wrapped_fn @contextmanager -def all_rank_call(): +def fastnlp_no_sync_context(level=2): + """ + 用于让 fastNLP 的 barrier 以及 gather/broadcast等操作等同于只有1卡的多卡程序。如果为 1 表示 fastNLP 里的barrier 操作失效; + 如果为 2 表示 barrier 与 gather/broadcast 都失效。 + + :param int level: 可选 [0, 1, 2] + :return: + """ + old_level = os.environ.get(FASTNLP_NO_SYNC, None) + os.environ[FASTNLP_NO_SYNC] = f'{level}' + yield + if old_level is None: + os.environ.pop(FASTNLP_NO_SYNC) + else: + os.environ[FASTNLP_NO_SYNC] = old_level + + +@contextmanager +def all_rank_call_context(): """ 在多卡模式下,该环境内,会暂时地将 FASTNLP_GLOBAL_RANK 设置为 "0",使得 rank_zero_call 函数失效,使得每个进程都会运行该函数。 # 使用方式 - with all_rank_run(): + with all_rank_call_context(): do_something # all rank will do :param fn: diff --git a/fastNLP/envs/env.py b/fastNLP/envs/env.py index b0683259..a943de1f 100644 --- a/fastNLP/envs/env.py +++ b/fastNLP/envs/env.py @@ -48,6 +48,10 @@ FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH" # fastNLP 中初始化deque的默认大小 FASTNLP_DEQUE_SIZE = 'FASTNLP_DEQUE_SIZE' +# fastNLP中用于关闭 fastNLP 1.barrier 与 2.gather/broadcast 。默认为 '0' 表示不关闭;为 '1' 表示 fastNLP 的 barrier 不执行; +# 为 '2' 表示 barrier 与 gather/broadcast 都关闭。 +FASTNLP_NO_SYNC = 'FASTNLP_NO_SYNC' + # todo 注释 直接使用的变量 FASTNLP_MODEL_FILENAME = "fastnlp_model.pkl.tar" FASTNLP_CHECKPOINT_FILENAME = "fastnlp_checkpoint.pkl.tar" diff --git a/tests/core/drivers/paddle_driver/test_fleet.py b/tests/core/drivers/paddle_driver/test_fleet.py index de98f9c5..c775a3a2 100644 --- a/tests/core/drivers/paddle_driver/test_fleet.py +++ b/tests/core/drivers/paddle_driver/test_fleet.py @@ -69,7 +69,7 @@ class TestFleetDriverFunction: """ 测试 get_no_sync_context 函数 """ - res = self.driver.get_no_sync_context() + res = self.driver.get_model_no_sync_context() dist.barrier() @magic_argv_env_context diff --git a/tests/core/utils/test_distributed.py b/tests/core/utils/test_distributed.py index 017f412d..eecbc72a 100644 --- a/tests/core/utils/test_distributed.py +++ b/tests/core/utils/test_distributed.py @@ -1,6 +1,6 @@ import os -from fastNLP.envs.distributed import rank_zero_call, all_rank_call +from fastNLP.envs.distributed import rank_zero_call, all_rank_call_context from tests.helpers.utils import re_run_current_cmd_for_torch, Capturing, magic_argv_env_context @@ -70,7 +70,7 @@ class TestTorch: re_run_current_cmd_for_torch(1, output_from_new_proc='all') # torch.distributed.init_process_group(backend='nccl') # torch.distributed.barrier() - with all_rank_call(): + with all_rank_call_context(): with Capturing(no_del=True) as output: write_something() output = output[0] @@ -80,7 +80,7 @@ class TestTorch: else: assert '11111' in output - with all_rank_call(): + with all_rank_call_context(): with Capturing(no_del=True) as output: rank_zero_call(write_other_thing)()