Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
66eeee653b
28 changed files with 398 additions and 213 deletions
  1. +1
    -0
      fastNLP/core/dataloaders/jittor_dataloader/fdl.py
  2. +10
    -1
      fastNLP/core/dataloaders/torch_dataloader/fdl.py
  3. +2
    -2
      fastNLP/core/utils/dummy_class.py
  4. +4
    -4
      tests/core/callbacks/test_callback_event.py
  5. +11
    -7
      tests/core/callbacks/test_checkpoint_callback_torch.py
  6. +6
    -4
      tests/core/callbacks/test_more_evaluate_callback.py
  7. +1
    -0
      tests/core/collators/padders/test_get_padder.py
  8. +18
    -18
      tests/core/collators/padders/test_paddle_padder.py
  9. +1
    -2
      tests/core/collators/padders/test_raw_padder.py
  10. +1
    -1
      tests/core/collators/test_collator.py
  11. +105
    -8
      tests/core/controllers/test_trainer_event_trigger.py
  12. +5
    -5
      tests/core/controllers/test_trainer_other_things.py
  13. +6
    -4
      tests/core/controllers/test_trainer_w_evaluator_torch.py
  14. +7
    -3
      tests/core/controllers/test_trainer_wo_evaluator_torch.py
  15. +3
    -5
      tests/core/controllers/utils/test_utils.py
  16. +2
    -1
      tests/core/drivers/jittor_driver/test_single_device.py
  17. +149
    -118
      tests/core/drivers/torch_driver/test_ddp.py
  18. +8
    -8
      tests/core/drivers/torch_driver/test_initialize_torch_driver.py
  19. +9
    -4
      tests/core/metrics/test_accuracy_torch.py
  20. +8
    -3
      tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py
  21. +9
    -4
      tests/core/metrics/test_span_f1_rec_acc_torch.py
  22. +4
    -2
      tests/core/metrics/utils.py
  23. +1
    -0
      tests/core/utils/test_cache_results.py
  24. +2
    -1
      tests/envs/test_set_backend.py
  25. +3
    -1
      tests/helpers/callbacks/helper_callbacks_torch.py
  26. +6
    -2
      tests/helpers/datasets/torch_data.py
  27. +10
    -5
      tests/helpers/models/torch_model.py
  28. +6
    -0
      tests/pytest.ini

+ 1
- 0
fastNLP/core/dataloaders/jittor_dataloader/fdl.py View File

@@ -91,6 +91,7 @@ class JittorDataLoader:
self.dataset.dataset.set_attrs(batch_size=1)
# 用户提供了 collate_fn,则会自动代替 jittor 提供 collate_batch 函数
# self._collate_fn = _collate_fn
self.cur_batch_indices = None

def __iter__(self):
# TODO 第一次迭代后不能设置collate_fn,设置是无效的


+ 10
- 1
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -3,7 +3,7 @@ __all__ = [
'prepare_torch_dataloader'
]

from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping
from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List

from fastNLP.core.dataset import DataSet
from fastNLP.core.collators import Collator
@@ -78,6 +78,7 @@ class TorchDataLoader(DataLoader):

if sampler is None and batch_sampler is None:
sampler = RandomSampler(dataset, shuffle=shuffle)
shuffle=False

super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler,
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None,
@@ -154,6 +155,14 @@ class TorchDataLoader(DataLoader):
else:
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")

def get_batch_indices(self) -> List[int]:
"""
获取当前 batch 的 idx

:return:
"""
return self.cur_batch_indices


def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]],
batch_size: int = 1,


+ 2
- 2
fastNLP/core/utils/dummy_class.py View File

@@ -1,5 +1,5 @@
import functools

class DummyClass:
def __call__(self, *args, **kwargs):
return
def __init__(self, *args, **kwargs):
pass

+ 4
- 4
tests/core/callbacks/test_callback_event.py View File

@@ -162,7 +162,7 @@ class TestCallbackEvents:
def test_every(self):

# 这里是什么样的事件是不影响的,因为我们是与 Trainer 拆分开了进行测试;
event_state = Events.on_train_begin() # 什么都不输入是应当默认 every=1;
event_state = Event.on_train_begin() # 什么都不输入是应当默认 every=1;
@Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn)
def _fn(data):
return data
@@ -174,7 +174,7 @@ class TestCallbackEvents:
_res.append(cu_res)
assert _res == list(range(100))

event_state = Events.on_train_begin(every=10)
event_state = Event.on_train_begin(every=10)
@Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn)
def _fn(data):
return data
@@ -187,7 +187,7 @@ class TestCallbackEvents:
assert _res == [w - 1 for w in range(10, 101, 10)]

def test_once(self):
event_state = Events.on_train_begin(once=10)
event_state = Event.on_train_begin(once=10)

@Filter(once=event_state.once)
def _fn(data):
@@ -220,7 +220,7 @@ def test_callback_events_torch():
return True
return False

event_state = Events.on_train_begin(filter_fn=filter_fn)
event_state = Event.on_train_begin(filter_fn=filter_fn)

@Filter(filter_fn=event_state.filter_fn)
def _fn(trainer, data):


+ 11
- 7
tests/core/callbacks/test_checkpoint_callback_torch.py View File

@@ -2,9 +2,6 @@ import os
import pytest
from typing import Any
from dataclasses import dataclass
from torch.utils.data import DataLoader
from torch.optim import SGD
import torch.distributed as dist
from pathlib import Path
import re
import time
@@ -20,6 +17,11 @@ from tests.helpers.datasets.torch_data import TorchArgMaxDataset
from torchmetrics import Accuracy
from fastNLP.core.log import logger

