Browse Source

修复Evaluator的evaluate_use_dist_sampler在Trainer中的bug

tags/v1.0.0alpha
yh 2 years ago
parent
commit
d4bccf3f6a
3 changed files with 221 additions and 4 deletions
  1. +4
    -2
      fastNLP/core/controllers/evaluator.py
  2. +3
    -2
      fastNLP/core/controllers/trainer.py
  3. +214
    -0
      tests/core/controllers/test_evaluator_torch.py

+ 4
- 2
fastNLP/core/controllers/evaluator.py View File

@@ -146,7 +146,9 @@ class Evaluator:

self.separator = kwargs.get('separator', '#')
self.model_use_eval_mode = kwargs.get('model_use_eval_mode', True)
use_dist_sampler = kwargs.get("use_dist_sampler", self.driver.is_distributed())
use_dist_sampler = kwargs.get("use_dist_sampler", None)
if use_dist_sampler is None:
use_dist_sampler = self.driver.is_distributed()
if use_dist_sampler:
self._dist_sampler = "unrepeatdist"
else:
@@ -384,7 +386,7 @@ class _MetricsWrapper:
# 如果数据是分布式的,但是不aggregate的话可能有问题
if evaluator._dist_sampler is not None and metric.aggregate_when_get_metric is False:
logger.rank_zero_warning(
"You have replace the sampler as distributed sampler when evaluation, but your metric "
"You have replaced the sampler as distributed sampler when evaluation, but your metric "
f"{metric_name}:{metric.__class__.__name__}'s `aggregate_when_get_metric` is False.", once=True)
if metric.aggregate_when_get_metric is None:
metric.aggregate_when_get_metric = evaluator._dist_sampler is not None


+ 3
- 2
fastNLP/core/controllers/trainer.py View File

@@ -288,7 +288,8 @@ class Trainer(TrainerEventTrigger):

* *use_dist_sampler* -- 表示是否使用分布式的 ``sampler``。在多卡时,分布式 ``sampler`` 将自动决定每张卡上读取的 sample ,使得一个 epoch
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。
* *evaluate_use_dist_sampler* -- 表示在 ``Evaluator`` 中在使用分布式的时候是否将 dataloader 的 ``sampler`` 替换为分布式的 ``sampler``;默认为 ``True``;
* *evaluate_use_dist_sampler* -- 表示在 ``Evaluator`` 中在使用分布式的时候是否将 dataloader 的 ``sampler`` 替换为分布式的 ``sampler``;
不传入该值时,该值与 ``use_dist_sampler`` 参数保持一致;
* *output_from_new_proc* -- 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一:
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error";
@@ -477,7 +478,7 @@ class Trainer(TrainerEventTrigger):
driver=self.driver, evaluate_batch_step_fn=evaluate_batch_step_fn,
evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping,
output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0,
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", None),
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler),
progress_bar=progress_bar)

if train_fn is not None and not isinstance(train_fn, str):


+ 214
- 0
tests/core/controllers/test_evaluator_torch.py View File

@@ -0,0 +1,214 @@
import pytest

from fastNLP import Metric, Evaluator

from dataclasses import dataclass
from typing import Any
from itertools import product

from fastNLP.core.controllers.trainer import Trainer
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset
from tests.helpers.utils import magic_argv_env_context
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP import Event

# 检查能否正确 aggregate


class DistMetric(Metric):
def __init__(self, aggregate_when_get_metric=None):
super().__init__(aggregate_when_get_metric=aggregate_when_get_metric)
self.register_element('count', value=0, aggregate_method='sum')
self.data = 0

def update(self, y):
self.count += len(y)
self.data += len(y)

def get_metric(self) -> dict:
count2 = sum(self.all_gather_object(self.data))
return {'count': self.count.item(), 'count2': count2}

def reset(self):
self.data = 0



if _NEED_IMPORT_TORCH:
from torch.optim import SGD
from torch.utils.data import DataLoader
import torch.distributed as dist
from torch.utils.data import Dataset
import torch


class DataSet(Dataset):
def __init__(self, num_samples=1000, num_features=10):
g = torch.Generator()
g.manual_seed(1000)
self.data = torch.randn(num_samples, num_features, generator=g)
self.y = self.data.argmax(dim=-1)

def __getitem__(self, item):
return {'x': self.data[item], 'y': self.data[item]}

def __len__(self):
return len(self.data)


@dataclass
class NormalClassificationTrainTorchConfig:
num_labels: int = 10
feature_dimension: int = 10
seed: int = 0

batch_size: int = 4
shuffle: bool = True


@dataclass
class TrainerParameters:
model: Any = None
optimizers: Any = None
train_dataloader: Any = None
evaluate_dataloaders: Any = None
input_mapping: Any = None
output_mapping: Any = None
metrics: Any = None


