Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
66eeee653b
28 changed files with 398 additions and 213 deletions
  1. +1
    -0
      fastNLP/core/dataloaders/jittor_dataloader/fdl.py
  2. +10
    -1
      fastNLP/core/dataloaders/torch_dataloader/fdl.py
  3. +2
    -2
      fastNLP/core/utils/dummy_class.py
  4. +4
    -4
      tests/core/callbacks/test_callback_event.py
  5. +11
    -7
      tests/core/callbacks/test_checkpoint_callback_torch.py
  6. +6
    -4
      tests/core/callbacks/test_more_evaluate_callback.py
  7. +1
    -0
      tests/core/collators/padders/test_get_padder.py
  8. +18
    -18
      tests/core/collators/padders/test_paddle_padder.py
  9. +1
    -2
      tests/core/collators/padders/test_raw_padder.py
  10. +1
    -1
      tests/core/collators/test_collator.py
  11. +105
    -8
      tests/core/controllers/test_trainer_event_trigger.py
  12. +5
    -5
      tests/core/controllers/test_trainer_other_things.py
  13. +6
    -4
      tests/core/controllers/test_trainer_w_evaluator_torch.py
  14. +7
    -3
      tests/core/controllers/test_trainer_wo_evaluator_torch.py
  15. +3
    -5
      tests/core/controllers/utils/test_utils.py
  16. +2
    -1
      tests/core/drivers/jittor_driver/test_single_device.py
  17. +149
    -118
      tests/core/drivers/torch_driver/test_ddp.py
  18. +8
    -8
      tests/core/drivers/torch_driver/test_initialize_torch_driver.py
  19. +9
    -4
      tests/core/metrics/test_accuracy_torch.py
  20. +8
    -3
      tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py
  21. +9
    -4
      tests/core/metrics/test_span_f1_rec_acc_torch.py
  22. +4
    -2
      tests/core/metrics/utils.py
  23. +1
    -0
      tests/core/utils/test_cache_results.py
  24. +2
    -1
      tests/envs/test_set_backend.py
  25. +3
    -1
      tests/helpers/callbacks/helper_callbacks_torch.py
  26. +6
    -2
      tests/helpers/datasets/torch_data.py
  27. +10
    -5
      tests/helpers/models/torch_model.py
  28. +6
    -0
      tests/pytest.ini

+ 1
- 0
fastNLP/core/dataloaders/jittor_dataloader/fdl.py View File

@@ -91,6 +91,7 @@ class JittorDataLoader:
self.dataset.dataset.set_attrs(batch_size=1) self.dataset.dataset.set_attrs(batch_size=1)
# 用户提供了 collate_fn,则会自动代替 jittor 提供 collate_batch 函数 # 用户提供了 collate_fn,则会自动代替 jittor 提供 collate_batch 函数
# self._collate_fn = _collate_fn # self._collate_fn = _collate_fn
self.cur_batch_indices = None


def __iter__(self): def __iter__(self):
# TODO 第一次迭代后不能设置collate_fn,设置是无效的 # TODO 第一次迭代后不能设置collate_fn,设置是无效的


+ 10
- 1
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -3,7 +3,7 @@ __all__ = [
'prepare_torch_dataloader' 'prepare_torch_dataloader'
] ]


from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping
from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List


from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.collators import Collator from fastNLP.core.collators import Collator
@@ -78,6 +78,7 @@ class TorchDataLoader(DataLoader):


if sampler is None and batch_sampler is None: if sampler is None and batch_sampler is None:
sampler = RandomSampler(dataset, shuffle=shuffle) sampler = RandomSampler(dataset, shuffle=shuffle)
shuffle=False


super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler,
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None,
@@ -154,6 +155,14 @@ class TorchDataLoader(DataLoader):
else: else:
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")


def get_batch_indices(self) -> List[int]:
"""
获取当前 batch 的 idx

:return:
"""
return self.cur_batch_indices



def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]],
batch_size: int = 1, batch_size: int = 1,


+ 2
- 2
fastNLP/core/utils/dummy_class.py View File

@@ -1,5 +1,5 @@
import functools import functools


class DummyClass: class DummyClass:
def __call__(self, *args, **kwargs):
return
def __init__(self, *args, **kwargs):
pass

+ 4
- 4
tests/core/callbacks/test_callback_event.py View File

@@ -162,7 +162,7 @@ class TestCallbackEvents:
def test_every(self): def test_every(self):


# 这里是什么样的事件是不影响的,因为我们是与 Trainer 拆分开了进行测试; # 这里是什么样的事件是不影响的,因为我们是与 Trainer 拆分开了进行测试;
event_state = Events.on_train_begin() # 什么都不输入是应当默认 every=1;
event_state = Event.on_train_begin() # 什么都不输入是应当默认 every=1;
@Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn) @Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn)
def _fn(data): def _fn(data):
return data return data
@@ -174,7 +174,7 @@ class TestCallbackEvents:
_res.append(cu_res) _res.append(cu_res)
assert _res == list(range(100)) assert _res == list(range(100))


event_state = Events.on_train_begin(every=10)
event_state = Event.on_train_begin(every=10)
@Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn) @Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn)
def _fn(data): def _fn(data):
return data return data
@@ -187,7 +187,7 @@ class TestCallbackEvents:
assert _res == [w - 1 for w in range(10, 101, 10)] assert _res == [w - 1 for w in range(10, 101, 10)]


def test_once(self): def test_once(self):
event_state = Events.on_train_begin(once=10)
event_state = Event.on_train_begin(once=10)


@Filter(once=event_state.once) @Filter(once=event_state.once)
def _fn(data): def _fn(data):
@@ -220,7 +220,7 @@ def test_callback_events_torch():
return True return True
return False return False


event_state = Events.on_train_begin(filter_fn=filter_fn)
event_state = Event.on_train_begin(filter_fn=filter_fn)


@Filter(filter_fn=event_state.filter_fn) @Filter(filter_fn=event_state.filter_fn)
def _fn(trainer, data): def _fn(trainer, data):


+ 11
- 7
tests/core/callbacks/test_checkpoint_callback_torch.py View File

@@ -2,9 +2,6 @@ import os
import pytest import pytest
from typing import Any from typing import Any
from dataclasses import dataclass from dataclasses import dataclass
from torch.utils.data import DataLoader
from torch.optim import SGD
import torch.distributed as dist
from pathlib import Path from pathlib import Path
import re import re
import time import time
@@ -20,6 +17,11 @@ from tests.helpers.datasets.torch_data import TorchArgMaxDataset
from torchmetrics import Accuracy from torchmetrics import Accuracy
from fastNLP.core.log import logger from fastNLP.core.log import logger


from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from torch.utils.data import DataLoader
from torch.optim import SGD
import torch.distributed as dist