from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from torch.utils.data import DataLoader
from torch.optim import SGD
import torch.distributed as dist

@dataclass
class ArgMaxDatasetConfig:
@@ -550,7 +552,7 @@ def test_trainer_checkpoint_callback_2(

if version == 0:
callbacks = [
TrainerCheckpointCallback(
CheckpointCallback(
monitor="acc",
folder=path,
every_n_epochs=None,
@@ -558,12 +560,13 @@ def test_trainer_checkpoint_callback_2(
topk=None,
last=False,
on_exception=None,
model_save_fn=model_save_fn
model_save_fn=model_save_fn,
save_object="trainer"
)
]
elif version == 1:
callbacks = [
TrainerCheckpointCallback(
CheckpointCallback(
monitor="acc",
folder=path,
every_n_epochs=None,
@@ -571,7 +574,8 @@ def test_trainer_checkpoint_callback_2(
topk=1,
last=True,
on_exception=None,
model_save_fn=model_save_fn
model_save_fn=model_save_fn,
save_object="trainer"
)
]



+ 6
- 4
tests/core/callbacks/test_more_evaluate_callback.py View File

@@ -12,9 +12,7 @@ import os
import pytest
from typing import Any
from dataclasses import dataclass
from torch.utils.data import DataLoader
from torch.optim import SGD
import torch.distributed as dist

from pathlib import Path
import re

@@ -29,7 +27,11 @@ from torchmetrics import Accuracy
from fastNLP.core.metrics import Metric
from fastNLP.core.log import logger
from fastNLP.core.callbacks import MoreEvaluateCallback

from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from torch.utils.data import DataLoader
from torch.optim import SGD
import torch.distributed as dist

@dataclass
class ArgMaxDatasetConfig:


+ 1
- 0
tests/core/collators/padders/test_get_padder.py View File

@@ -17,6 +17,7 @@ def test_get_element_shape_dtype():
@pytest.mark.parametrize('backend', ['raw', None, 'numpy', 'torch', 'jittor', 'paddle'])
@pytest.mark.torch
@pytest.mark.paddle
@pytest.mark.jittor
def test_get_padder_run(backend):
if not _NEED_IMPORT_TORCH and backend == 'torch':
pytest.skip("No torch")


+ 18
- 18
tests/core/collators/padders/test_paddle_padder.py View File

@@ -1,7 +1,7 @@
import numpy as np
import pytest

from fastNLP.core.collators.padders.paddle_padder import paddleTensorPadder, paddleSequencePadder, paddleNumberPadder
from fastNLP.core.collators.padders.paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder
from fastNLP.core.collators.padders.exceptions import DtypeError
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE

@@ -10,9 +10,9 @@ if _NEED_IMPORT_PADDLE:


@pytest.mark.paddle
class TestpaddleNumberPadder:
class TestPaddleNumberPadder:
def test_run(self):
padder = paddleNumberPadder(ele_dtype=int, dtype=int, pad_val=-1)
padder = PaddleNumberPadder(ele_dtype=int, dtype=int, pad_val=-1)
a = [1, 2, 3]
t_a = padder(a)
assert isinstance(t_a, paddle.Tensor)
@@ -20,9 +20,9 @@ class TestpaddleNumberPadder:


@pytest.mark.paddle
class TestpaddleSequencePadder:
class TestPaddleSequencePadder:
def test_run(self):
padder = paddleSequencePadder(ele_dtype=int, dtype=int, pad_val=-1)
padder = PaddleSequencePadder(ele_dtype=int, dtype=int, pad_val=-1)
a = [[1, 2, 3], [3]]
a = padder(a)
shape = a.shape
@@ -32,20 +32,20 @@ class TestpaddleSequencePadder:
assert (a == b).sum().item() == shape[0]*shape[1]

def test_dtype_check(self):
padder = paddleSequencePadder(ele_dtype=np.zeros(3, dtype=np.int32).dtype, dtype=int, pad_val=-1)
padder = PaddleSequencePadder(ele_dtype=np.zeros(3, dtype=np.int32).dtype, dtype=int, pad_val=-1)
with pytest.raises(DtypeError):
padder = paddleSequencePadder(ele_dtype=str, dtype=int, pad_val=-1)
padder = paddleSequencePadder(ele_dtype='int64', dtype=int, pad_val=-1)
padder = paddleSequencePadder(ele_dtype=np.int32, dtype=None, pad_val=-1)
padder = PaddleSequencePadder(ele_dtype=str, dtype=int, pad_val=-1)
padder = PaddleSequencePadder(ele_dtype='int64', dtype=int, pad_val=-1)
padder = PaddleSequencePadder(ele_dtype=np.int32, dtype=None, pad_val=-1)
a = padder([[1], [2, 322]])
# assert (a>67).sum()==0 # 因为int8的范围为-67 - 66
padder = paddleSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1)
padder = PaddleSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1)


@pytest.mark.paddle
class TestpaddleTensorPadder:
class TestPaddleTensorPadder:
def test_run(self):
padder = paddleTensorPadder(ele_dtype=paddle.zeros((3,)).dtype, dtype=paddle.zeros((3,)).dtype, pad_val=-1)
padder = PaddleTensorPadder(ele_dtype=paddle.zeros((3,)).dtype, dtype=paddle.zeros((3,)).dtype, pad_val=-1)
a = [paddle.zeros((3,)), paddle.zeros((2,))]
a = padder(a)
shape = a.shape
@@ -74,7 +74,7 @@ class TestpaddleTensorPadder:
[[0, -1], [-1, -1], [-1, -1]]])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]