@pytest.fixture(scope="module", params=[1], autouse=True)
def trainer_params(request):
trainer_params = TrainerParameters()

trainer_params.model = TorchNormalModel_Classification_1(
num_labels=NormalClassificationTrainTorchConfig.num_labels,
feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension
)
trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001)

dataset = DataSet(99, num_features=NormalClassificationTrainTorchConfig.feature_dimension)
_dataloader = DataLoader(
dataset=dataset,
batch_size=NormalClassificationTrainTorchConfig.batch_size,
shuffle=True
)
trainer_params.train_dataloader = _dataloader
trainer_params.evaluate_dataloaders = _dataloader

return trainer_params


@pytest.mark.torch
@pytest.mark.parametrize('device', [[0, 1], None])
@magic_argv_env_context
def test_1(trainer_params: TrainerParameters, device):
# 测试能否自动 aggregate 。
for aggregate_when_get_metric, use_dist_sampler in product([True, False], [True, False, None]):
metric = DistMetric(aggregate_when_get_metric=aggregate_when_get_metric)

evaluator = Evaluator(trainer_params.model, dataloaders=trainer_params.evaluate_dataloaders,
metrics={'c': metric},
driver='torch', device=device, use_dist_sampler=use_dist_sampler,
progress_bar='tqdm')
if use_dist_sampler is None:
use_dist_sampler = device is not None
results = evaluator.run()
num_samples = len(trainer_params.evaluate_dataloaders.dataset)
if device is None:
assert results['count#c'] == num_samples
assert results['count2#c'] == num_samples
else:
if aggregate_when_get_metric is True and use_dist_sampler is True:
assert results['count#c'] == num_samples
assert results['count2#c'] == num_samples
elif aggregate_when_get_metric is True and use_dist_sampler is False:
assert results['count#c'] == 2*num_samples
assert results['count2#c'] == 2*num_samples
elif aggregate_when_get_metric is False and use_dist_sampler is True:
assert results['count#c'] in (49, 50) # 不同卡,数量不同
assert results['count2#c'] in (49, 50)
else:
assert results['count#c'] == num_samples
assert results['count2#c'] == num_samples

if dist.is_initialized():
dist.destroy_process_group()



@pytest.mark.torch
@pytest.mark.parametrize('device', [[0, 1], None])
@magic_argv_env_context
def test_2(trainer_params: TrainerParameters, device):
# 测试能否自动 aggregate 。
for aggregate_when_get_metric, use_dist_sampler in product([True, False], [True, False, None]):
metric = DistMetric(aggregate_when_get_metric=aggregate_when_get_metric)

num_samples = len(trainer_params.evaluate_dataloaders.dataset)

@Trainer.on(Event.on_sanity_check_end())
def on_valid_end(trainer, results):
if device is None:
assert results['count#c'] == num_samples
assert results['count2#c'] == num_samples
else:
if aggregate_when_get_metric is True and use_dist_sampler is True:
assert results['count#c'] == num_samples
assert results['count2#c'] == num_samples
elif aggregate_when_get_metric is True and use_dist_sampler is False:
assert results['count#c'] == 2 * num_samples
assert results['count2#c'] == 2 * num_samples
elif aggregate_when_get_metric is False and use_dist_sampler is True:
assert results['count#c'] in (49, 50) # 不同卡,数量不同
assert results['count2#c'] in (49, 50)
else:
assert results['count#c'] == num_samples
assert results['count2#c'] == num_samples

trainer = Trainer(
model=trainer_params.model,
driver='torch',
device=device,
optimizers=trainer_params.optimizers,
train_dataloader=trainer_params.train_dataloader,
evaluate_dataloaders=trainer_params.evaluate_dataloaders,
metrics={'c': metric},
evaluate_every=-1,
n_epochs=0,
output_from_new_proc="all",
use_dist_sampler=use_dist_sampler,
progress_bar='tqdm'
)

if use_dist_sampler is None:
use_dist_sampler = device is not None

trainer.run(num_eval_sanity_batch=-1)

trainer = Trainer(
model=trainer_params.model,
driver='torch',
device=device,
optimizers=trainer_params.optimizers,
train_dataloader=trainer_params.train_dataloader,
evaluate_dataloaders=trainer_params.evaluate_dataloaders,
metrics={'c': DistMetric(aggregate_when_get_metric=aggregate_when_get_metric)},
evaluate_every=-1,
n_epochs=0,
output_from_new_proc="all",
use_dist_sampler=not (use_dist_sampler is True), #取相反的值
evaluate_use_dist_sampler=use_dist_sampler,
progress_bar='rich' # 刚好测试一下可以替换 progress 么
)
trainer.run(num_eval_sanity_batch=-1)

if dist.is_initialized():
dist.destroy_process_group()







Loading…
Cancel
Save