@dataclass @dataclass
class ArgMaxDatasetConfig: class ArgMaxDatasetConfig:
@@ -550,7 +552,7 @@ def test_trainer_checkpoint_callback_2(


if version == 0: if version == 0:
callbacks = [ callbacks = [
TrainerCheckpointCallback(
CheckpointCallback(
monitor="acc", monitor="acc",
folder=path, folder=path,
every_n_epochs=None, every_n_epochs=None,
@@ -558,12 +560,13 @@ def test_trainer_checkpoint_callback_2(
topk=None, topk=None,
last=False, last=False,
on_exception=None, on_exception=None,
model_save_fn=model_save_fn
model_save_fn=model_save_fn,
save_object="trainer"
) )
] ]
elif version == 1: elif version == 1:
callbacks = [ callbacks = [
TrainerCheckpointCallback(
CheckpointCallback(
monitor="acc", monitor="acc",
folder=path, folder=path,
every_n_epochs=None, every_n_epochs=None,
@@ -571,7 +574,8 @@ def test_trainer_checkpoint_callback_2(
topk=1, topk=1,
last=True, last=True,
on_exception=None, on_exception=None,
model_save_fn=model_save_fn
model_save_fn=model_save_fn,
save_object="trainer"
) )
] ]




+ 6
- 4
tests/core/callbacks/test_more_evaluate_callback.py View File

@@ -12,9 +12,7 @@ import os
import pytest import pytest
from typing import Any from typing import Any
from dataclasses import dataclass from dataclasses import dataclass
from torch.utils.data import DataLoader
from torch.optim import SGD
import torch.distributed as dist

from pathlib import Path from pathlib import Path
import re import re


@@ -29,7 +27,11 @@ from torchmetrics import Accuracy
from fastNLP.core.metrics import Metric from fastNLP.core.metrics import Metric
from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.core.callbacks import MoreEvaluateCallback from fastNLP.core.callbacks import MoreEvaluateCallback

from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from torch.utils.data import DataLoader
from torch.optim import SGD
import torch.distributed as dist


@dataclass @dataclass
class ArgMaxDatasetConfig: class ArgMaxDatasetConfig:


+ 1
- 0
tests/core/collators/padders/test_get_padder.py View File

@@ -17,6 +17,7 @@ def test_get_element_shape_dtype():
@pytest.mark.parametrize('backend', ['raw', None, 'numpy', 'torch', 'jittor', 'paddle']) @pytest.mark.parametrize('backend', ['raw', None, 'numpy', 'torch', 'jittor', 'paddle'])
@pytest.mark.torch @pytest.mark.torch
@pytest.mark.paddle @pytest.mark.paddle
@pytest.mark.jittor
def test_get_padder_run(backend): def test_get_padder_run(backend):
if not _NEED_IMPORT_TORCH and backend == 'torch': if not _NEED_IMPORT_TORCH and backend == 'torch':
pytest.skip("No torch") pytest.skip("No torch")


+ 18
- 18
tests/core/collators/padders/test_paddle_padder.py View File

@@ -1,7 +1,7 @@
import numpy as np import numpy as np
import pytest import pytest


from fastNLP.core.collators.padders.paddle_padder import paddleTensorPadder, paddleSequencePadder, paddleNumberPadder
from fastNLP.core.collators.padders.paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder
from fastNLP.core.collators.padders.exceptions import DtypeError from fastNLP.core.collators.padders.exceptions import DtypeError
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.envs.imports import _NEED_IMPORT_PADDLE


@@ -10,9 +10,9 @@ if _NEED_IMPORT_PADDLE:




@pytest.mark.paddle @pytest.mark.paddle
class TestpaddleNumberPadder:
class TestPaddleNumberPadder:
def test_run(self): def test_run(self):
padder = paddleNumberPadder(ele_dtype=int, dtype=int, pad_val=-1)
padder = PaddleNumberPadder(ele_dtype=int, dtype=int, pad_val=-1)
a = [1, 2, 3] a = [1, 2, 3]
t_a = padder(a) t_a = padder(a)
assert isinstance(t_a, paddle.Tensor) assert isinstance(t_a, paddle.Tensor)
@@ -20,9 +20,9 @@ class TestpaddleNumberPadder:




@pytest.mark.paddle @pytest.mark.paddle
class TestpaddleSequencePadder:
class TestPaddleSequencePadder:
def test_run(self): def test_run(self):
padder = paddleSequencePadder(ele_dtype=int, dtype=int, pad_val=-1)
padder = PaddleSequencePadder(ele_dtype=int, dtype=int, pad_val=-1)
a = [[1, 2, 3], [3]] a = [[1, 2, 3], [3]]
a = padder(a) a = padder(a)
shape = a.shape shape = a.shape
@@ -32,20 +32,20 @@ class TestpaddleSequencePadder:
assert (a == b).sum().item() == shape[0]*shape[1] assert (a == b).sum().item() == shape[0]*shape[1]


def test_dtype_check(self): def test_dtype_check(self):
padder = paddleSequencePadder(ele_dtype=np.zeros(3, dtype=np.int32).dtype, dtype=int, pad_val=-1)
padder = PaddleSequencePadder(ele_dtype=np.zeros(3, dtype=np.int32).dtype, dtype=int, pad_val=-1)
with pytest.raises(DtypeError): with pytest.raises(DtypeError):
padder = paddleSequencePadder(ele_dtype=str, dtype=int, pad_val=-1)
padder = paddleSequencePadder(ele_dtype='int64', dtype=int, pad_val=-1)
padder = paddleSequencePadder(ele_dtype=np.int32, dtype=None, pad_val=-1)
padder = PaddleSequencePadder(ele_dtype=str, dtype=int, pad_val=-1)
padder = PaddleSequencePadder(ele_dtype='int64', dtype=int, pad_val=-1)
padder = PaddleSequencePadder(ele_dtype=np.int32, dtype=None, pad_val=-1)
a = padder([[1], [2, 322]]) a = padder([[1], [2, 322]])
# assert (a>67).sum()==0 # 因为int8的范围为-67 - 66 # assert (a>67).sum()==0 # 因为int8的范围为-67 - 66
padder = paddleSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1)
padder = PaddleSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1)




@pytest.mark.paddle @pytest.mark.paddle
class TestpaddleTensorPadder:
class TestPaddleTensorPadder:
def test_run(self): def test_run(self):
padder = paddleTensorPadder(ele_dtype=paddle.zeros((3,)).dtype, dtype=paddle.zeros((3,)).dtype, pad_val=-1)
padder = PaddleTensorPadder(ele_dtype=paddle.zeros((3,)).dtype, dtype=paddle.zeros((3,)).dtype, pad_val=-1)
a = [paddle.zeros((3,)), paddle.zeros((2,))] a = [paddle.zeros((3,)), paddle.zeros((2,))]
a = padder(a) a = padder(a)
shape = a.shape shape = a.shape
@@ -74,7 +74,7 @@ class TestpaddleTensorPadder:
[[0, -1], [-1, -1], [-1, -1]]]) [[0, -1], [-1, -1], [-1, -1]]])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]


padder = paddleTensorPadder(ele_dtype=paddle.zeros((3, )).dtype, dtype=paddle.zeros((3, )).dtype, pad_val=-1)
padder = PaddleTensorPadder(ele_dtype=paddle.zeros((3, )).dtype, dtype=paddle.zeros((3, )).dtype, pad_val=-1)
a = [paddle.zeros((3, 2)), paddle.zeros((2, 2))] a = [paddle.zeros((3, 2)), paddle.zeros((2, 2))]
a = padder(a) a = padder(a)
shape = a.shape shape = a.shape
@@ -85,7 +85,7 @@ class TestpaddleTensorPadder:
]) ])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]


padder = paddleTensorPadder(ele_dtype=paddle.zeros((3, 2)).dtype, dtype=None, pad_val=-1)
padder = PaddleTensorPadder(ele_dtype=paddle.zeros((3, 2)).dtype, dtype=None, pad_val=-1)
a = [np.zeros((3, 2), dtype=np.float32), np.zeros((2, 2), dtype=np.float32)] a = [np.zeros((3, 2), dtype=np.float32), np.zeros((2, 2), dtype=np.float32)]
a = padder(a) a = padder(a)
shape = a.shape shape = a.shape
@@ -96,11 +96,11 @@ class TestpaddleTensorPadder:
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]