padder = paddleTensorPadder(ele_dtype=paddle.zeros((3, )).dtype, dtype=paddle.zeros((3, )).dtype, pad_val=-1)
padder = PaddleTensorPadder(ele_dtype=paddle.zeros((3, )).dtype, dtype=paddle.zeros((3, )).dtype, pad_val=-1)
a = [paddle.zeros((3, 2)), paddle.zeros((2, 2))]
a = padder(a)
shape = a.shape
@@ -85,7 +85,7 @@ class TestpaddleTensorPadder:
])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]

padder = paddleTensorPadder(ele_dtype=paddle.zeros((3, 2)).dtype, dtype=None, pad_val=-1)
padder = PaddleTensorPadder(ele_dtype=paddle.zeros((3, 2)).dtype, dtype=None, pad_val=-1)
a = [np.zeros((3, 2), dtype=np.float32), np.zeros((2, 2), dtype=np.float32)]
a = padder(a)
shape = a.shape
@@ -96,11 +96,11 @@ class TestpaddleTensorPadder:
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]

def test_dtype_check(self):
padder = paddleTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1)
padder = PaddleTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1)
with pytest.raises(DtypeError):
padder = paddleTensorPadder(ele_dtype=str, dtype=int, pad_val=-1)
padder = paddleTensorPadder(ele_dtype='int64', dtype=int, pad_val=-1)
padder = paddleTensorPadder(ele_dtype=int, dtype='int64', pad_val=-1)
padder = PaddleTensorPadder(ele_dtype=str, dtype=int, pad_val=-1)
padder = PaddleTensorPadder(ele_dtype='int64', dtype=int, pad_val=-1)
padder = PaddleTensorPadder(ele_dtype=int, dtype='int64', pad_val=-1)

def test_v1(self):
print(paddle.zeros((3, )).dtype)

+ 1
- 2
tests/core/collators/padders/test_raw_padder.py View File

@@ -23,7 +23,6 @@ class TestRawSequencePadder:
assert (a == b).sum().item() == shape[0]*shape[1]

def test_dtype_check(self):
with pytest.raises(DtypeError):
padder = RawSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int)
padder = RawSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int)
with pytest.raises(DtypeError):
padder = RawSequencePadder(pad_val=-1, ele_dtype=str, dtype=int)

+ 1
- 1
tests/core/collators/test_collator.py View File

@@ -4,7 +4,7 @@ import pytest

from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR

from fastNLP.core.collators.new_collator import Collator
from fastNLP.core.collators.collator import Collator


def _assert_equal(d1, d2):


+ 105
- 8
tests/core/controllers/test_trainer_event_trigger.py View File

@@ -1,10 +1,7 @@
import pytest
from typing import Any
from dataclasses import dataclass
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
import torch.distributed as dist


from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.callbacks.callback_event import Event
@@ -12,6 +9,12 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification
from tests.helpers.callbacks.helper_callbacks import RecordTrainerEventTriggerCallback
from tests.helpers.utils import magic_argv_env_context, Capturing
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
import torch.distributed as dist


@dataclass
@@ -96,10 +99,10 @@ def test_trainer_event_trigger_1(
if dist.is_initialized():
dist.destroy_process_group()

Event_attrs = Event.__dict__
for k, v in Event_attrs.items():
if isinstance(v, staticmethod):
assert k in output[0]
Event_attrs = Event.__dict__
for k, v in Event_attrs.items():
if isinstance(v, staticmethod):
assert k in output[0]

@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7])
@@ -211,7 +214,101 @@ def test_trainer_event_trigger_2(
)

trainer.run()

if dist.is_initialized():
dist.destroy_process_group()

Event_attrs = Event.__dict__
for k, v in Event_attrs.items():
if isinstance(v, staticmethod):
assert k in output[0]


@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 6)])
@pytest.mark.torch
@magic_argv_env_context
def test_trainer_event_trigger_3(
model_and_optimizers: TrainerParameters,
driver,
device,
n_epochs=2,
):
import re

once_message_1 = "This message should be typed 1 times."
once_message_2 = "test_filter_fn"
once_message_3 = "once message 3"
twice_message = "twice message hei hei"

@Trainer.on(Event.on_train_epoch_begin(every=2))
def train_epoch_begin_1(trainer):
print(once_message_1)

@Trainer.on(Event.on_train_epoch_begin())
def train_epoch_begin_2(trainer):
print(twice_message)

@Trainer.on(Event.on_train_epoch_begin(once=2))
def train_epoch_begin_3(trainer):
print(once_message_3)

def filter_fn(filter, trainer):
if trainer.cur_epoch_idx == 1:
return True
else:
return False

@Trainer.on(Event.on_train_epoch_end(filter_fn=filter_fn))
def test_filter_fn(trainer):
print(once_message_2)

with Capturing() as output:
trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,

n_epochs=n_epochs,
)

trainer.run()

if dist.is_initialized():
dist.destroy_process_group()


once_pattern_1 = re.compile(once_message_1)
once_pattern_2 = re.compile(once_message_2)
once_pattern_3 = re.compile(once_message_3)
twice_pattern = re.compile(twice_message)

once_res_1 = once_pattern_1.findall(output[0])
assert len(once_res_1) == 1
once_res_2 = once_pattern_2.findall(output[0])
assert len(once_res_2) == 1
once_res_3 = once_pattern_3.findall(output[0])
assert len(once_res_3) == 1
twice_res = twice_pattern.findall(output[0])
assert len(twice_res) == 2

















+ 5
- 5
tests/core/controllers/test_trainer_other_things.py View File

@@ -1,22 +1,22 @@
import pytest

from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.callbacks import Events
from fastNLP.core.callbacks import Event
from tests.helpers.utils import magic_argv_env_context


