@@ -138,10 +138,6 @@ class CheckpointCallback(HasMonitorCallback): | |||||
f'exception_{exception.__class__.__name__}' | f'exception_{exception.__class__.__name__}' | ||||
self.save(trainer=trainer, folder_name=folder_name) | self.save(trainer=trainer, folder_name=folder_name) | ||||
def on_sanity_check_end(self, trainer, sanity_check_res): | |||||
# 主要核对一下 monitor 是否存在。 | |||||
self.get_monitor_value(results=sanity_check_res) | |||||
def on_save_checkpoint(self, trainer) -> Dict: | def on_save_checkpoint(self, trainer) -> Dict: | ||||
""" | """ | ||||
保存 timestamp_path 使得之后可以继续训练并保存到该文件夹。 | 保存 timestamp_path 使得之后可以继续训练并保存到该文件夹。 | ||||
@@ -49,7 +49,8 @@ class HasMonitorCallback(Callback): | |||||
self.monitor = monitor | self.monitor = monitor | ||||
else: | else: | ||||
self.monitor = str(monitor) if monitor is not None else None | self.monitor = str(monitor) if monitor is not None else None | ||||
self.larger_better = bool(larger_better) | |||||
if self.monitor is not None: | |||||
self.larger_better = bool(larger_better) | |||||
if larger_better: | if larger_better: | ||||
self.monitor_value = float('-inf') | self.monitor_value = float('-inf') | ||||
else: | else: | ||||
@@ -71,6 +72,12 @@ class HasMonitorCallback(Callback): | |||||
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " | raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " | ||||
f"You can set it in the initialization or through Trainer.") | f"You can set it in the initialization or through Trainer.") | ||||
def on_sanity_check_end(self, trainer, sanity_check_res): | |||||
# 主要核对一下 monitor 是否存在。 | |||||
if self.monitor is not None: | |||||
self.get_monitor_value(results=sanity_check_res) | |||||
def get_monitor_value(self, results:Dict)->Union[float, None]: | def get_monitor_value(self, results:Dict)->Union[float, None]: | ||||
""" | """ | ||||
获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。 | 获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。 | ||||
@@ -10,7 +10,7 @@ import shutil | |||||
from fastNLP.envs.env import FASTNLP_LAUNCH_TIME, FASTNLP_GLOBAL_RANK, FASTNLP_BACKEND_LAUNCH | from fastNLP.envs.env import FASTNLP_LAUNCH_TIME, FASTNLP_GLOBAL_RANK, FASTNLP_BACKEND_LAUNCH | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.envs import all_rank_call | |||||
from fastNLP.envs import all_rank_call_context | |||||
class LoadBestModelCallback(HasMonitorCallback): | class LoadBestModelCallback(HasMonitorCallback): | ||||
@@ -76,9 +76,6 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
super().on_after_trainer_initialized(trainer, driver) | super().on_after_trainer_initialized(trainer, driver) | ||||
def on_sanity_check_end(self, trainer, sanity_check_res): | |||||
self.get_monitor_value(sanity_check_res) | |||||
def on_validate_end(self, trainer, results): | def on_validate_end(self, trainer, results): | ||||
if self.is_better_results(results, keep_if_better=True): | if self.is_better_results(results, keep_if_better=True): | ||||
if self.real_save_folder: | if self.real_save_folder: | ||||
@@ -86,7 +83,7 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
model_save_fn=self.model_save_fn) | model_save_fn=self.model_save_fn) | ||||
else: | else: | ||||
self.buffer.seek(0) | self.buffer.seek(0) | ||||
with all_rank_call(): | |||||
with all_rank_call_context(): | |||||
trainer.save_model(folder=self.buffer, only_state_dict=self.only_state_dict) | trainer.save_model(folder=self.buffer, only_state_dict=self.only_state_dict) | ||||
def on_train_end(self, trainer): | def on_train_end(self, trainer): | ||||
@@ -11,14 +11,15 @@ class LRSchedCallback(Callback): | |||||
根据 step_on 参数在合适的时机调用 scheduler 的 step 函数。 | 根据 step_on 参数在合适的时机调用 scheduler 的 step 函数。 | ||||
:param scheduler: 实现了 step() 函数的对象 | :param scheduler: 实现了 step() 函数的对象 | ||||
:param step_on: 可选 ['batch', 'epoch'] 表示在何时调用 scheduler 的 step 函数 | |||||
:param step_on: 可选 ['batch', 'epoch'] 表示在何时调用 scheduler 的 step 函数。如果为 batch 的话在每次更新参数 | |||||
之前调用;如果为 epoch 则是在一个 epoch 运行结束后调用。 | |||||
""" | """ | ||||
assert hasattr(scheduler, 'step') and callable(scheduler.step), "The scheduler object should have a " \ | assert hasattr(scheduler, 'step') and callable(scheduler.step), "The scheduler object should have a " \ | ||||
"step function." | "step function." | ||||
self.scheduler = scheduler | self.scheduler = scheduler | ||||
self.step_on = 0 if step_on == 'batch' else 1 | self.step_on = 0 if step_on == 'batch' else 1 | ||||
def on_train_batch_end(self, trainer): | |||||
def on_before_optimizers_step(self, trainer, optimizers): | |||||
if self.step_on == 0: | if self.step_on == 0: | ||||
self.scheduler.step() | self.scheduler.step() | ||||
@@ -32,10 +32,6 @@ class ProgressCallback(HasMonitorCallback): | |||||
def on_train_end(self, trainer): | def on_train_end(self, trainer): | ||||
f_rich_progress.stop() | f_rich_progress.stop() | ||||
def on_sanity_check_end(self, trainer, sanity_check_res): | |||||
if len(sanity_check_res) and getattr(self, 'monitor', None) is not None: | |||||
self.get_monitor_value(sanity_check_res) | |||||
class RichCallback(ProgressCallback): | class RichCallback(ProgressCallback): | ||||
def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, | def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, | ||||
@@ -3,7 +3,6 @@ from functools import partial | |||||
from dataclasses import is_dataclass | from dataclasses import is_dataclass | ||||
import sys | import sys | ||||
__all__ = [ | __all__ = [ | ||||
'Evaluator' | 'Evaluator' | ||||
] | ] | ||||
@@ -75,8 +74,8 @@ class Evaluator: | |||||
当 auto_tensor_conversion_for_metric 为True时,fastNLP 将自动将输出中 paddle 的 tensor (其它非 tensor 的参数 | 当 auto_tensor_conversion_for_metric 为True时,fastNLP 将自动将输出中 paddle 的 tensor (其它非 tensor 的参数 | ||||
不做任何处理)转换为 pytorch 的 tensor 再输入到 metrics 中进行评测。 model 的输出 tensor 类型通过 driver 来决定, | 不做任何处理)转换为 pytorch 的 tensor 再输入到 metrics 中进行评测。 model 的输出 tensor 类型通过 driver 来决定, | ||||
metrics 支持的输入类型由 metrics 决定。如果需要更复杂的转换,请使用 input_mapping、output_mapping 参数进行。 | metrics 支持的输入类型由 metrics 决定。如果需要更复杂的转换,请使用 input_mapping、output_mapping 参数进行。 | ||||
use_dist_sampler: 是否使用分布式evaluate的方式。仅当 driver 为分布式类型时,该参数才有效。如果为True,将使得每个进程上 | |||||
的 dataloader 自动使用不同数据,所有进程的数据并集是整个数据集。请确保使用的 metrics 支持自动分布式累积。 | |||||
use_dist_sampler: 是否使用分布式evaluate的方式。仅当 driver 为分布式类型时,该参数才有效。默认为根据 driver 是否支持 | |||||
分布式进行设置。如果为True,将使得每个进程上的 dataloader 自动使用不同数据,所有进程的数据并集是整个数据集。 | |||||
output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | ||||
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | ||||
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | ||||
@@ -86,7 +85,8 @@ class Evaluator: | |||||
self.model = model | self.model = model | ||||
self.metrics = metrics | self.metrics = metrics | ||||
self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call, **kwargs) | |||||
self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call, | |||||
**kwargs) | |||||
if dataloaders is None: | if dataloaders is None: | ||||
raise ValueError("Parameter `dataloaders` can not be None.") | raise ValueError("Parameter `dataloaders` can not be None.") | ||||
@@ -105,9 +105,13 @@ class Evaluator: | |||||
dataloaders = {None: dataloaders} | dataloaders = {None: dataloaders} | ||||
self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=batch_step_fn) | self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=batch_step_fn) | ||||
self.driver.setup() | |||||
self.driver.barrier() | |||||
self.separator = kwargs.get('separator', '#') | self.separator = kwargs.get('separator', '#') | ||||
self.model_use_eval_mode = kwargs.get('model_use_eval_mode', True) | self.model_use_eval_mode = kwargs.get('model_use_eval_mode', True) | ||||
use_dist_sampler = kwargs.get("use_dist_sampler", False) # 如果是 Evaluator 自身的默认值的话,应当为 False; | |||||
use_dist_sampler = kwargs.get("use_dist_sampler", driver.is_distributed()) | |||||
if use_dist_sampler: | if use_dist_sampler: | ||||
self._dist_sampler = "unrepeatdist" | self._dist_sampler = "unrepeatdist" | ||||
else: | else: | ||||
@@ -115,8 +119,9 @@ class Evaluator: | |||||
self._metric_wrapper = None | self._metric_wrapper = None | ||||
_ = self.metrics_wrapper # 触发检查 | _ = self.metrics_wrapper # 触发检查 | ||||
self.driver.setup() | |||||
self.driver.barrier() | |||||
if self._dist_sampler is not None and not self.driver.is_distributed(): | |||||
logger.warning_once("Running in a non-distributed driver, but with distributed sampler, it may cause " | |||||
"different process evaluating on different data.") | |||||
if evaluate_fn is not None and not isinstance(evaluate_fn, str): | if evaluate_fn is not None and not isinstance(evaluate_fn, str): | ||||
raise TypeError("Parameter `evaluate_fn` can only be `str` type when it is not None.") | raise TypeError("Parameter `evaluate_fn` can only be `str` type when it is not None.") | ||||
@@ -183,7 +188,7 @@ class Evaluator: | |||||
return metric_results | return metric_results | ||||
def start_progress_bar(self, total:int, dataloader_name): | |||||
def start_progress_bar(self, total: int, dataloader_name): | |||||
if self.progress_bar == 'rich': | if self.progress_bar == 'rich': | ||||
if dataloader_name is None: | if dataloader_name is None: | ||||
desc = f'Eval. Batch:0' | desc = f'Eval. Batch:0' | ||||
@@ -208,7 +213,7 @@ class Evaluator: | |||||
advance=kwargs.get('advance', 1), refresh=kwargs.get('refresh', True), | advance=kwargs.get('advance', 1), refresh=kwargs.get('refresh', True), | ||||
visible=kwargs.get('visible', True)) | visible=kwargs.get('visible', True)) | ||||
elif self.progress_bar == 'raw': | elif self.progress_bar == 'raw': | ||||
if self.verbose>1: | |||||
if self.verbose > 1: | |||||
logger.info(desc) | logger.info(desc) | ||||
def remove_progress_bar(self, dataloader_name): | def remove_progress_bar(self, dataloader_name): | ||||
@@ -256,7 +261,7 @@ class Evaluator: | |||||
""" | """ | ||||
self.metrics_wrapper.update(*args, **kwargs) | self.metrics_wrapper.update(*args, **kwargs) | ||||
def get_dataloader_metric(self, dataloader_name:Optional[str]='') -> Dict: | |||||
def get_dataloader_metric(self, dataloader_name: Optional[str] = '') -> Dict: | |||||
""" | """ | ||||
获取当前dataloader的metric结果 | 获取当前dataloader的metric结果 | ||||
@@ -313,6 +318,7 @@ class _MetricsWrapper: | |||||
并且通过对 update() , reset() , get_metric() 函数的封装,实现支持 fastNLP 的 metric 以及 torchmetrics 或者更多。 | 并且通过对 update() , reset() , get_metric() 函数的封装,实现支持 fastNLP 的 metric 以及 torchmetrics 或者更多。 | ||||
""" | """ | ||||
def __init__(self, metrics, evaluator): | def __init__(self, metrics, evaluator): | ||||
self.evaluator = evaluator | self.evaluator = evaluator | ||||
self._metrics = [] | self._metrics = [] | ||||
@@ -326,13 +332,14 @@ class _MetricsWrapper: | |||||
# torchmetrics 是默认自动开启了多卡的 | # torchmetrics 是默认自动开启了多卡的 | ||||
evaluator.driver.move_model_to_device(metric, evaluator.driver.data_device) | evaluator.driver.move_model_to_device(metric, evaluator.driver.data_device) | ||||
elif isinstance(metric, Metric): | elif isinstance(metric, Metric): | ||||
if evaluator._dist_sampler is not None and evaluator.driver.is_distributed() \ | |||||
and metric.aggregate_when_get_metric is False: | |||||
logger.warning("You have replace the sampler as distributed sampler when evaluation, but your " | |||||
f"metric:{metric_name}' `aggregate_when_get_metric` is False.") | |||||
if evaluator._dist_sampler is None and evaluator.driver.is_distributed() \ | |||||
and metric.aggregate_when_get_metric is True: | |||||
pass # 这种情况无所谓,因为 | |||||
# 如果数据是分布式的,但是不aggregate的话可能有问题 | |||||
if evaluator._dist_sampler is not None and metric.aggregate_when_get_metric is False: | |||||
logger.warning_once( | |||||
"You have replace the sampler as distributed sampler when evaluation, but your " | |||||
f"metric {metric_name}:{metric.__class__.__name__}' `aggregate_when_get_metric` is False.") | |||||
if metric.aggregate_when_get_metric is None: | |||||
metric.aggregate_when_get_metric = evaluator._dist_sampler is not None | |||||
metric.to(evaluator.driver.data_device) | metric.to(evaluator.driver.data_device) | ||||
self._metric_names.append(metric_name) | self._metric_names.append(metric_name) | ||||
self._metrics.append(metric) | self._metrics.append(metric) | ||||
@@ -343,8 +350,9 @@ class _MetricsWrapper: | |||||
for metric in self._metrics: | for metric in self._metrics: | ||||
args = [] | args = [] | ||||
if not isinstance(batch, dict): | if not isinstance(batch, dict): | ||||
logger.warning_once(f"The output of the DataLoader is of type:`{type(batch)}`, fastNLP will only depend on " | |||||
f"the output of model to update metric.") | |||||
logger.warning_once( | |||||
f"The output of the DataLoader is of type:`{type(batch)}`, fastNLP will only depend on " | |||||
f"the output of model to update metric.") | |||||
else: | else: | ||||
args.append(batch) | args.append(batch) | ||||
if not isinstance(outputs, dict): | if not isinstance(outputs, dict): | ||||
@@ -368,7 +376,7 @@ class _MetricsWrapper: | |||||
elif _is_torchmetrics_metric(metric) or _is_paddle_metric(metric) or isinstance(metric, Metric): | elif _is_torchmetrics_metric(metric) or _is_paddle_metric(metric) or isinstance(metric, Metric): | ||||
metric.reset() | metric.reset() | ||||
def get_metric(self, dataloader_name:str, separator:str) -> Dict: | |||||
def get_metric(self, dataloader_name: str, separator: str) -> Dict: | |||||
""" | """ | ||||
将所有 metric 结果展平到一个一级的字典中,这个字典中 key 的命名规则是 | 将所有 metric 结果展平到一个一级的字典中,这个字典中 key 的命名规则是 | ||||
indicator_name{separator}metric_name{separator}dataloader_name | indicator_name{separator}metric_name{separator}dataloader_name | ||||
@@ -419,4 +427,4 @@ def _get_metric_res_name(dataloader_name: Optional[str], metric_name: str, indic | |||||
names.append(dataloader_name) | names.append(dataloader_name) | ||||
if len(names) == 0: | if len(names) == 0: | ||||
raise RuntimeError("You cannot use empty `dataloader_name`, `metric_name`, and `monitor` simultaneously.") | raise RuntimeError("You cannot use empty `dataloader_name`, `metric_name`, and `monitor` simultaneously.") | ||||
return separator.join(names) | |||||
return separator.join(names) |
@@ -122,7 +122,8 @@ class Trainer(TrainerEventTrigger): | |||||
注意如果 model_device 为 None,那么 data_device 不会起作用; | 注意如果 model_device 为 None,那么 data_device 不会起作用; | ||||
torch_ddp_kwargs: 用于配置 pytorch 的 DistributedDataParallel 初始化时的参数; | torch_ddp_kwargs: 用于配置 pytorch 的 DistributedDataParallel 初始化时的参数; | ||||
set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; | set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; | ||||
use_dist_sampler: 表示在使用 TorchDDPDriver 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; | |||||
use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch | |||||
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 | |||||
use_eval_dist_sampler: 表示在 Evaluator 中在使用 TorchDDPDriver 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; | use_eval_dist_sampler: 表示在 Evaluator 中在使用 TorchDDPDriver 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; | ||||
output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | ||||
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | ||||
@@ -211,12 +212,6 @@ class Trainer(TrainerEventTrigger): | |||||
total_batches=None | total_batches=None | ||||
) | ) | ||||
use_dist_sampler = kwargs.get("use_dist_sampler", True) | |||||
if use_dist_sampler: | |||||
_dist_sampler = "dist" | |||||
else: | |||||
_dist_sampler = None | |||||
""" 设置内部的 Evaluator """ | """ 设置内部的 Evaluator """ | ||||
if metrics is None and evaluate_dataloaders is not None: | if metrics is None and evaluate_dataloaders is not None: | ||||
raise ValueError("You have set 'evaluate_dataloader' but forget to set 'metrics'.") | raise ValueError("You have set 'evaluate_dataloader' but forget to set 'metrics'.") | ||||
@@ -224,6 +219,18 @@ class Trainer(TrainerEventTrigger): | |||||
if metrics is not None and evaluate_dataloaders is None: | if metrics is not None and evaluate_dataloaders is None: | ||||
raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloader'.") | raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloader'.") | ||||
self.metrics = metrics | |||||
self.validate_every = evaluate_every | |||||
self.driver.setup() | |||||
self.driver.barrier() | |||||
use_dist_sampler = kwargs.get("use_dist_sampler", self.driver.is_distributed()) | |||||
if use_dist_sampler: | |||||
_dist_sampler = "dist" | |||||
else: | |||||
_dist_sampler = None | |||||
self.evaluator = None | self.evaluator = None | ||||
self.monitor = monitor | self.monitor = monitor | ||||
self.larger_better = larger_better | self.larger_better = larger_better | ||||
@@ -241,16 +248,10 @@ class Trainer(TrainerEventTrigger): | |||||
output_mapping=output_mapping, | output_mapping=output_mapping, | ||||
fp16=fp16, | fp16=fp16, | ||||
verbose=0, | verbose=0, | ||||
use_dist_sampler=kwargs.get("use_eval_dist_sampler", use_dist_sampler), | |||||
use_dist_sampler=kwargs.get("use_eval_dist_sampler", None), | |||||
progress_bar=kwargs.get('progress_bar', 'auto') | progress_bar=kwargs.get('progress_bar', 'auto') | ||||
) | ) | ||||
self.metrics = metrics | |||||
self.validate_every = evaluate_every | |||||
self.driver.setup() | |||||
self.driver.barrier() | |||||
if train_fn is not None and not isinstance(train_fn, str): | if train_fn is not None and not isinstance(train_fn, str): | ||||
raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") | raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") | ||||
self._train_step, self._train_step_signature_fn = self.driver.get_model_call_fn("train_step" if train_fn is None else train_fn) | self._train_step, self._train_step_signature_fn = self.driver.get_model_call_fn("train_step" if train_fn is None else train_fn) | ||||
@@ -753,7 +754,7 @@ class Trainer(TrainerEventTrigger): | |||||
""" | """ | ||||
if (self.global_forward_batches + 1) % self.accumulation_steps != 0: | if (self.global_forward_batches + 1) % self.accumulation_steps != 0: | ||||
_no_sync_context = self.driver.get_no_sync_context() | |||||
_no_sync_context = self.driver.get_model_no_sync_context() | |||||
else: | else: | ||||
_no_sync_context = nullcontext | _no_sync_context = nullcontext | ||||
@@ -199,9 +199,10 @@ class Driver(ABC): | |||||
""" | """ | ||||
raise NotImplementedError("Each specific driver should implemented its own `zero_grad` function.") | raise NotImplementedError("Each specific driver should implemented its own `zero_grad` function.") | ||||
def get_no_sync_context(self): | |||||
def get_model_no_sync_context(self): | |||||
r""" | r""" | ||||
返回一个用于关闭多进程之间互相同步操作的 context 上下文对象;只有多卡的 driver 需要单独实现该函数,单卡的 driver 不需要; | |||||
返回一个用于关闭多进程之间 model 中的自动互相同步操作的 context 上下文对象;只有多卡的 driver 需要单独实现该函数, | |||||
单卡的 driver 不需要; | |||||
:return: 返回一个类似于 DistributedDataParallel(model).no_sync 的 context 上下文对象; | :return: 返回一个类似于 DistributedDataParallel(model).no_sync 的 context 上下文对象; | ||||
""" | """ | ||||
@@ -357,6 +358,8 @@ class Driver(ABC): | |||||
r""" | r""" | ||||
用于在多进程工作时同步各进程的工作进度,运行快的进程运行到这里会等待运行慢的进程,只有所有进程都运行到此函数时,所有的进程才会继续运行; | 用于在多进程工作时同步各进程的工作进度,运行快的进程运行到这里会等待运行慢的进程,只有所有进程都运行到此函数时,所有的进程才会继续运行; | ||||
仅在多分布式训练场景中有使用。 | 仅在多分布式训练场景中有使用。 | ||||
注意,该函数的行为会受到 FASTNLP_NO_SYNC 的影响。仅当 FASTNLP_NO_SYNC 在 os.environ 中不存在,或小于 1 时才真的执行 barrier 。 | |||||
""" | """ | ||||
def is_distributed(self) -> bool: | def is_distributed(self) -> bool: | ||||
@@ -82,7 +82,7 @@ class JittorMPIDriver(JittorDriver): | |||||
def is_global_zero(self): | def is_global_zero(self): | ||||
return self.global_rank == 0 | return self.global_rank == 0 | ||||
def get_no_sync_context(self): | |||||
def get_model_no_sync_context(self): | |||||
return self.model.no_sync | return self.model.no_sync | ||||
def unwrap_model(self): | def unwrap_model(self): | ||||
@@ -405,7 +405,7 @@ class PaddleFleetDriver(PaddleDriver): | |||||
def is_global_zero(self): | def is_global_zero(self): | ||||
return self.global_rank == 0 | return self.global_rank == 0 | ||||
def get_no_sync_context(self): | |||||
def get_model_no_sync_context(self): | |||||
return self.model.no_sync | return self.model.no_sync | ||||
def unwrap_model(self): | def unwrap_model(self): | ||||
@@ -5,7 +5,6 @@ import socket | |||||
import numpy as np | import numpy as np | ||||
from time import sleep | from time import sleep | ||||
from typing import List, Optional, Union, Dict, Tuple, Callable | from typing import List, Optional, Union, Dict, Tuple, Callable | ||||
from functools import partial | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
@@ -29,7 +28,7 @@ from fastNLP.core.drivers.utils import distributed_open_proc | |||||
from fastNLP.core.utils import auto_param_call, check_user_specific_params | from fastNLP.core.utils import auto_param_call, check_user_specific_params | ||||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, ReproducibleBatchSampler, \ | from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, ReproducibleBatchSampler, \ | ||||
re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler | re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler | ||||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED | |||||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED, FASTNLP_NO_SYNC | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object | from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object | ||||
@@ -511,7 +510,7 @@ class TorchDDPDriver(TorchDriver): | |||||
def is_global_zero(self): | def is_global_zero(self): | ||||
return self.global_rank == 0 | return self.global_rank == 0 | ||||
def get_no_sync_context(self): | |||||
def get_model_no_sync_context(self): | |||||
# 注意此时的 model 是 "DistributedDataParallel" 对象; | # 注意此时的 model 是 "DistributedDataParallel" 对象; | ||||
return self.model.no_sync | return self.model.no_sync | ||||
@@ -526,7 +525,8 @@ class TorchDDPDriver(TorchDriver): | |||||
return self.local_rank | return self.local_rank | ||||
def barrier(self): | def barrier(self): | ||||
torch.distributed.barrier(async_op=True) | |||||
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行 | |||||
torch.distributed.barrier(async_op=True) | |||||
def is_distributed(self): | def is_distributed(self): | ||||
return True | return True | ||||
@@ -544,6 +544,8 @@ class TorchDDPDriver(TorchDriver): | |||||
:return: 如果当前不是分布式 driver 直接返回输入的 obj 。如果当前 rank 是接收端(其 global rank 包含在了 dst 中),则返回 | :return: 如果当前不是分布式 driver 直接返回输入的 obj 。如果当前 rank 是接收端(其 global rank 包含在了 dst 中),则返回 | ||||
接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。 | 接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。 | ||||
""" | """ | ||||
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC == 2 直接返回。 | |||||
return | |||||
return fastnlp_torch_broadcast_object(obj, src, device=self.data_device, group=group) | return fastnlp_torch_broadcast_object(obj, src, device=self.data_device, group=group) | ||||
def all_gather(self, obj, group) -> List: | def all_gather(self, obj, group) -> List: | ||||
@@ -569,6 +571,8 @@ class TorchDDPDriver(TorchDriver): | |||||
:param group: | :param group: | ||||
:return: | :return: | ||||
""" | """ | ||||
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC 表示不执行 | |||||
return [obj] | |||||
return fastnlp_torch_all_gather(obj, group=group) | return fastnlp_torch_all_gather(obj, group=group) | ||||
@@ -1,5 +1,6 @@ | |||||
import io | import io | ||||
import pickle | import pickle | ||||
import os | |||||
_pickler = pickle.Pickler | _pickler = pickle.Pickler | ||||
_unpickler = pickle.Unpickler | _unpickler = pickle.Unpickler | ||||
from typing import Any, List | from typing import Any, List | ||||
@@ -7,6 +8,7 @@ from typing import Any, List | |||||
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 | from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 | ||||
from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP | from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
from fastNLP.envs.env import FASTNLP_NO_SYNC | |||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
from torch import distributed as dist | from torch import distributed as dist | ||||
@@ -34,47 +36,15 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list): | |||||
) | ) | ||||
def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=DEFAULT_TORCH_GROUP): | |||||
def fastnlp_torch_gather_object(obj, dst=0, group=DEFAULT_TORCH_GROUP): | |||||
""" | """ | ||||
从其它 rank gather 东西到 dst rank 。 | 从其它 rank gather 东西到 dst rank 。 | ||||
Gathers picklable objects from the whole group in a single process. | |||||
Similar to :func:`gather`, but Python objects can be passed in. Note that the | |||||
object must be picklable in order to be gathered. | |||||
Args: | |||||
obj (Any): Input object. Must be picklable. | |||||
object_gather_list (list[Any]): Output list. On the ``dst`` rank, it | |||||
should be correctly sized as the size of the group for this | |||||
collective and will contain the output. Must be ``None`` on non-dst | |||||
ranks. (default is ``None``) | |||||
dst (int, optional): Destination rank. (default is 0) | |||||
group: (ProcessGroup, optional): The process group to work on. If None, | |||||
the default process group will be used. Default is ``None``. | |||||
Returns: | |||||
None. On the ``dst`` rank, ``object_gather_list`` will contain the | |||||
output of the collective. | |||||
.. note:: Note that this API differs slightly from the gather collective | |||||
since it does not provide an async_op handle and thus will be a blocking | |||||
call. | |||||
.. note:: Note that this API is not supported when using the NCCL backend. | |||||
.. warning:: | |||||
:func:`gather_object` uses ``pickle`` module implicitly, which is | |||||
known to be insecure. It is possible to construct malicious pickle data | |||||
which will execute arbitrary code during unpickling. Only call this | |||||
function with data you trust. | |||||
Example:: | Example:: | ||||
>>> # Note: Process group initialization omitted on each rank. | |||||
>>> import torch.distributed as dist | |||||
>>> # Assumes world_size of 3. | >>> # Assumes world_size of 3. | ||||
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object | >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object | ||||
>>> output = [None for _ in gather_objects] | >>> output = [None for _ in gather_objects] | ||||
>>> dist.gather_object( | |||||
>>> fastnlp_torch_gather_object( | |||||
gather_objects[dist.get_rank()], | gather_objects[dist.get_rank()], | ||||
output if dist.get_rank() == 0 else None, | output if dist.get_rank() == 0 else None, | ||||
dst=0 | dst=0 | ||||
@@ -82,7 +52,20 @@ def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=DEFAU | |||||
>>> # On rank 0 | >>> # On rank 0 | ||||
>>> output | >>> output | ||||
['foo', 12, {1: 2}] | ['foo', 12, {1: 2}] | ||||
:param obj: 需要发送的 obj 对象,需要是可以 pickable 的对象 | |||||
:param dst: 目标的 rank 。 | |||||
:param group: 在哪个 group 执行该函数。 | |||||
:return: 在 dst 上面返回 world_size 的 list,依次为 rank 0;rank 1...上 obj | |||||
""" | """ | ||||
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: | |||||
return [obj] | |||||
if dist.get_rank() == dst: | |||||
object_gather_list = [None for _ in range(dist.get_world_size(group))] | |||||
else: | |||||
object_gather_list = None | |||||
if group is None: | if group is None: | ||||
group = DEFAULT_TORCH_GROUP | group = DEFAULT_TORCH_GROUP | ||||
@@ -212,6 +195,9 @@ def fastnlp_torch_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) - | |||||
:param group: | :param group: | ||||
:return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 | :return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 | ||||
""" | """ | ||||
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: | |||||
return [obj] | |||||
if group is None: | if group is None: | ||||
group = DEFAULT_TORCH_GROUP | group = DEFAULT_TORCH_GROUP | ||||
if isinstance(obj, torch.Tensor): | if isinstance(obj, torch.Tensor): | ||||
@@ -232,12 +218,18 @@ def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GR | |||||
""" | """ | ||||
将 src 上的 obj 对象广播到其它 rank 上。 | 将 src 上的 obj 对象广播到其它 rank 上。 | ||||
:param obj: | |||||
:param src: | |||||
:param obj: 需要发送的对象 | |||||
:param src: 从哪里发出。 | |||||
:param device: | :param device: | ||||
:param group: | |||||
:param group: 属于哪个通信 group | |||||
:return: | :return: | ||||
""" | """ | ||||
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: | |||||
if src == dist.get_rank(group): | |||||
return obj | |||||
else: | |||||
return None | |||||
if group is None: | if group is None: | ||||
group = DEFAULT_TORCH_GROUP | group = DEFAULT_TORCH_GROUP | ||||
cur_rank = dist.get_rank(group) | cur_rank = dist.get_rank(group) | ||||
@@ -289,50 +281,23 @@ def all_gather_object(object_list, obj, group=None): | |||||
""" | """ | ||||
复制 pytorch 的代码,使得可以版本兼容低版本的 pytorch 。 | 复制 pytorch 的代码,使得可以版本兼容低版本的 pytorch 。 | ||||
Gathers picklable objects from the whole group into a list. Similar to | |||||
:func:`all_gather`, but Python objects can be passed in. Note that the object | |||||
must be picklable in order to be gathered. | |||||
Args: | |||||
object_list (list[Any]): Output list. It should be correctly sized as the | |||||
size of the group for this collective and will contain the output. | |||||
object (Any): Pickable Python object to be broadcast from current process. | |||||
group (ProcessGroup, optional): The process group to work on. If None, | |||||
the default process group will be used. Default is ``None``. | |||||
Returns: | |||||
None. If the calling rank is part of this group, the output of the | |||||
collective will be populated into the input ``object_list``. If the | |||||
calling rank is not part of the group, the passed in ``object_list`` will | |||||
be unmodified. | |||||
.. note:: Note that this API differs slightly from the :func:`all_gather` | |||||
collective since it does not provide an ``async_op`` handle and thus | |||||
will be a blocking call. | |||||
.. note:: For NCCL-based processed groups, internal tensor representations | |||||
of objects must be moved to the GPU device before communication takes | |||||
place. In this case, the device used is given by | |||||
``torch.cuda.current_device()`` and it is the user's responsiblity to | |||||
ensure that this is set so that each rank has an individual GPU, via | |||||
``torch.cuda.set_device()``. | |||||
.. warning:: | |||||
:func:`all_gather_object` uses ``pickle`` module implicitly, which is | |||||
known to be insecure. It is possible to construct malicious pickle data | |||||
which will execute arbitrary code during unpickling. Only call this | |||||
function with data you trust. | |||||
Example:: | Example:: | ||||
>>> # Note: Process group initialization omitted on each rank. | >>> # Note: Process group initialization omitted on each rank. | ||||
>>> import torch.distributed as dist | |||||
>>> # Assumes world_size of 3. | >>> # Assumes world_size of 3. | ||||
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object | >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object | ||||
>>> output = [None for _ in gather_objects] | >>> output = [None for _ in gather_objects] | ||||
>>> dist.all_gather_object(output, gather_objects[dist.get_rank()]) | |||||
>>> all_gather_object(output, gather_objects[dist.get_rank()]) | |||||
>>> output | >>> output | ||||
['foo', 12, {1: 2}] | ['foo', 12, {1: 2}] | ||||
:param object_list: | |||||
:param obj: | |||||
:param group: | |||||
:return: | |||||
""" | """ | ||||
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: | |||||
return [obj] | |||||
if dist.distributed_c10d._rank_not_in_group(group): | if dist.distributed_c10d._rank_not_in_group(group): | ||||
return | return | ||||
if _TORCH_GREATER_EQUAL_1_8: | if _TORCH_GREATER_EQUAL_1_8: | ||||
@@ -35,7 +35,6 @@ def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, | |||||
"'jittor', 'paddle', 'fleet'].") | "'jittor', 'paddle', 'fleet'].") | ||||
def distributed_open_proc(output_from_new_proc:str, command:List[str], env_copy:dict, rank:int=None): | def distributed_open_proc(output_from_new_proc:str, command:List[str], env_copy:dict, rank:int=None): | ||||
""" | """ | ||||
使用 command 通过 subprocess.Popen 开启新的进程。 | 使用 command 通过 subprocess.Popen 开启新的进程。 | ||||
@@ -60,30 +59,3 @@ def distributed_open_proc(output_from_new_proc:str, command:List[str], env_copy: | |||||
err_f = open(output_from_new_proc + f'/{rank}_err.log', 'w') | err_f = open(output_from_new_proc + f'/{rank}_err.log', 'w') | ||||
proc = subprocess.Popen(command, env=env_copy, stdout=std_f, stderr=err_f) | proc = subprocess.Popen(command, env=env_copy, stdout=std_f, stderr=err_f) | ||||
return proc | return proc | ||||
def load_model(filepath: Union[str, Path], backend: str = "torch", **kwargs): | |||||
r""" | |||||
对应 `load_model`,用来帮助用户加载之前通过 `load_model` 所保存的模型; | |||||
:param filepath: 加载的文件的位置; | |||||
:param backend: 使用哪种 backend 来加载该 filepath, 目前支持 ["torch", "paddle", "jittor"] 。 | |||||
""" | |||||
if filepath is None: | |||||
raise ValueError("Parameter `path` can not be None.") | |||||
assert backend is not None, "Parameter `backend` can not be None." | |||||
if backend == "torch": | |||||
import torch | |||||
_res = torch.load(filepath) | |||||
return _res | |||||
elif backend == "jittor": | |||||
raise NotImplementedError | |||||
elif backend == "paddle": | |||||
raise NotImplementedError | |||||
else: | |||||
raise ValueError("Parameter `backend` could only be one of these values: ['torch', 'jittor', 'paddle']") | |||||
@@ -13,15 +13,22 @@ from fastNLP.core.utils.utils import seq_len_to_mask | |||||
class Accuracy(Metric): | class Accuracy(Metric): | ||||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = None): | |||||
""" | |||||
计算 准确率 的 metric 。 | |||||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True): | |||||
:param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update() | |||||
函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。 | |||||
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | |||||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 | |||||
""" | |||||
super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | ||||
self.register_element(name='correct', value=0, aggregate_method='sum', backend=backend) | self.register_element(name='correct', value=0, aggregate_method='sum', backend=backend) | ||||
self.register_element(name='total', value=0, aggregate_method="sum", backend=backend) | self.register_element(name='total', value=0, aggregate_method="sum", backend=backend) | ||||
def get_metric(self) -> dict: | def get_metric(self) -> dict: | ||||
r""" | r""" | ||||
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. | |||||
get_metric 函数将根据 evaluate 函数累计的评价指标统计量来计算最终的评价结果. | |||||
:return dict evaluate_result: {"acc": float} | :return dict evaluate_result: {"acc": float} | ||||
""" | """ | ||||
@@ -11,7 +11,6 @@ from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gathe | |||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
import torch.distributed as dist | import torch.distributed as dist | ||||
import torch.nn.functional as F | |||||
def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List: | def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List: | ||||
@@ -31,7 +31,20 @@ def _compute_f_pre_rec(beta_square, tp, fn, fp): | |||||
class ClassifyFPreRecMetric(Metric): | class ClassifyFPreRecMetric(Metric): | ||||
def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None, num_class: int = 0, | def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None, num_class: int = 0, | ||||
only_gross: bool = True, f_type='micro', beta=1, backend: Union[str, Backend, None] = 'auto', | only_gross: bool = True, f_type='micro', beta=1, backend: Union[str, Backend, None] = 'auto', | ||||
aggregate_when_get_metric: bool = False) -> None: | |||||
aggregate_when_get_metric: bool = None) -> None: | |||||
""" | |||||
:param tag_vocab: | |||||
:param ignore_labels: | |||||
:param num_class: | |||||
:param only_gross: | |||||
:param f_type: | |||||
:param beta: | |||||
:param str backend: 目前支持四种类型的backend, [torch, paddle, jittor, auto]。其中 auto 表示根据实际调用 Metric.update() | |||||
函数时传入的参数决定具体的 backend ,大部分情况下直接使用 auto 即可。 | |||||
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | |||||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 | |||||
""" | |||||
super(ClassifyFPreRecMetric, self).__init__(backend=backend, | super(ClassifyFPreRecMetric, self).__init__(backend=backend, | ||||
aggregate_when_get_metric=aggregate_when_get_metric) | aggregate_when_get_metric=aggregate_when_get_metric) | ||||
if f_type not in ('micro', 'macro'): | if f_type not in ('micro', 'macro'): | ||||
@@ -35,6 +35,8 @@ class Element: | |||||
""" | """ | ||||
self._check_value_initialized() | self._check_value_initialized() | ||||
if self.aggregate_method is None: # 如果没有 aggregate 则不进行聚合。 | |||||
return | |||||
try: | try: | ||||
self._value = self.backend.aggregate(self._value, self.aggregate_method) | self._value = self.backend.aggregate(self._value, self.aggregate_method) | ||||
except AggregateMethodError as e: | except AggregateMethodError as e: | ||||
@@ -14,13 +14,13 @@ from fastNLP.core.metrics.element import Element | |||||
class Metric: | class Metric: | ||||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True): | |||||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = None): | |||||
""" | """ | ||||
:param str backend: 目前支持四种类型的backend, [torch, paddle, jittor, auto]。其中 auto 表示根据实际调用 Metric.update() | :param str backend: 目前支持四种类型的backend, [torch, paddle, jittor, auto]。其中 auto 表示根据实际调用 Metric.update() | ||||
函数时传入的参数决定具体的 backend ,大部分情况下直接使用 auto 即可。 | 函数时传入的参数决定具体的 backend ,大部分情况下直接使用 auto 即可。 | ||||
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | :param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | ||||
当 backend 不支持分布式时,该参数无意义。 | |||||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 | |||||
""" | """ | ||||
self.backend = AutoBackend(backend) | self.backend = AutoBackend(backend) | ||||
self._updated = False | self._updated = False | ||||
@@ -43,7 +43,7 @@ class Metric: | |||||
:param name: 当前 element 的名字,注册后,在 Metric 中可以通过 self.{name} 访问该变量。 | :param name: 当前 element 的名字,注册后,在 Metric 中可以通过 self.{name} 访问该变量。 | ||||
:param value: 初始化的值。在调用 Metric.reset() 方法时也将自动设置为该值 | :param value: 初始化的值。在调用 Metric.reset() 方法时也将自动设置为该值 | ||||
:param aggregate_method: 如何聚合多卡上的结果,如果为单卡执行,该值无意义。 | |||||
:param aggregate_method: 如何聚合多卡上的结果,如果为单卡执行,该值无意义。如果设置为 None 则表示该 element 不进行聚合。 | |||||
:param backend: 使用的 backend 。Element 的类型会根据 backend 进行实际的初始化。例如 backend 为 torch 则该对象为 | :param backend: 使用的 backend 。Element 的类型会根据 backend 进行实际的初始化。例如 backend 为 torch 则该对象为 | ||||
Torch.tensor ; 如果backend 为 paddle 则该对象为 paddle.tensor ;如果 backend 为 jittor , 则该对象为 jittor.Var 。 | Torch.tensor ; 如果backend 为 paddle 则该对象为 paddle.tensor ;如果 backend 为 jittor , 则该对象为 jittor.Var 。 | ||||
一般情况下直接默认为 auto 就行了,fastNLP 会根据实际调用 Metric.update() 函数时传入的参数进行合理的初始化,例如当传入 | 一般情况下直接默认为 auto 就行了,fastNLP 会根据实际调用 Metric.update() 函数时传入的参数进行合理的初始化,例如当传入 | ||||
@@ -218,7 +218,7 @@ class SpanFPreRecMetric(Metric): | |||||
def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None, | def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None, | ||||
only_gross: bool = True, f_type='micro', | only_gross: bool = True, f_type='micro', | ||||
beta=1, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True,) -> None: | |||||
beta=1, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = None) -> None: | |||||
r""" | r""" | ||||
:param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), | :param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), | ||||
@@ -234,7 +234,7 @@ class SpanFPreRecMetric(Metric): | |||||
:param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update() | :param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update() | ||||
函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。 | 函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。 | ||||
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | :param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | ||||
当 backend 不支持分布式时,该参数无意义。 | |||||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 | |||||
""" | """ | ||||
super(SpanFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | super(SpanFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | ||||
if f_type not in ('micro', 'macro'): | if f_type not in ('micro', 'macro'): | ||||
@@ -222,14 +222,14 @@ class RandomSampler(ReproducibleSampler): | |||||
class SequentialSampler(RandomSampler): | class SequentialSampler(RandomSampler): | ||||
def __init__(self, dataset, dist_mode:str='interval', **kwargs): | |||||
def __init__(self, dataset, **kwargs): | |||||
""" | """ | ||||
按照顺序读取 dataset 。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。 | 按照顺序读取 dataset 。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。 | ||||
:param dataset: 实现了 __len__ 方法的数据容器。 | :param dataset: 实现了 __len__ 方法的数据容器。 | ||||
:param kwargs: | :param kwargs: | ||||
""" | """ | ||||
super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) | |||||
super().__init__(dataset=dataset, **kwargs) | |||||
def __iter__(self): | def __iter__(self): | ||||
if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 | if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 | ||||
@@ -6,8 +6,9 @@ __all__ = [ | |||||
'is_cur_env_distributed', | 'is_cur_env_distributed', | ||||
'get_global_rank', | 'get_global_rank', | ||||
'rank_zero_call', | 'rank_zero_call', | ||||
'all_rank_call', | |||||
'get_gpu_count' | |||||
'all_rank_call_context', | |||||
'get_gpu_count', | |||||
'fastnlp_no_sync_context' | |||||
] | ] | ||||
@@ -7,10 +7,11 @@ __all__ = [ | |||||
'is_cur_env_distributed', | 'is_cur_env_distributed', | ||||
'get_global_rank', | 'get_global_rank', | ||||
'rank_zero_call', | 'rank_zero_call', | ||||
'all_rank_call' | |||||
'all_rank_call_context', | |||||
'fastnlp_no_sync_context' | |||||
] | ] | ||||
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK | |||||
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK, FASTNLP_NO_SYNC | |||||
def is_cur_env_distributed() -> bool: | def is_cur_env_distributed() -> bool: | ||||
@@ -41,24 +42,46 @@ def rank_zero_call(fn: Callable): | |||||
return a+b | return a+b | ||||
rank_zero_call(add)(1, 2) | rank_zero_call(add)(1, 2) | ||||
同时,该函数还会设置 FASTNLP_NO_SYNC 为 2,在这个环境下,所有的 fastNLP 内置的 barrier 接口,gather/broadcast 操作都没有任何 | |||||
意义。 | |||||
:param fn: 需要包裹的可执行的函数。 | :param fn: 需要包裹的可执行的函数。 | ||||
:return: | :return: | ||||
""" | """ | ||||
@wraps(fn) | @wraps(fn) | ||||
def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]: | def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]: | ||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | ||||
return fn(*args, **kwargs) | |||||
with fastnlp_no_sync_context(level=2): | |||||
return fn(*args, **kwargs) | |||||
return None | return None | ||||
return wrapped_fn | return wrapped_fn | ||||
@contextmanager | @contextmanager | ||||
def all_rank_call(): | |||||
def fastnlp_no_sync_context(level=2): | |||||
""" | |||||
用于让 fastNLP 的 barrier 以及 gather/broadcast等操作等同于只有1卡的多卡程序。如果为 1 表示 fastNLP 里的barrier 操作失效; | |||||
如果为 2 表示 barrier 与 gather/broadcast 都失效。 | |||||
:param int level: 可选 [0, 1, 2] | |||||
:return: | |||||
""" | |||||
old_level = os.environ.get(FASTNLP_NO_SYNC, None) | |||||
os.environ[FASTNLP_NO_SYNC] = f'{level}' | |||||
yield | |||||
if old_level is None: | |||||
os.environ.pop(FASTNLP_NO_SYNC) | |||||
else: | |||||
os.environ[FASTNLP_NO_SYNC] = old_level | |||||
@contextmanager | |||||
def all_rank_call_context(): | |||||
""" | """ | ||||
在多卡模式下,该环境内,会暂时地将 FASTNLP_GLOBAL_RANK 设置为 "0",使得 rank_zero_call 函数失效,使得每个进程都会运行该函数。 | 在多卡模式下,该环境内,会暂时地将 FASTNLP_GLOBAL_RANK 设置为 "0",使得 rank_zero_call 函数失效,使得每个进程都会运行该函数。 | ||||
# 使用方式 | # 使用方式 | ||||
with all_rank_run(): | |||||
with all_rank_call_context(): | |||||
do_something # all rank will do | do_something # all rank will do | ||||
:param fn: | :param fn: | ||||
@@ -48,6 +48,10 @@ FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH" | |||||
# fastNLP 中初始化deque的默认大小 | # fastNLP 中初始化deque的默认大小 | ||||
FASTNLP_DEQUE_SIZE = 'FASTNLP_DEQUE_SIZE' | FASTNLP_DEQUE_SIZE = 'FASTNLP_DEQUE_SIZE' | ||||
# fastNLP中用于关闭 fastNLP 1.barrier 与 2.gather/broadcast 。默认为 '0' 表示不关闭;为 '1' 表示 fastNLP 的 barrier 不执行; | |||||
# 为 '2' 表示 barrier 与 gather/broadcast 都关闭。 | |||||
FASTNLP_NO_SYNC = 'FASTNLP_NO_SYNC' | |||||
# todo 注释 直接使用的变量 | # todo 注释 直接使用的变量 | ||||
FASTNLP_MODEL_FILENAME = "fastnlp_model.pkl.tar" | FASTNLP_MODEL_FILENAME = "fastnlp_model.pkl.tar" | ||||
FASTNLP_CHECKPOINT_FILENAME = "fastnlp_checkpoint.pkl.tar" | FASTNLP_CHECKPOINT_FILENAME = "fastnlp_checkpoint.pkl.tar" | ||||
@@ -69,7 +69,7 @@ class TestFleetDriverFunction: | |||||
""" | """ | ||||
测试 get_no_sync_context 函数 | 测试 get_no_sync_context 函数 | ||||
""" | """ | ||||
res = self.driver.get_no_sync_context() | |||||
res = self.driver.get_model_no_sync_context() | |||||
dist.barrier() | dist.barrier() | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@@ -1,6 +1,6 @@ | |||||
import os | import os | ||||
from fastNLP.envs.distributed import rank_zero_call, all_rank_call | |||||
from fastNLP.envs.distributed import rank_zero_call, all_rank_call_context | |||||
from tests.helpers.utils import re_run_current_cmd_for_torch, Capturing, magic_argv_env_context | from tests.helpers.utils import re_run_current_cmd_for_torch, Capturing, magic_argv_env_context | ||||
@@ -70,7 +70,7 @@ class TestTorch: | |||||
re_run_current_cmd_for_torch(1, output_from_new_proc='all') | re_run_current_cmd_for_torch(1, output_from_new_proc='all') | ||||
# torch.distributed.init_process_group(backend='nccl') | # torch.distributed.init_process_group(backend='nccl') | ||||
# torch.distributed.barrier() | # torch.distributed.barrier() | ||||
with all_rank_call(): | |||||
with all_rank_call_context(): | |||||
with Capturing(no_del=True) as output: | with Capturing(no_del=True) as output: | ||||
write_something() | write_something() | ||||
output = output[0] | output = output[0] | ||||
@@ -80,7 +80,7 @@ class TestTorch: | |||||
else: | else: | ||||
assert '11111' in output | assert '11111' in output | ||||
with all_rank_call(): | |||||
with all_rank_call_context(): | |||||
with Capturing(no_del=True) as output: | with Capturing(no_del=True) as output: | ||||
rank_zero_call(write_other_thing)() | rank_zero_call(write_other_thing)() | ||||