def test_dtype_check(self): def test_dtype_check(self):
padder = paddleTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1)
padder = PaddleTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1)
with pytest.raises(DtypeError): with pytest.raises(DtypeError):
padder = paddleTensorPadder(ele_dtype=str, dtype=int, pad_val=-1)
padder = paddleTensorPadder(ele_dtype='int64', dtype=int, pad_val=-1)
padder = paddleTensorPadder(ele_dtype=int, dtype='int64', pad_val=-1)
padder = PaddleTensorPadder(ele_dtype=str, dtype=int, pad_val=-1)
padder = PaddleTensorPadder(ele_dtype='int64', dtype=int, pad_val=-1)
padder = PaddleTensorPadder(ele_dtype=int, dtype='int64', pad_val=-1)


def test_v1(self): def test_v1(self):
print(paddle.zeros((3, )).dtype) print(paddle.zeros((3, )).dtype)

+ 1
- 2
tests/core/collators/padders/test_raw_padder.py View File

@@ -23,7 +23,6 @@ class TestRawSequencePadder:
assert (a == b).sum().item() == shape[0]*shape[1] assert (a == b).sum().item() == shape[0]*shape[1]


def test_dtype_check(self): def test_dtype_check(self):
with pytest.raises(DtypeError):
padder = RawSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int)
padder = RawSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int)
with pytest.raises(DtypeError): with pytest.raises(DtypeError):
padder = RawSequencePadder(pad_val=-1, ele_dtype=str, dtype=int) padder = RawSequencePadder(pad_val=-1, ele_dtype=str, dtype=int)

+ 1
- 1
tests/core/collators/test_collator.py View File

@@ -4,7 +4,7 @@ import pytest


from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR


from fastNLP.core.collators.new_collator import Collator
from fastNLP.core.collators.collator import Collator




def _assert_equal(d1, d2): def _assert_equal(d1, d2):


+ 105
- 8
tests/core/controllers/test_trainer_event_trigger.py View File

@@ -1,10 +1,7 @@
import pytest import pytest
from typing import Any from typing import Any
from dataclasses import dataclass from dataclasses import dataclass
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
import torch.distributed as dist



from fastNLP.core.controllers.trainer import Trainer from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.callbacks.callback_event import Event from fastNLP.core.callbacks.callback_event import Event
@@ -12,6 +9,12 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification
from tests.helpers.callbacks.helper_callbacks import RecordTrainerEventTriggerCallback from tests.helpers.callbacks.helper_callbacks import RecordTrainerEventTriggerCallback
from tests.helpers.utils import magic_argv_env_context, Capturing from tests.helpers.utils import magic_argv_env_context, Capturing
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
import torch.distributed as dist




@dataclass @dataclass
@@ -96,10 +99,10 @@ def test_trainer_event_trigger_1(
if dist.is_initialized(): if dist.is_initialized():
dist.destroy_process_group() dist.destroy_process_group()


Event_attrs = Event.__dict__
for k, v in Event_attrs.items():
if isinstance(v, staticmethod):
assert k in output[0]
Event_attrs = Event.__dict__
for k, v in Event_attrs.items():
if isinstance(v, staticmethod):
assert k in output[0]


@pytest.mark.torch @pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) @pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7])
@@ -211,7 +214,101 @@ def test_trainer_event_trigger_2(
) )


trainer.run() trainer.run()

if dist.is_initialized():
dist.destroy_process_group()

Event_attrs = Event.__dict__ Event_attrs = Event.__dict__
for k, v in Event_attrs.items(): for k, v in Event_attrs.items():
if isinstance(v, staticmethod): if isinstance(v, staticmethod):
assert k in output[0] assert k in output[0]


@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 6)])
@pytest.mark.torch
@magic_argv_env_context
def test_trainer_event_trigger_3(
model_and_optimizers: TrainerParameters,
driver,
device,
n_epochs=2,
):
import re

once_message_1 = "This message should be typed 1 times."
once_message_2 = "test_filter_fn"
once_message_3 = "once message 3"
twice_message = "twice message hei hei"

@Trainer.on(Event.on_train_epoch_begin(every=2))
def train_epoch_begin_1(trainer):
print(once_message_1)

@Trainer.on(Event.on_train_epoch_begin())
def train_epoch_begin_2(trainer):
print(twice_message)

@Trainer.on(Event.on_train_epoch_begin(once=2))
def train_epoch_begin_3(trainer):
print(once_message_3)

def filter_fn(filter, trainer):
if trainer.cur_epoch_idx == 1:
return True
else:
return False

@Trainer.on(Event.on_train_epoch_end(filter_fn=filter_fn))
def test_filter_fn(trainer):
print(once_message_2)

with Capturing() as output:
trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,

n_epochs=n_epochs,
)

trainer.run()

if dist.is_initialized():
dist.destroy_process_group()


once_pattern_1 = re.compile(once_message_1)
once_pattern_2 = re.compile(once_message_2)
once_pattern_3 = re.compile(once_message_3)
twice_pattern = re.compile(twice_message)

once_res_1 = once_pattern_1.findall(output[0])
assert len(once_res_1) == 1
once_res_2 = once_pattern_2.findall(output[0])
assert len(once_res_2) == 1
once_res_3 = once_pattern_3.findall(output[0])
assert len(once_res_3) == 1
twice_res = twice_pattern.findall(output[0])
assert len(twice_res) == 2

















+ 5
- 5
tests/core/controllers/test_trainer_other_things.py View File

@@ -1,22 +1,22 @@
import pytest import pytest


from fastNLP.core.controllers.trainer import Trainer from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.callbacks import Events
from fastNLP.core.callbacks import Event
from tests.helpers.utils import magic_argv_env_context from tests.helpers.utils import magic_argv_env_context




@magic_argv_env_context @magic_argv_env_context
def test_trainer_torch_without_evaluator(): def test_trainer_torch_without_evaluator():
@Trainer.on(Events.on_train_epoch_begin(every=10))
@Trainer.on(Event.on_train_epoch_begin(every=10), marker="test_trainer_other_things")
def fn1(trainer): def fn1(trainer):
pass pass


@Trainer.on(Events.on_train_batch_begin(every=10))
@Trainer.on(Event.on_train_batch_begin(every=10), marker="test_trainer_other_things")
def fn2(trainer, batch, indices): def fn2(trainer, batch, indices):
pass pass


with pytest.raises(AssertionError):
@Trainer.on(Events.on_train_batch_begin(every=10))
with pytest.raises(BaseException):
@Trainer.on(Event.on_train_batch_begin(every=10), marker="test_trainer_other_things")
def fn3(trainer, batch): def fn3(trainer, batch):
pass pass




+ 6
- 4
tests/core/controllers/test_trainer_w_evaluator_torch.py View File

@@ -2,9 +2,7 @@
注意这一文件中的测试函数都应当是在 `test_trainer_w_evaluator_torch.py` 中已经测试过的测试函数的基础上加上 metrics 和 evaluator 修改而成; 注意这一文件中的测试函数都应当是在 `test_trainer_w_evaluator_torch.py` 中已经测试过的测试函数的基础上加上 metrics 和 evaluator 修改而成;
""" """
import pytest import pytest
from torch.optim import SGD
from torch.utils.data import DataLoader
import torch.distributed as dist

from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
from torchmetrics import Accuracy from torchmetrics import Accuracy
@@ -14,7 +12,11 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback
from tests.helpers.utils import magic_argv_env_context from tests.helpers.utils import magic_argv_env_context

from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from torch.optim import SGD
from torch.utils.data import DataLoader
import torch.distributed as dist


@dataclass @dataclass
class NormalClassificationTrainTorchConfig: class NormalClassificationTrainTorchConfig:


+ 7
- 3
tests/core/controllers/test_trainer_wo_evaluator_torch.py View File