@magic_argv_env_context
def test_trainer_torch_without_evaluator():
@Trainer.on(Events.on_train_epoch_begin(every=10))
@Trainer.on(Event.on_train_epoch_begin(every=10), marker="test_trainer_other_things")
def fn1(trainer):
pass

@Trainer.on(Events.on_train_batch_begin(every=10))
@Trainer.on(Event.on_train_batch_begin(every=10), marker="test_trainer_other_things")
def fn2(trainer, batch, indices):
pass

with pytest.raises(AssertionError):
@Trainer.on(Events.on_train_batch_begin(every=10))
with pytest.raises(BaseException):
@Trainer.on(Event.on_train_batch_begin(every=10), marker="test_trainer_other_things")
def fn3(trainer, batch):
pass



+ 6
- 4
tests/core/controllers/test_trainer_w_evaluator_torch.py View File

@@ -2,9 +2,7 @@
注意这一文件中的测试函数都应当是在 `test_trainer_w_evaluator_torch.py` 中已经测试过的测试函数的基础上加上 metrics 和 evaluator 修改而成;
"""
import pytest
from torch.optim import SGD
from torch.utils.data import DataLoader
import torch.distributed as dist

from dataclasses import dataclass
from typing import Any
from torchmetrics import Accuracy
@@ -14,7 +12,11 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback
from tests.helpers.utils import magic_argv_env_context

from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from torch.optim import SGD
from torch.utils.data import DataLoader
import torch.distributed as dist

@dataclass
class NormalClassificationTrainTorchConfig:


+ 7
- 3
tests/core/controllers/test_trainer_wo_evaluator_torch.py View File

@@ -2,9 +2,7 @@ import os.path
import subprocess
import sys
import pytest
import torch.distributed as dist
from torch.optim import SGD
from torch.utils.data import DataLoader

from dataclasses import dataclass
from typing import Any
from pathlib import Path
@@ -16,6 +14,11 @@ from tests.helpers.callbacks.helper_callbacks import RecordLossCallback
from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch
from tests.helpers.utils import magic_argv_env_context, Capturing
from fastNLP.core import rank_zero_rm
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch.distributed as dist
from torch.optim import SGD
from torch.utils.data import DataLoader


@dataclass
@@ -286,6 +289,7 @@ def test_trainer_on_exception(
dist.destroy_process_group()


@pytest.mark.torch
@pytest.mark.parametrize("version", [0, 1, 2, 3])
@magic_argv_env_context
def test_torch_distributed_launch_1(version):


+ 3
- 5
tests/core/controllers/utils/test_utils.py View File

@@ -11,7 +11,7 @@ class Test_WrapDataLoader:
for sanity_batches in all_sanity_batches:
data = NormalSampler(num_of_data=1000)
wrapper = _TruncatedDataLoader(dataloader=data, num_batches=sanity_batches)
dataloader = iter(wrapper(dataloader=data))
dataloader = iter(wrapper)
mark = 0
while True:
try:
@@ -32,8 +32,7 @@ class Test_WrapDataLoader:
dataset = TorchNormalDataset(num_of_data=1000)
dataloader = DataLoader(dataset, batch_size=bs, shuffle=True)
wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches)
dataloader = wrapper(dataloader)
dataloader = iter(dataloader)
dataloader = iter(wrapper)
all_supposed_running_data_num = 0
while True:
try:
@@ -55,6 +54,5 @@ class Test_WrapDataLoader:
dataset = TorchNormalDataset(num_of_data=1000)
dataloader = DataLoader(dataset, batch_size=bs, shuffle=True)
wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches)
dataloader = wrapper(dataloader)
length.append(len(dataloader))
length.append(len(wrapper))
assert length == reduce(lambda x, y: x+y, [all_sanity_batches for _ in range(len(bses))])

+ 2
- 1
tests/core/drivers/jittor_driver/test_single_device.py View File

@@ -15,7 +15,7 @@ else:



class Model (Module):
class Model(Module):
def __init__ (self):
super (Model, self).__init__()
self.conv1 = nn.Conv (3, 32, 3, 1) # no padding
@@ -45,6 +45,7 @@ class Model (Module):
return x

@pytest.mark.jittor
@pytest.mark.skip("Skip jittor tests now.")
class TestSingleDevice:

def test_on_gpu_without_fp16(self):


+ 149
- 118
tests/core/drivers/torch_driver/test_ddp.py View File

@@ -13,12 +13,13 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset
from tests.helpers.utils import magic_argv_env_context
from fastNLP.core import rank_zero_rm
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, BatchSampler

import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, BatchSampler

def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"):
def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="all"):
torch_model = TorchNormalModel_Classification_1(num_labels, feature_dimension)
torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01)
device = [torch.device(i) for i in device]
@@ -72,108 +73,100 @@ def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed=
#
############################################################################

@pytest.mark.torch
@magic_argv_env_context
def test_multi_drivers():
"""
测试使用了多个 TorchDDPDriver 的情况。
"""
generate_driver(10, 10)
generate_driver(20, 10)
with pytest.raises(RuntimeError):
# 设备设置不同,应该报错
generate_driver(20, 3, device=[0,1,2])
assert False
dist.barrier()

if dist.is_initialized():
dist.destroy_process_group()

@pytest.mark.torch
class TestDDPDriverFunction:
"""
测试 TorchDDPDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题
"""

@classmethod
def setup_class(cls):
cls.driver = generate_driver(10, 10)

@magic_argv_env_context
def test_multi_drivers(self):
def test_simple_functions(self):
"""
测试使用了多个 TorchDDPDriver 的情况。
简单测试多个函数
"""
driver2 = generate_driver(20, 10)
with pytest.raises(RuntimeError):
# 设备设置不同,应该报错
driver3 = generate_driver(20, 3, device=[0,1,2])
assert False
dist.barrier()
driver = generate_driver(10, 10)

@magic_argv_env_context
def test_move_data_to_device(self):
"""
这个函数仅调用了torch_move_data_to_device,测试例在tests/core/utils/test_torch_utils.py中
就不重复测试了
测试 move_data_to_device 函数。这个函数仅调用了 torch_move_data_to_device ,测试例在
tests/core/utils/test_torch_utils.py中,就不重复测试了
"""
self.driver.move_data_to_device(torch.rand((32, 64)))

driver.move_data_to_device(torch.rand((32, 64)))
dist.barrier()

@magic_argv_env_context
def test_is_distributed(self):
"""
测试 is_distributed 函数
"""
assert self.driver.is_distributed() == True
assert driver.is_distributed() == True
dist.barrier()

@magic_argv_env_context
def test_get_no_sync_context(self):
"""
测试 get_no_sync_context 函数
"""
res = self.driver.get_model_no_sync_context()
res = driver.get_model_no_sync_context()
dist.barrier()

@magic_argv_env_context
def test_is_global_zero(self):
"""
测试 is_global_zero 函数
"""
self.driver.is_global_zero()
driver.is_global_zero()
dist.barrier()

@magic_argv_env_context
def test_unwrap_model(self):
"""
测试 unwrap_model 函数
"""
self.driver.unwrap_model()
driver.unwrap_model()
dist.barrier()

@magic_argv_env_context
def test_get_local_rank(self):
"""
测试 get_local_rank 函数
"""
self.driver.get_local_rank()
driver.get_local_rank()
dist.barrier()

@magic_argv_env_context
def test_all_gather(self):
"""
测试 all_gather 函数
详细的测试在 test_dist_utils.py 中完成
"""
obj = {
"rank": self.driver.global_rank
"rank": driver.global_rank
}
obj_list = self.driver.all_gather(obj, group=None)
obj_list = driver.all_gather(obj, group=None)
for i, res in enumerate(obj_list):
assert res["rank"] == i

@magic_argv_env_context
@pytest.mark.parametrize("src_rank", ([0, 1]))
def test_broadcast_object(self, src_rank):
"""
测试 broadcast_object 函数
详细的函数在 test_dist_utils.py 中完成
"""
if self.driver.global_rank == src_rank:
if driver.global_rank == 0:
obj = {
"rank": self.driver.global_rank
"rank": driver.global_rank
}
else:
obj = None
res = self.driver.broadcast_object(obj, src=src_rank)
assert res["rank"] == src_rank
res = driver.broadcast_object(obj, src=0)
assert res["rank"] == 0

if dist.is_initialized():
dist.destroy_process_group()

############################################################################
#
@@ -187,7 +180,6 @@ class TestSetDistReproDataloader:
@classmethod
def setup_class(cls):
cls.device = [0, 1]
cls.driver = generate_driver(10, 10, device=cls.device)

def setup_method(self):
self.dataset = TorchNormalDataset(40)
@@ -204,17 +196,20 @@ class TestSetDistReproDataloader:
测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现
此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler
"""
driver = generate_driver(10, 10, device=self.device)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, batch_sampler, False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert replaced_loader.batch_sampler is batch_sampler
self.check_distributed_sampler(replaced_loader.batch_sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle)
dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
@@ -223,9 +218,10 @@ class TestSetDistReproDataloader:
测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现
此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler
"""
driver = generate_driver(10, 10, device=self.device)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
sampler = RandomSampler(self.dataset, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, sampler, False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
@@ -234,9 +230,11 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.sampler is sampler
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle)

dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()
"""
传入的参数 `dist` 为 None 的情况,这种情况出现在 trainer 和 evaluator 的初始化过程中,用户指定了 `use_dist_sampler`
@@ -251,15 +249,17 @@ class TestSetDistReproDataloader:
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现
当用户在 driver 之外初始化了分布式环境时,fastnlp 不支持进行断点重训,此时应该报错
"""
driver = generate_driver(10, 10, device=self.device)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
with pytest.raises(RuntimeError):
# 应当抛出 RuntimeError
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, True)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, True)

dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()

@magic_argv_env_context
# @pytest.mark.parametrize("shuffle", ([True, False]))
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle):
"""
@@ -268,21 +268,24 @@ class TestSetDistReproDataloader:
此时传入的 dataloader 的 batch_sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler
和原 dataloader 相同
"""
driver = generate_driver(10, 10, device=self.device)
dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False)
dataloader.batch_sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank,
num_replicas=driver.world_size,
rank=driver.global_rank,
pad=True
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert replaced_loader.batch_sampler.batch_size == 4
self.check_distributed_sampler(dataloader.batch_sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle)

dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
@@ -292,12 +295,13 @@ class TestSetDistReproDataloader:
此时传入的 dataloader 的 batch_sampler.sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其
batch_sampler.sampler 和原 dataloader 相同
"""
driver = generate_driver(10, 10, device=self.device)
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False)
dataloader.batch_sampler.sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank
num_replicas=driver.world_size,
rank=driver.global_rank
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
@@ -307,9 +311,11 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.batch_sampler.drop_last == False
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle)
dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
@@ -318,11 +324,14 @@ class TestSetDistReproDataloader:
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现
此时直接返回原来的 dataloader,不做任何处理。
"""
driver = generate_driver(10, 10, device=self.device)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False)

assert replaced_loader is dataloader
dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()

"""
传入的参数 `dist` 为 'dist' 的情况,这种情况出现在 trainer 的初始化过程中,用户指定了 `use_dist_sampler` 参数
@@ -337,12 +346,13 @@ class TestSetDistReproDataloader:
的表现
此时应该返回一个新的 dataloader,其batch_sampler 和原 dataloader 相同,且应该正确地设置了分布式相关的属性
"""
driver = generate_driver(10, 10, device=self.device)
dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle)
)
dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
@@ -351,6 +361,8 @@ class TestSetDistReproDataloader:
assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler)
dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
@@ -361,8 +373,9 @@ class TestSetDistReproDataloader:
此时应该返回一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同,且应该正确地设置了分布式相关
的属性
"""
driver = generate_driver(10, 10, device=self.device)
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False)