@@ -2,9 +2,7 @@ import os.path
import subprocess import subprocess
import sys import sys
import pytest import pytest
import torch.distributed as dist
from torch.optim import SGD
from torch.utils.data import DataLoader

from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
from pathlib import Path from pathlib import Path
@@ -16,6 +14,11 @@ from tests.helpers.callbacks.helper_callbacks import RecordLossCallback
from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch
from tests.helpers.utils import magic_argv_env_context, Capturing from tests.helpers.utils import magic_argv_env_context, Capturing
from fastNLP.core import rank_zero_rm from fastNLP.core import rank_zero_rm
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch.distributed as dist
from torch.optim import SGD
from torch.utils.data import DataLoader




@dataclass @dataclass
@@ -286,6 +289,7 @@ def test_trainer_on_exception(
dist.destroy_process_group() dist.destroy_process_group()




@pytest.mark.torch
@pytest.mark.parametrize("version", [0, 1, 2, 3]) @pytest.mark.parametrize("version", [0, 1, 2, 3])
@magic_argv_env_context @magic_argv_env_context
def test_torch_distributed_launch_1(version): def test_torch_distributed_launch_1(version):


+ 3
- 5
tests/core/controllers/utils/test_utils.py View File

@@ -11,7 +11,7 @@ class Test_WrapDataLoader:
for sanity_batches in all_sanity_batches: for sanity_batches in all_sanity_batches:
data = NormalSampler(num_of_data=1000) data = NormalSampler(num_of_data=1000)
wrapper = _TruncatedDataLoader(dataloader=data, num_batches=sanity_batches) wrapper = _TruncatedDataLoader(dataloader=data, num_batches=sanity_batches)
dataloader = iter(wrapper(dataloader=data))
dataloader = iter(wrapper)
mark = 0 mark = 0
while True: while True:
try: try:
@@ -32,8 +32,7 @@ class Test_WrapDataLoader:
dataset = TorchNormalDataset(num_of_data=1000) dataset = TorchNormalDataset(num_of_data=1000)
dataloader = DataLoader(dataset, batch_size=bs, shuffle=True) dataloader = DataLoader(dataset, batch_size=bs, shuffle=True)
wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches) wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches)
dataloader = wrapper(dataloader)
dataloader = iter(dataloader)
dataloader = iter(wrapper)
all_supposed_running_data_num = 0 all_supposed_running_data_num = 0
while True: while True:
try: try:
@@ -55,6 +54,5 @@ class Test_WrapDataLoader:
dataset = TorchNormalDataset(num_of_data=1000) dataset = TorchNormalDataset(num_of_data=1000)
dataloader = DataLoader(dataset, batch_size=bs, shuffle=True) dataloader = DataLoader(dataset, batch_size=bs, shuffle=True)
wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches) wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches)
dataloader = wrapper(dataloader)
length.append(len(dataloader))
length.append(len(wrapper))
assert length == reduce(lambda x, y: x+y, [all_sanity_batches for _ in range(len(bses))]) assert length == reduce(lambda x, y: x+y, [all_sanity_batches for _ in range(len(bses))])

+ 2
- 1
tests/core/drivers/jittor_driver/test_single_device.py View File

@@ -15,7 +15,7 @@ else:






class Model (Module):
class Model(Module):
def __init__ (self): def __init__ (self):
super (Model, self).__init__() super (Model, self).__init__()
self.conv1 = nn.Conv (3, 32, 3, 1) # no padding self.conv1 = nn.Conv (3, 32, 3, 1) # no padding
@@ -45,6 +45,7 @@ class Model (Module):
return x return x


@pytest.mark.jittor @pytest.mark.jittor
@pytest.mark.skip("Skip jittor tests now.")
class TestSingleDevice: class TestSingleDevice:


def test_on_gpu_without_fp16(self): def test_on_gpu_without_fp16(self):


+ 149
- 118
tests/core/drivers/torch_driver/test_ddp.py View File

@@ -13,12 +13,13 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset
from tests.helpers.utils import magic_argv_env_context from tests.helpers.utils import magic_argv_env_context
from fastNLP.core import rank_zero_rm from fastNLP.core import rank_zero_rm
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, BatchSampler


import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, BatchSampler

def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"):
def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="all"):
torch_model = TorchNormalModel_Classification_1(num_labels, feature_dimension) torch_model = TorchNormalModel_Classification_1(num_labels, feature_dimension)
torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01)
device = [torch.device(i) for i in device] device = [torch.device(i) for i in device]
@@ -72,108 +73,100 @@ def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed=
# #
############################################################################ ############################################################################


@pytest.mark.torch
@magic_argv_env_context
def test_multi_drivers():
"""
测试使用了多个 TorchDDPDriver 的情况。
"""
generate_driver(10, 10)
generate_driver(20, 10)
with pytest.raises(RuntimeError):
# 设备设置不同,应该报错
generate_driver(20, 3, device=[0,1,2])
assert False
dist.barrier()

if dist.is_initialized():
dist.destroy_process_group()

@pytest.mark.torch @pytest.mark.torch
class TestDDPDriverFunction: class TestDDPDriverFunction:
""" """
测试 TorchDDPDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 测试 TorchDDPDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题
""" """


@classmethod
def setup_class(cls):
cls.driver = generate_driver(10, 10)

@magic_argv_env_context @magic_argv_env_context
def test_multi_drivers(self):
def test_simple_functions(self):
""" """
测试使用了多个 TorchDDPDriver 的情况。
简单测试多个函数
""" """
driver2 = generate_driver(20, 10)
with pytest.raises(RuntimeError):
# 设备设置不同,应该报错
driver3 = generate_driver(20, 3, device=[0,1,2])
assert False
dist.barrier()
driver = generate_driver(10, 10)


@magic_argv_env_context
def test_move_data_to_device(self):
""" """
这个函数仅调用了torch_move_data_to_device,测试例在tests/core/utils/test_torch_utils.py中
就不重复测试了
测试 move_data_to_device 函数。这个函数仅调用了 torch_move_data_to_device ,测试例在
tests/core/utils/test_torch_utils.py中,就不重复测试了
""" """
self.driver.move_data_to_device(torch.rand((32, 64)))

driver.move_data_to_device(torch.rand((32, 64)))
dist.barrier() dist.barrier()


@magic_argv_env_context
def test_is_distributed(self):
""" """
测试 is_distributed 函数 测试 is_distributed 函数
""" """
assert self.driver.is_distributed() == True
assert driver.is_distributed() == True
dist.barrier() dist.barrier()


@magic_argv_env_context
def test_get_no_sync_context(self):
""" """
测试 get_no_sync_context 函数 测试 get_no_sync_context 函数
""" """
res = self.driver.get_model_no_sync_context()
res = driver.get_model_no_sync_context()
dist.barrier() dist.barrier()


@magic_argv_env_context
def test_is_global_zero(self):
""" """
测试 is_global_zero 函数 测试 is_global_zero 函数
""" """
self.driver.is_global_zero()
driver.is_global_zero()
dist.barrier() dist.barrier()


@magic_argv_env_context
def test_unwrap_model(self):
""" """
测试 unwrap_model 函数 测试 unwrap_model 函数
""" """
self.driver.unwrap_model()
driver.unwrap_model()
dist.barrier() dist.barrier()


@magic_argv_env_context
def test_get_local_rank(self):
""" """
测试 get_local_rank 函数 测试 get_local_rank 函数
""" """
self.driver.get_local_rank()
driver.get_local_rank()
dist.barrier() dist.barrier()


@magic_argv_env_context
def test_all_gather(self):
""" """
测试 all_gather 函数 测试 all_gather 函数
详细的测试在 test_dist_utils.py 中完成 详细的测试在 test_dist_utils.py 中完成
""" """
obj = { obj = {
"rank": self.driver.global_rank
"rank": driver.global_rank
} }
obj_list = self.driver.all_gather(obj, group=None)
obj_list = driver.all_gather(obj, group=None)
for i, res in enumerate(obj_list): for i, res in enumerate(obj_list):
assert res["rank"] == i assert res["rank"] == i


@magic_argv_env_context
@pytest.mark.parametrize("src_rank", ([0, 1]))
def test_broadcast_object(self, src_rank):
""" """
测试 broadcast_object 函数 测试 broadcast_object 函数
详细的函数在 test_dist_utils.py 中完成 详细的函数在 test_dist_utils.py 中完成
""" """
if self.driver.global_rank == src_rank:
if driver.global_rank == 0:
obj = { obj = {
"rank": self.driver.global_rank
"rank": driver.global_rank
} }
else: else:
obj = None obj = None
res = self.driver.broadcast_object(obj, src=src_rank)
assert res["rank"] == src_rank
res = driver.broadcast_object(obj, src=0)
assert res["rank"] == 0

if dist.is_initialized():
dist.destroy_process_group()


############################################################################ ############################################################################
# #
@@ -187,7 +180,6 @@ class TestSetDistReproDataloader:
@classmethod @classmethod
def setup_class(cls): def setup_class(cls):
cls.device = [0, 1] cls.device = [0, 1]
cls.driver = generate_driver(10, 10, device=cls.device)


def setup_method(self): def setup_method(self):
self.dataset = TorchNormalDataset(40) self.dataset = TorchNormalDataset(40)
@@ -204,17 +196,20 @@ class TestSetDistReproDataloader:
测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现
此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler 此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler
""" """
driver = generate_driver(10, 10, device=self.device)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, batch_sampler, False)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert replaced_loader.batch_sampler is batch_sampler assert replaced_loader.batch_sampler is batch_sampler
self.check_distributed_sampler(replaced_loader.batch_sampler) self.check_distributed_sampler(replaced_loader.batch_sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle)
dist.barrier() dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()


@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
@@ -223,9 +218,10 @@ class TestSetDistReproDataloader:
测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现
此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler 此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler
""" """
driver = generate_driver(10, 10, device=self.device)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
sampler = RandomSampler(self.dataset, shuffle=shuffle) sampler = RandomSampler(self.dataset, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, sampler, False)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler) assert isinstance(replaced_loader.batch_sampler, BatchSampler)
@@ -234,9 +230,11 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.sampler is sampler assert replaced_loader.batch_sampler.sampler is sampler
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle)