assert not (replaced_loader is dataloader)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
@@ -372,6 +385,8 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
@@ -381,8 +396,9 @@ class TestSetDistReproDataloader:
此时应该返回一个新的 dataloader,并替换其 batch_sampler.sampler 为 RandomSampler,且应该正确设置了分布式相关
的属性
"""
driver = generate_driver(10, 10, device=self.device)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
@@ -392,6 +408,8 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()

"""
传入的参数 `dist` 为 'unrepeatdist' 的情况,这种情况出现在 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` 参数
@@ -407,8 +425,9 @@ class TestSetDistReproDataloader:
此时应该返回一个新的 dataloader,且将原来的 Sampler 替换为 UnrepeatedRandomSampler,且正确地设置了分布式相关
的属性
"""
driver = generate_driver(10, 10, device=self.device)
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
@@ -418,6 +437,8 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
@@ -427,8 +448,9 @@ class TestSetDistReproDataloader:
的表现
此时应该返回一个新的 dataloader,且重新实例化了原来的 Sampler
"""
driver = generate_driver(10, 10, device=self.device)
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=True)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
@@ -439,6 +461,8 @@ class TestSetDistReproDataloader:
assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
@@ -448,8 +472,9 @@ class TestSetDistReproDataloader:
此时应该返回一个新的 dataloader,且将 sampler 替换为 UnrepeatedSequentialSampler,并正确地设置了分布式相关
的属性
"""
driver = generate_driver(10, 10, device=self.device)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
@@ -459,6 +484,8 @@ class TestSetDistReproDataloader:
assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()

def check_distributed_sampler(self, sampler):
"""
@@ -469,7 +496,7 @@ class TestSetDistReproDataloader:
if not isinstance(sampler, UnrepeatedSampler):
assert sampler.pad == True

def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle):
def check_set_dist_repro_dataloader(self, driver, dataloader, replaced_loader, shuffle):
"""
测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确
"""
@@ -501,8 +528,8 @@ class TestSetDistReproDataloader:
drop_last=False,
)
new_loader.batch_sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank,
num_replicas=driver.world_size,
rank=driver.global_rank,
pad=True
)
new_loader.batch_sampler.load_state_dict(sampler_states)
@@ -512,8 +539,8 @@ class TestSetDistReproDataloader:
# 重新构造 dataloader
new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, drop_last=False)
new_loader.batch_sampler.sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank
num_replicas=driver.world_size,
rank=driver.global_rank
)
new_loader.batch_sampler.sampler.load_state_dict(sampler_states)
for idx, batch in enumerate(new_loader):
@@ -534,11 +561,6 @@ class TestSaveLoad:
测试多卡情况下 save 和 load 相关函数的表现
"""

@classmethod
def setup_class(cls):
# 不在这里 setup 的话会报错
cls.driver = generate_driver(10, 10)

def setup_method(self):
self.dataset = TorchArgMaxDataset(10, 20)

@@ -552,26 +574,26 @@ class TestSaveLoad:
path = "model"

dataloader = DataLoader(self.dataset, batch_size=2)
self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10)
driver1, driver2 = generate_driver(10, 10), generate_driver(10, 10)

self.driver1.save_model(path, only_state_dict)
driver1.save_model(path, only_state_dict)

# 同步
dist.barrier()
self.driver2.load_model(path, only_state_dict)
driver2.load_model(path, only_state_dict)

for idx, batch in enumerate(dataloader):
batch = self.driver1.move_data_to_device(batch)
res1 = self.driver1.model(
batch = driver1.move_data_to_device(batch)
res1 = driver1.model(
batch,
fastnlp_fn=self.driver1.model.module.model.evaluate_step,
fastnlp_fn=driver1.model.module.model.evaluate_step,
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
res2 = self.driver2.model(
res2 = driver2.model(
batch,
fastnlp_fn=self.driver2.model.module.model.evaluate_step,
fastnlp_fn=driver2.model.module.model.evaluate_step,
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
@@ -580,6 +602,9 @@ class TestSaveLoad:
finally:
rank_zero_rm(path)

if dist.is_initialized():
dist.destroy_process_group()

@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
@pytest.mark.parametrize("fp16", ([True, False]))
@@ -593,7 +618,7 @@ class TestSaveLoad:
path = "model.ckp"
num_replicas = len(device)

self.driver1, self.driver2 = generate_driver(10, 10, device=device, fp16=fp16), \
driver1, driver2 = generate_driver(10, 10, device=device, fp16=fp16), \
generate_driver(10, 10, device=device, fp16=False)
dataloader = dataloader_with_bucketedbatchsampler(
self.dataset,
@@ -603,8 +628,8 @@ class TestSaveLoad:
drop_last=False
)
dataloader.batch_sampler.set_distributed(
num_replicas=self.driver1.world_size,
rank=self.driver1.global_rank,
num_replicas=driver1.world_size,
rank=driver1.global_rank,
pad=True
)
num_consumed_batches = 2
@@ -623,7 +648,7 @@ class TestSaveLoad:
# 保存状态
sampler_states = dataloader.batch_sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
# 加载
# 更改 batch_size
dataloader = dataloader_with_bucketedbatchsampler(
@@ -634,11 +659,11 @@ class TestSaveLoad:
drop_last=False
)
dataloader.batch_sampler.set_distributed(
num_replicas=self.driver2.world_size,
rank=self.driver2.global_rank,
num_replicas=driver2.world_size,
rank=driver2.global_rank,
pad=True
)
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")
# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空
@@ -652,7 +677,7 @@ class TestSaveLoad:

# 3. 检查 fp16 是否被加载
if fp16:
assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler)
assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler)

# 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx
@@ -664,16 +689,16 @@ class TestSaveLoad:

left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
res1 = self.driver1.model(
res1 = driver1.model(
batch,
fastnlp_fn=self.driver1.model.module.model.evaluate_step,
fastnlp_fn=driver1.model.module.model.evaluate_step,
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
res2 = self.driver2.model(
res2 = driver2.model(
batch,
fastnlp_fn=self.driver2.model.module.model.evaluate_step,
fastnlp_fn=driver2.model.module.model.evaluate_step,
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
@@ -686,6 +711,9 @@ class TestSaveLoad:
finally:
rank_zero_rm(path)

if dist.is_initialized():
dist.destroy_process_group()

@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
@pytest.mark.parametrize("fp16", ([True, False]))
@@ -700,13 +728,13 @@ class TestSaveLoad:

num_replicas = len(device)

self.driver1 = generate_driver(10, 10, device=device, fp16=fp16)
self.driver2 = generate_driver(10, 10, device=device, fp16=False)
driver1 = generate_driver(10, 10, device=device, fp16=fp16)
driver2 = generate_driver(10, 10, device=device, fp16=False)

dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False)
dataloader.batch_sampler.sampler.set_distributed(
num_replicas=self.driver1.world_size,
rank=self.driver1.global_rank,
num_replicas=driver1.world_size,
rank=driver1.global_rank,
pad=True
)
num_consumed_batches = 2
@@ -726,18 +754,18 @@ class TestSaveLoad:
sampler_states = dataloader.batch_sampler.sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict:
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else:
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))])
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))])
# 加载
# 更改 batch_size
dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False)
dataloader.batch_sampler.sampler.set_distributed(
num_replicas=self.driver2.world_size,
rank=self.driver2.global_rank,
num_replicas=driver2.world_size,
rank=driver2.global_rank,
pad=True
)
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")

# 1. 检查 optimizer 的状态
@@ -753,7 +781,7 @@ class TestSaveLoad:
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"]
# 3. 检查 fp16 是否被加载
if fp16:
assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler)
assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler)

# 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx
@@ -765,16 +793,16 @@ class TestSaveLoad:

left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
res1 = self.driver1.model(
res1 = driver1.model(
batch,
fastnlp_fn=self.driver1.model.module.model.evaluate_step,
fastnlp_fn=driver1.model.module.model.evaluate_step,
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
res2 = self.driver2.model(
res2 = driver2.model(
batch,
fastnlp_fn=self.driver2.model.module.model.evaluate_step,
fastnlp_fn=driver2.model.module.model.evaluate_step,
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
@@ -786,4 +814,7 @@ class TestSaveLoad:
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas

finally:
rank_zero_rm(path)
rank_zero_rm(path)

if dist.is_initialized():
dist.destroy_process_group()

+ 8
- 8
tests/core/drivers/torch_driver/test_initialize_torch_driver.py View File

@@ -2,12 +2,14 @@ import pytest

from fastNLP.core.drivers import TorchSingleDriver, TorchDDPDriver
from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver
from fastNLP.envs import get_gpu_count
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.utils import magic_argv_env_context

import torch

from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
from torch import device as torchdevice
else:
from fastNLP.core.utils.dummy_class import DummyClass as torchdevice

@pytest.mark.torch
def test_incorrect_driver():
@@ -20,7 +22,7 @@ def test_incorrect_driver():
@pytest.mark.torch
@pytest.mark.parametrize(
"device",
["cpu", "cuda:0", 0, torch.device("cuda:0")]
["cpu", "cuda:0", 0, torchdevice("cuda:0")]
)
@pytest.mark.parametrize(
"driver",
@@ -83,7 +85,6 @@ def test_get_ddp(driver, device):
("driver", "device"),
[("torch_ddp", "cpu")]
)
@magic_argv_env_context
def test_get_ddp_cpu(driver, device):
"""
测试试图在 cpu 上初始化分布式训练的情况
@@ -96,13 +97,12 @@ def test_get_ddp_cpu(driver, device):
@pytest.mark.torch
@pytest.mark.parametrize(
"device",
[-2, [0, torch.cuda.device_count() + 1, 3], [-2], torch.cuda.device_count() + 1]
[-2, [0, 20, 3], [-2], 20]
)
@pytest.mark.parametrize(
"driver",
["torch", "torch_ddp"]
)
@magic_argv_env_context
def test_device_out_of_range(driver, device):
"""
测试传入的device超过范围的情况


+ 9
- 4
tests/core/metrics/test_accuracy_torch.py View File

@@ -7,15 +7,20 @@ import copy
import socket
import pytest
import numpy as np
import torch
import torch.distributed
from torch.multiprocessing import Pool, set_start_method

from sklearn.metrics import accuracy_score as sklearn_accuracy

from fastNLP.core.dataset import DataSet
from fastNLP.core.metrics.accuracy import Accuracy
from fastNLP.core.metrics.metric import Metric
from .utils import find_free_network_port, setup_ddp, _assert_allclose
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
import torch.distributed
from torch.multiprocessing import Pool, set_start_method
else:
from fastNLP.core.utils.dummy_class import DummyClass as set_start_method

set_start_method("spawn", force=True)

@@ -26,7 +31,7 @@ pool = None

def _test(local_rank: int,
world_size: int,
device: torch.device,
device: "torch.device",
dataset: DataSet,
metric_class: Type[Metric],
metric_kwargs: Dict[str, Any],


+ 8
- 3
tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py View File

@@ -2,18 +2,23 @@ from functools import partial
import copy

import pytest
import torch
import numpy as np
from torch.multiprocessing import Pool, set_start_method

from fastNLP.core.metrics import ClassifyFPreRecMetric
from fastNLP.core.dataset import DataSet
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from .utils import find_free_network_port, setup_ddp
if _NEED_IMPORT_TORCH:
import torch
from torch.multiprocessing import Pool, set_start_method
else:
from fastNLP.core.utils.dummy_class import DummyClass as set_start_method

set_start_method("spawn", force=True)


def _test(local_rank: int, world_size: int, device: torch.device,
def _test(local_rank: int, world_size: int, device: "torch.device",
dataset: DataSet, metric_class, metric_kwargs, metric_result):
metric = metric_class(**metric_kwargs)
# dataset 也类似(每个进程有自己的一个)


+ 9
- 4
tests/core/metrics/test_span_f1_rec_acc_torch.py View File

@@ -5,16 +5,21 @@ import os, sys
import copy
from functools import partial

import torch
import torch.distributed
import numpy as np
import socket
from torch.multiprocessing import Pool, set_start_method
# from multiprocessing import Pool, set_start_method
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.metrics import SpanFPreRecMetric
from fastNLP.core.dataset import DataSet
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from .utils import find_free_network_port, setup_ddp
if _NEED_IMPORT_TORCH:
import torch
import torch.distributed
from torch.multiprocessing import Pool, set_start_method
else:
from fastNLP.core.utils.dummy_class import DummyClass as set_start_method

set_start_method("spawn", force=True)

@@ -44,7 +49,7 @@ pool = None

def _test(local_rank: int,
world_size: int,
device: torch.device,
device: "torch.device",
dataset: DataSet,
metric_class,
metric_kwargs,


+ 4
- 2
tests/core/metrics/utils.py View File

@@ -2,9 +2,11 @@ import os, sys
import socket
from typing import Union

import torch
from torch import distributed
import numpy as np
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
from torch import distributed


def setup_ddp(rank: int, world_size: int, master_port: int) -> None:


+ 1
- 0
tests/core/utils/test_cache_results.py View File

@@ -3,6 +3,7 @@ import pytest
import subprocess
from io import StringIO
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '../../..'))

from fastNLP.core.utils.cache_results import cache_results
from fastNLP.core import rank_zero_rm


+ 2
- 1
tests/envs/test_set_backend.py View File

@@ -1,4 +1,5 @@
import os
import pytest

from fastNLP.envs.set_backend import dump_fastnlp_backend
from tests.helpers.utils import Capturing
@@ -9,7 +10,7 @@ def test_dump_fastnlp_envs():
filepath = None
try:
with Capturing() as output:
dump_fastnlp_backend()
dump_fastnlp_backend(backend="torch")
filepath = os.path.join(os.path.expanduser('~'), '.fastNLP', 'envs', os.environ['CONDA_DEFAULT_ENV']+'.json')
assert filepath in output[0]
assert os.path.exists(filepath)


+ 3
- 1
tests/helpers/callbacks/helper_callbacks_torch.py View File

@@ -1,7 +1,9 @@
import torch
from copy import deepcopy

from fastNLP.core.callbacks.callback import Callback
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch


class RecordAccumulationStepsCallback_Torch(Callback):


+ 6
- 2
tests/helpers/datasets/torch_data.py View File

@@ -1,7 +1,11 @@
import torch
from functools import reduce
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torch.utils.data.sampler import SequentialSampler, BatchSampler
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torch.utils.data.sampler import SequentialSampler, BatchSampler
else:
from fastNLP.core.utils.dummy_class import DummyClass as Dataset


class TorchNormalDataset(Dataset):


+ 10
- 5
tests/helpers/models/torch_model.py View File

@@ -1,9 +1,14 @@
import torch
import torch.nn as nn
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
from torch.nn import Module
import torch.nn as nn
else:
from fastNLP.core.utils.dummy_class import DummyClass as Module


# 1. 最为基础的分类模型
class TorchNormalModel_Classification_1(nn.Module):
class TorchNormalModel_Classification_1(Module):
"""
单独实现 train_step 和 evaluate_step;
"""
@@ -38,7 +43,7 @@ class TorchNormalModel_Classification_1(nn.Module):
return {"preds": x, "target": y}


class TorchNormalModel_Classification_2(nn.Module):
class TorchNormalModel_Classification_2(Module):
"""
只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景;
"""
@@ -62,7 +67,7 @@ class TorchNormalModel_Classification_2(nn.Module):
return {"loss": loss, "preds": x, "target": y}


class TorchNormalModel_Classification_3(nn.Module):
class TorchNormalModel_Classification_3(Module):
"""
只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景;
关闭 auto_param_call,forward 只有一个 batch 参数;


+ 6
- 0
tests/pytest.ini View File

@@ -0,0 +1,6 @@
[pytest]
markers =
torch
paddle
jittor
torchpaddle

Loading…
Cancel
Save