dist.barrier() dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()
""" """
传入的参数 `dist` 为 None 的情况,这种情况出现在 trainer 和 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` 传入的参数 `dist` 为 None 的情况,这种情况出现在 trainer 和 evaluator 的初始化过程中,用户指定了 `use_dist_sampler`
@@ -251,15 +249,17 @@ class TestSetDistReproDataloader:
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现
当用户在 driver 之外初始化了分布式环境时,fastnlp 不支持进行断点重训,此时应该报错 当用户在 driver 之外初始化了分布式环境时,fastnlp 不支持进行断点重训,此时应该报错
""" """
driver = generate_driver(10, 10, device=self.device)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
# 应当抛出 RuntimeError # 应当抛出 RuntimeError
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, True)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, True)


dist.barrier() dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()


@magic_argv_env_context @magic_argv_env_context
# @pytest.mark.parametrize("shuffle", ([True, False]))
@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
def test_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle): def test_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle):
""" """
@@ -268,21 +268,24 @@ class TestSetDistReproDataloader:
此时传入的 dataloader 的 batch_sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler 此时传入的 dataloader 的 batch_sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler
和原 dataloader 相同 和原 dataloader 相同
""" """
driver = generate_driver(10, 10, device=self.device)
dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False)
dataloader.batch_sampler.set_distributed( dataloader.batch_sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank,
num_replicas=driver.world_size,
rank=driver.global_rank,
pad=True pad=True
) )
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.batch_sampler.batch_size == 4
self.check_distributed_sampler(dataloader.batch_sampler) self.check_distributed_sampler(dataloader.batch_sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle)


dist.barrier() dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()


@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
@@ -292,12 +295,13 @@ class TestSetDistReproDataloader:
此时传入的 dataloader 的 batch_sampler.sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 此时传入的 dataloader 的 batch_sampler.sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其
batch_sampler.sampler 和原 dataloader 相同 batch_sampler.sampler 和原 dataloader 相同
""" """
driver = generate_driver(10, 10, device=self.device)
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False)
dataloader.batch_sampler.sampler.set_distributed( dataloader.batch_sampler.sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank
num_replicas=driver.world_size,
rank=driver.global_rank
) )
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler) assert isinstance(replaced_loader.batch_sampler, BatchSampler)
@@ -307,9 +311,11 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.batch_sampler.drop_last == False assert replaced_loader.batch_sampler.drop_last == False
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle)
dist.barrier() dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()


@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
@@ -318,11 +324,14 @@ class TestSetDistReproDataloader:
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现
此时直接返回原来的 dataloader,不做任何处理。 此时直接返回原来的 dataloader,不做任何处理。
""" """
driver = generate_driver(10, 10, device=self.device)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False)


assert replaced_loader is dataloader assert replaced_loader is dataloader
dist.barrier() dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()


""" """
传入的参数 `dist` 为 'dist' 的情况,这种情况出现在 trainer 的初始化过程中,用户指定了 `use_dist_sampler` 参数 传入的参数 `dist` 为 'dist' 的情况,这种情况出现在 trainer 的初始化过程中,用户指定了 `use_dist_sampler` 参数
@@ -337,12 +346,13 @@ class TestSetDistReproDataloader:
的表现 的表现
此时应该返回一个新的 dataloader,其batch_sampler 和原 dataloader 相同,且应该正确地设置了分布式相关的属性 此时应该返回一个新的 dataloader,其batch_sampler 和原 dataloader 相同,且应该正确地设置了分布式相关的属性
""" """
driver = generate_driver(10, 10, device=self.device)
dataloader = DataLoader( dataloader = DataLoader(
dataset=self.dataset, dataset=self.dataset,
batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle)
) )
dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
@@ -351,6 +361,8 @@ class TestSetDistReproDataloader:
assert replaced_loader.drop_last == dataloader.drop_last assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler) self.check_distributed_sampler(replaced_loader.batch_sampler)
dist.barrier() dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()


@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
@@ -361,8 +373,9 @@ class TestSetDistReproDataloader:
此时应该返回一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同,且应该正确地设置了分布式相关 此时应该返回一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同,且应该正确地设置了分布式相关
的属性 的属性
""" """
driver = generate_driver(10, 10, device=self.device)
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
@@ -372,6 +385,8 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier() dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()


@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
@@ -381,8 +396,9 @@ class TestSetDistReproDataloader:
此时应该返回一个新的 dataloader,并替换其 batch_sampler.sampler 为 RandomSampler,且应该正确设置了分布式相关 此时应该返回一个新的 dataloader,并替换其 batch_sampler.sampler 为 RandomSampler,且应该正确设置了分布式相关
的属性 的属性
""" """
driver = generate_driver(10, 10, device=self.device)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler) assert isinstance(replaced_loader.batch_sampler, BatchSampler)
@@ -392,6 +408,8 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier() dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()


""" """
传入的参数 `dist` 为 'unrepeatdist' 的情况,这种情况出现在 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` 参数 传入的参数 `dist` 为 'unrepeatdist' 的情况,这种情况出现在 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` 参数
@@ -407,8 +425,9 @@ class TestSetDistReproDataloader:
此时应该返回一个新的 dataloader,且将原来的 Sampler 替换为 UnrepeatedRandomSampler,且正确地设置了分布式相关 此时应该返回一个新的 dataloader,且将原来的 Sampler 替换为 UnrepeatedRandomSampler,且正确地设置了分布式相关
的属性 的属性
""" """
driver = generate_driver(10, 10, device=self.device)
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler) assert isinstance(replaced_loader.batch_sampler, BatchSampler)
@@ -418,6 +437,8 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier() dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()


@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
@@ -427,8 +448,9 @@ class TestSetDistReproDataloader:
的表现 的表现
此时应该返回一个新的 dataloader,且重新实例化了原来的 Sampler 此时应该返回一个新的 dataloader,且重新实例化了原来的 Sampler
""" """
driver = generate_driver(10, 10, device=self.device)
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=True) dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=True)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler) assert isinstance(replaced_loader.batch_sampler, BatchSampler)
@@ -439,6 +461,8 @@ class TestSetDistReproDataloader:
assert replaced_loader.drop_last == dataloader.drop_last assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier() dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()


@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
@@ -448,8 +472,9 @@ class TestSetDistReproDataloader:
此时应该返回一个新的 dataloader,且将 sampler 替换为 UnrepeatedSequentialSampler,并正确地设置了分布式相关 此时应该返回一个新的 dataloader,且将 sampler 替换为 UnrepeatedSequentialSampler,并正确地设置了分布式相关
的属性 的属性
""" """
driver = generate_driver(10, 10, device=self.device)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler) assert isinstance(replaced_loader.batch_sampler, BatchSampler)
@@ -459,6 +484,8 @@ class TestSetDistReproDataloader:
assert replaced_loader.drop_last == dataloader.drop_last assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier() dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()


def check_distributed_sampler(self, sampler): def check_distributed_sampler(self, sampler):
""" """
@@ -469,7 +496,7 @@ class TestSetDistReproDataloader:
if not isinstance(sampler, UnrepeatedSampler): if not isinstance(sampler, UnrepeatedSampler):
assert sampler.pad == True assert sampler.pad == True


def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle):
def check_set_dist_repro_dataloader(self, driver, dataloader, replaced_loader, shuffle):
""" """
测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确 测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确
""" """
@@ -501,8 +528,8 @@ class TestSetDistReproDataloader:
drop_last=False, drop_last=False,
) )
new_loader.batch_sampler.set_distributed( new_loader.batch_sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank,
num_replicas=driver.world_size,
rank=driver.global_rank,
pad=True pad=True
) )
new_loader.batch_sampler.load_state_dict(sampler_states) new_loader.batch_sampler.load_state_dict(sampler_states)
@@ -512,8 +539,8 @@ class TestSetDistReproDataloader:
# 重新构造 dataloader # 重新构造 dataloader
new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, drop_last=False) new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, drop_last=False)
new_loader.batch_sampler.sampler.set_distributed( new_loader.batch_sampler.sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank
num_replicas=driver.world_size,
rank=driver.global_rank
) )
new_loader.batch_sampler.sampler.load_state_dict(sampler_states) new_loader.batch_sampler.sampler.load_state_dict(sampler_states)
for idx, batch in enumerate(new_loader): for idx, batch in enumerate(new_loader):
@@ -534,11 +561,6 @@ class TestSaveLoad:
测试多卡情况下 save 和 load 相关函数的表现 测试多卡情况下 save 和 load 相关函数的表现
""" """


@classmethod
def setup_class(cls):
# 不在这里 setup 的话会报错
cls.driver = generate_driver(10, 10)

def setup_method(self): def setup_method(self):
self.dataset = TorchArgMaxDataset(10, 20) self.dataset = TorchArgMaxDataset(10, 20)


@@ -552,26 +574,26 @@ class TestSaveLoad:
path = "model" path = "model"


dataloader = DataLoader(self.dataset, batch_size=2) dataloader = DataLoader(self.dataset, batch_size=2)
self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10)
driver1, driver2 = generate_driver(10, 10), generate_driver(10, 10)


self.driver1.save_model(path, only_state_dict)
driver1.save_model(path, only_state_dict)


# 同步 # 同步
dist.barrier() dist.barrier()
self.driver2.load_model(path, only_state_dict)
driver2.load_model(path, only_state_dict)


for idx, batch in enumerate(dataloader): for idx, batch in enumerate(dataloader):
batch = self.driver1.move_data_to_device(batch)
res1 = self.driver1.model(
batch = driver1.move_data_to_device(batch)
res1 = driver1.model(
batch, batch,
fastnlp_fn=self.driver1.model.module.model.evaluate_step,
fastnlp_fn=driver1.model.module.model.evaluate_step,
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model # Driver.model -> DataParallel.module -> _FleetWrappingModel.model
fastnlp_signature_fn=None, fastnlp_signature_fn=None,
wo_auto_param_call=False, wo_auto_param_call=False,
) )
res2 = self.driver2.model(
res2 = driver2.model(
batch, batch,
fastnlp_fn=self.driver2.model.module.model.evaluate_step,
fastnlp_fn=driver2.model.module.model.evaluate_step,
fastnlp_signature_fn=None, fastnlp_signature_fn=None,
wo_auto_param_call=False, wo_auto_param_call=False,
) )
@@ -580,6 +602,9 @@ class TestSaveLoad:
finally: finally:
rank_zero_rm(path) rank_zero_rm(path)


if dist.is_initialized():
dist.destroy_process_group()

@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False])) @pytest.mark.parametrize("only_state_dict", ([True, False]))
@pytest.mark.parametrize("fp16", ([True, False])) @pytest.mark.parametrize("fp16", ([True, False]))
@@ -593,7 +618,7 @@ class TestSaveLoad:
path = "model.ckp" path = "model.ckp"
num_replicas = len(device) num_replicas = len(device)


self.driver1, self.driver2 = generate_driver(10, 10, device=device, fp16=fp16), \
driver1, driver2 = generate_driver(10, 10, device=device, fp16=fp16), \
generate_driver(10, 10, device=device, fp16=False) generate_driver(10, 10, device=device, fp16=False)
dataloader = dataloader_with_bucketedbatchsampler( dataloader = dataloader_with_bucketedbatchsampler(
self.dataset, self.dataset,
@@ -603,8 +628,8 @@ class TestSaveLoad:
drop_last=False drop_last=False
) )
dataloader.batch_sampler.set_distributed( dataloader.batch_sampler.set_distributed(
num_replicas=self.driver1.world_size,
rank=self.driver1.global_rank,
num_replicas=driver1.world_size,
rank=driver1.global_rank,
pad=True pad=True
) )
num_consumed_batches = 2 num_consumed_batches = 2
@@ -623,7 +648,7 @@ class TestSaveLoad:
# 保存状态 # 保存状态
sampler_states = dataloader.batch_sampler.state_dict() sampler_states = dataloader.batch_sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches} save_states = {"num_consumed_batches": num_consumed_batches}
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
# 加载 # 加载
# 更改 batch_size # 更改 batch_size
dataloader = dataloader_with_bucketedbatchsampler( dataloader = dataloader_with_bucketedbatchsampler(
@@ -634,11 +659,11 @@ class TestSaveLoad:
drop_last=False drop_last=False
) )
dataloader.batch_sampler.set_distributed( dataloader.batch_sampler.set_distributed(
num_replicas=self.driver2.world_size,
rank=self.driver2.global_rank,
num_replicas=driver2.world_size,
rank=driver2.global_rank,
pad=True pad=True
) )
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader") replaced_loader = load_states.pop("dataloader")
# 1. 检查 optimizer 的状态 # 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空 # TODO optimizer 的 state_dict 总是为空
@@ -652,7 +677,7 @@ class TestSaveLoad:


# 3. 检查 fp16 是否被加载 # 3. 检查 fp16 是否被加载
if fp16: if fp16:
assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler)
assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler)


# 4. 检查 model 的参数是否正确 # 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx # 5. 检查 batch_idx
@@ -664,16 +689,16 @@ class TestSaveLoad:


left_x_batches.update(batch["x"]) left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"]) left_y_batches.update(batch["y"])
res1 = self.driver1.model(
res1 = driver1.model(
batch, batch,
fastnlp_fn=self.driver1.model.module.model.evaluate_step,
fastnlp_fn=driver1.model.module.model.evaluate_step,
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model # Driver.model -> DataParallel.module -> _FleetWrappingModel.model
fastnlp_signature_fn=None, fastnlp_signature_fn=None,
wo_auto_param_call=False, wo_auto_param_call=False,
) )
res2 = self.driver2.model(
res2 = driver2.model(
batch, batch,
fastnlp_fn=self.driver2.model.module.model.evaluate_step,
fastnlp_fn=driver2.model.module.model.evaluate_step,
fastnlp_signature_fn=None, fastnlp_signature_fn=None,
wo_auto_param_call=False, wo_auto_param_call=False,
) )
@@ -686,6 +711,9 @@ class TestSaveLoad:
finally: finally:
rank_zero_rm(path) rank_zero_rm(path)


if dist.is_initialized():
dist.destroy_process_group()

@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False])) @pytest.mark.parametrize("only_state_dict", ([True, False]))
@pytest.mark.parametrize("fp16", ([True, False])) @pytest.mark.parametrize("fp16", ([True, False]))
@@ -700,13 +728,13 @@ class TestSaveLoad:


num_replicas = len(device) num_replicas = len(device)


self.driver1 = generate_driver(10, 10, device=device, fp16=fp16)
self.driver2 = generate_driver(10, 10, device=device, fp16=False)
driver1 = generate_driver(10, 10, device=device, fp16=fp16)
driver2 = generate_driver(10, 10, device=device, fp16=False)


dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False) dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False)
dataloader.batch_sampler.sampler.set_distributed( dataloader.batch_sampler.sampler.set_distributed(
num_replicas=self.driver1.world_size,
rank=self.driver1.global_rank,
num_replicas=driver1.world_size,
rank=driver1.global_rank,
pad=True pad=True
) )
num_consumed_batches = 2 num_consumed_batches = 2
@@ -726,18 +754,18 @@ class TestSaveLoad:
sampler_states = dataloader.batch_sampler.sampler.state_dict() sampler_states = dataloader.batch_sampler.sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches} save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict: if only_state_dict:
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else: else:
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))])
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))])
# 加载 # 加载
# 更改 batch_size # 更改 batch_size
dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False) dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False)
dataloader.batch_sampler.sampler.set_distributed( dataloader.batch_sampler.sampler.set_distributed(
num_replicas=self.driver2.world_size,
rank=self.driver2.global_rank,
num_replicas=driver2.world_size,
rank=driver2.global_rank,
pad=True pad=True
) )
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader") replaced_loader = load_states.pop("dataloader")


# 1. 检查 optimizer 的状态 # 1. 检查 optimizer 的状态
@@ -753,7 +781,7 @@ class TestSaveLoad:
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"]
# 3. 检查 fp16 是否被加载 # 3. 检查 fp16 是否被加载
if fp16: if fp16:
assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler)
assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler)


# 4. 检查 model 的参数是否正确 # 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx # 5. 检查 batch_idx
@@ -765,16 +793,16 @@ class TestSaveLoad:


left_x_batches.update(batch["x"]) left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"]) left_y_batches.update(batch["y"])
res1 = self.driver1.model(
res1 = driver1.model(
batch, batch,
fastnlp_fn=self.driver1.model.module.model.evaluate_step,
fastnlp_fn=driver1.model.module.model.evaluate_step,
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model # Driver.model -> DataParallel.module -> _FleetWrappingModel.model
fastnlp_signature_fn=None, fastnlp_signature_fn=None,
wo_auto_param_call=False, wo_auto_param_call=False,
) )
res2 = self.driver2.model(
res2 = driver2.model(
batch, batch,
fastnlp_fn=self.driver2.model.module.model.evaluate_step,
fastnlp_fn=driver2.model.module.model.evaluate_step,
fastnlp_signature_fn=None, fastnlp_signature_fn=None,
wo_auto_param_call=False, wo_auto_param_call=False,
) )
@@ -786,4 +814,7 @@ class TestSaveLoad:
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas


finally: finally:
rank_zero_rm(path)
rank_zero_rm(path)

if dist.is_initialized():
dist.destroy_process_group()

+ 8
- 8
tests/core/drivers/torch_driver/test_initialize_torch_driver.py View File

@@ -2,12 +2,14 @@ import pytest


from fastNLP.core.drivers import TorchSingleDriver, TorchDDPDriver from fastNLP.core.drivers import TorchSingleDriver, TorchDDPDriver
from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver
from fastNLP.envs import get_gpu_count
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.utils import magic_argv_env_context from tests.helpers.utils import magic_argv_env_context

import torch

from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
from torch import device as torchdevice
else:
from fastNLP.core.utils.dummy_class import DummyClass as torchdevice


@pytest.mark.torch @pytest.mark.torch
def test_incorrect_driver(): def test_incorrect_driver():
@@ -20,7 +22,7 @@ def test_incorrect_driver():
@pytest.mark.torch @pytest.mark.torch
@pytest.mark.parametrize( @pytest.mark.parametrize(
"device", "device",
["cpu", "cuda:0", 0, torch.device("cuda:0")]
["cpu", "cuda:0", 0, torchdevice("cuda:0")]
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"driver", "driver",
@@ -83,7 +85,6 @@ def test_get_ddp(driver, device):
("driver", "device"), ("driver", "device"),
[("torch_ddp", "cpu")] [("torch_ddp", "cpu")]
) )
@magic_argv_env_context
def test_get_ddp_cpu(driver, device): def test_get_ddp_cpu(driver, device):
""" """
测试试图在 cpu 上初始化分布式训练的情况 测试试图在 cpu 上初始化分布式训练的情况
@@ -96,13 +97,12 @@ def test_get_ddp_cpu(driver, device):
@pytest.mark.torch @pytest.mark.torch
@pytest.mark.parametrize( @pytest.mark.parametrize(
"device", "device",
[-2, [0, torch.cuda.device_count() + 1, 3], [-2], torch.cuda.device_count() + 1]
[-2, [0, 20, 3], [-2], 20]
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"driver", "driver",
["torch", "torch_ddp"] ["torch", "torch_ddp"]
) )
@magic_argv_env_context
def test_device_out_of_range(driver, device): def test_device_out_of_range(driver, device):
""" """
测试传入的device超过范围的情况 测试传入的device超过范围的情况


+ 9
- 4
tests/core/metrics/test_accuracy_torch.py View File

@@ -7,15 +7,20 @@ import copy
import socket import socket
import pytest import pytest
import numpy as np import numpy as np
import torch
import torch.distributed
from torch.multiprocessing import Pool, set_start_method

from sklearn.metrics import accuracy_score as sklearn_accuracy from sklearn.metrics import accuracy_score as sklearn_accuracy


from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.metrics.accuracy import Accuracy from fastNLP.core.metrics.accuracy import Accuracy
from fastNLP.core.metrics.metric import Metric from fastNLP.core.metrics.metric import Metric
from .utils import find_free_network_port, setup_ddp, _assert_allclose from .utils import find_free_network_port, setup_ddp, _assert_allclose
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
import torch.distributed
from torch.multiprocessing import Pool, set_start_method
else:
from fastNLP.core.utils.dummy_class import DummyClass as set_start_method


set_start_method("spawn", force=True) set_start_method("spawn", force=True)


@@ -26,7 +31,7 @@ pool = None


def _test(local_rank: int, def _test(local_rank: int,
world_size: int, world_size: int,
device: torch.device,
device: "torch.device",
dataset: DataSet, dataset: DataSet,
metric_class: Type[Metric], metric_class: Type[Metric],
metric_kwargs: Dict[str, Any], metric_kwargs: Dict[str, Any],


+ 8
- 3
tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py View File

@@ -2,18 +2,23 @@ from functools import partial
import copy import copy


import pytest import pytest
import torch
import numpy as np import numpy as np
from torch.multiprocessing import Pool, set_start_method


from fastNLP.core.metrics import ClassifyFPreRecMetric from fastNLP.core.metrics import ClassifyFPreRecMetric
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from .utils import find_free_network_port, setup_ddp from .utils import find_free_network_port, setup_ddp
if _NEED_IMPORT_TORCH:
import torch
from torch.multiprocessing import Pool, set_start_method
else:
from fastNLP.core.utils.dummy_class import DummyClass as set_start_method


set_start_method("spawn", force=True) set_start_method("spawn", force=True)




def _test(local_rank: int, world_size: int, device: torch.device,
def _test(local_rank: int, world_size: int, device: "torch.device",
dataset: DataSet, metric_class, metric_kwargs, metric_result): dataset: DataSet, metric_class, metric_kwargs, metric_result):
metric = metric_class(**metric_kwargs) metric = metric_class(**metric_kwargs)
# dataset 也类似(每个进程有自己的一个) # dataset 也类似(每个进程有自己的一个)


+ 9
- 4
tests/core/metrics/test_span_f1_rec_acc_torch.py View File

@@ -5,16 +5,21 @@ import os, sys
import copy import copy
from functools import partial from functools import partial


import torch
import torch.distributed
import numpy as np import numpy as np
import socket import socket
from torch.multiprocessing import Pool, set_start_method
# from multiprocessing import Pool, set_start_method # from multiprocessing import Pool, set_start_method
from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.metrics import SpanFPreRecMetric from fastNLP.core.metrics import SpanFPreRecMetric
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from .utils import find_free_network_port, setup_ddp from .utils import find_free_network_port, setup_ddp
if _NEED_IMPORT_TORCH:
import torch
import torch.distributed
from torch.multiprocessing import Pool, set_start_method
else:
from fastNLP.core.utils.dummy_class import DummyClass as set_start_method


set_start_method("spawn", force=True) set_start_method("spawn", force=True)


@@ -44,7 +49,7 @@ pool = None


def _test(local_rank: int, def _test(local_rank: int,
world_size: int, world_size: int,
device: torch.device,
device: "torch.device",
dataset: DataSet, dataset: DataSet,
metric_class, metric_class,
metric_kwargs, metric_kwargs,


+ 4
- 2
tests/core/metrics/utils.py View File

@@ -2,9 +2,11 @@ import os, sys
import socket import socket
from typing import Union from typing import Union


import torch
from torch import distributed
import numpy as np import numpy as np
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
from torch import distributed




def setup_ddp(rank: int, world_size: int, master_port: int) -> None: def setup_ddp(rank: int, world_size: int, master_port: int) -> None:


+ 1
- 0
tests/core/utils/test_cache_results.py View File

@@ -3,6 +3,7 @@ import pytest
import subprocess import subprocess
from io import StringIO from io import StringIO
import sys import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '../../..'))


from fastNLP.core.utils.cache_results import cache_results from fastNLP.core.utils.cache_results import cache_results
from fastNLP.core import rank_zero_rm from fastNLP.core import rank_zero_rm


+ 2
- 1
tests/envs/test_set_backend.py View File

@@ -1,4 +1,5 @@
import os import os
import pytest


from fastNLP.envs.set_backend import dump_fastnlp_backend from fastNLP.envs.set_backend import dump_fastnlp_backend
from tests.helpers.utils import Capturing from tests.helpers.utils import Capturing
@@ -9,7 +10,7 @@ def test_dump_fastnlp_envs():
filepath = None filepath = None
try: try:
with Capturing() as output: with Capturing() as output:
dump_fastnlp_backend()
dump_fastnlp_backend(backend="torch")
filepath = os.path.join(os.path.expanduser('~'), '.fastNLP', 'envs', os.environ['CONDA_DEFAULT_ENV']+'.json') filepath = os.path.join(os.path.expanduser('~'), '.fastNLP', 'envs', os.environ['CONDA_DEFAULT_ENV']+'.json')
assert filepath in output[0] assert filepath in output[0]
assert os.path.exists(filepath) assert os.path.exists(filepath)


+ 3
- 1
tests/helpers/callbacks/helper_callbacks_torch.py View File

@@ -1,7 +1,9 @@
import torch
from copy import deepcopy from copy import deepcopy


from fastNLP.core.callbacks.callback import Callback from fastNLP.core.callbacks.callback import Callback
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch




class RecordAccumulationStepsCallback_Torch(Callback): class RecordAccumulationStepsCallback_Torch(Callback):


+ 6
- 2
tests/helpers/datasets/torch_data.py View File

@@ -1,7 +1,11 @@
import torch import torch
from functools import reduce from functools import reduce
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torch.utils.data.sampler import SequentialSampler, BatchSampler
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torch.utils.data.sampler import SequentialSampler, BatchSampler
else:
from fastNLP.core.utils.dummy_class import DummyClass as Dataset




class TorchNormalDataset(Dataset): class TorchNormalDataset(Dataset):


+ 10
- 5
tests/helpers/models/torch_model.py View File

@@ -1,9 +1,14 @@
import torch
import torch.nn as nn
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
from torch.nn import Module
import torch.nn as nn
else:
from fastNLP.core.utils.dummy_class import DummyClass as Module




# 1. 最为基础的分类模型 # 1. 最为基础的分类模型
class TorchNormalModel_Classification_1(nn.Module):
class TorchNormalModel_Classification_1(Module):
""" """
单独实现 train_step 和 evaluate_step; 单独实现 train_step 和 evaluate_step;
""" """
@@ -38,7 +43,7 @@ class TorchNormalModel_Classification_1(nn.Module):
return {"preds": x, "target": y} return {"preds": x, "target": y}




class TorchNormalModel_Classification_2(nn.Module):
class TorchNormalModel_Classification_2(Module):
""" """
只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; 只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景;
""" """
@@ -62,7 +67,7 @@ class TorchNormalModel_Classification_2(nn.Module):
return {"loss": loss, "preds": x, "target": y} return {"loss": loss, "preds": x, "target": y}




class TorchNormalModel_Classification_3(nn.Module):
class TorchNormalModel_Classification_3(Module):
""" """
只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; 只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景;
关闭 auto_param_call,forward 只有一个 batch 参数; 关闭 auto_param_call,forward 只有一个 batch 参数;


+ 6
- 0
tests/pytest.ini View File

@@ -0,0 +1,6 @@
[pytest]
markers =
torch
paddle
jittor
torchpaddle

Loading…
Cancel
Save