@@ -91,6 +91,7 @@ class JittorDataLoader: | |||||
self.dataset.dataset.set_attrs(batch_size=1) | self.dataset.dataset.set_attrs(batch_size=1) | ||||
# 用户提供了 collate_fn,则会自动代替 jittor 提供 collate_batch 函数 | # 用户提供了 collate_fn,则会自动代替 jittor 提供 collate_batch 函数 | ||||
# self._collate_fn = _collate_fn | # self._collate_fn = _collate_fn | ||||
self.cur_batch_indices = None | |||||
def __iter__(self): | def __iter__(self): | ||||
# TODO 第一次迭代后不能设置collate_fn,设置是无效的 | # TODO 第一次迭代后不能设置collate_fn,设置是无效的 | ||||
@@ -3,7 +3,7 @@ __all__ = [ | |||||
'prepare_torch_dataloader' | 'prepare_torch_dataloader' | ||||
] | ] | ||||
from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping | |||||
from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.collators import Collator | from fastNLP.core.collators import Collator | ||||
@@ -78,6 +78,7 @@ class TorchDataLoader(DataLoader): | |||||
if sampler is None and batch_sampler is None: | if sampler is None and batch_sampler is None: | ||||
sampler = RandomSampler(dataset, shuffle=shuffle) | sampler = RandomSampler(dataset, shuffle=shuffle) | ||||
shuffle=False | |||||
super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, | super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, | ||||
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None, | batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None, | ||||
@@ -154,6 +155,14 @@ class TorchDataLoader(DataLoader): | |||||
else: | else: | ||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") | raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") | ||||
def get_batch_indices(self) -> List[int]: | |||||
""" | |||||
获取当前 batch 的 idx | |||||
:return: | |||||
""" | |||||
return self.cur_batch_indices | |||||
def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], | def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], | ||||
batch_size: int = 1, | batch_size: int = 1, | ||||
@@ -1,5 +1,5 @@ | |||||
import functools | import functools | ||||
class DummyClass: | class DummyClass: | ||||
def __call__(self, *args, **kwargs): | |||||
return | |||||
def __init__(self, *args, **kwargs): | |||||
pass |
@@ -162,7 +162,7 @@ class TestCallbackEvents: | |||||
def test_every(self): | def test_every(self): | ||||
# 这里是什么样的事件是不影响的,因为我们是与 Trainer 拆分开了进行测试; | # 这里是什么样的事件是不影响的,因为我们是与 Trainer 拆分开了进行测试; | ||||
event_state = Events.on_train_begin() # 什么都不输入是应当默认 every=1; | |||||
event_state = Event.on_train_begin() # 什么都不输入是应当默认 every=1; | |||||
@Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn) | @Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn) | ||||
def _fn(data): | def _fn(data): | ||||
return data | return data | ||||
@@ -174,7 +174,7 @@ class TestCallbackEvents: | |||||
_res.append(cu_res) | _res.append(cu_res) | ||||
assert _res == list(range(100)) | assert _res == list(range(100)) | ||||
event_state = Events.on_train_begin(every=10) | |||||
event_state = Event.on_train_begin(every=10) | |||||
@Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn) | @Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn) | ||||
def _fn(data): | def _fn(data): | ||||
return data | return data | ||||
@@ -187,7 +187,7 @@ class TestCallbackEvents: | |||||
assert _res == [w - 1 for w in range(10, 101, 10)] | assert _res == [w - 1 for w in range(10, 101, 10)] | ||||
def test_once(self): | def test_once(self): | ||||
event_state = Events.on_train_begin(once=10) | |||||
event_state = Event.on_train_begin(once=10) | |||||
@Filter(once=event_state.once) | @Filter(once=event_state.once) | ||||
def _fn(data): | def _fn(data): | ||||
@@ -220,7 +220,7 @@ def test_callback_events_torch(): | |||||
return True | return True | ||||
return False | return False | ||||
event_state = Events.on_train_begin(filter_fn=filter_fn) | |||||
event_state = Event.on_train_begin(filter_fn=filter_fn) | |||||
@Filter(filter_fn=event_state.filter_fn) | @Filter(filter_fn=event_state.filter_fn) | ||||
def _fn(trainer, data): | def _fn(trainer, data): | ||||
@@ -2,9 +2,6 @@ import os | |||||
import pytest | import pytest | ||||
from typing import Any | from typing import Any | ||||
from dataclasses import dataclass | from dataclasses import dataclass | ||||
from torch.utils.data import DataLoader | |||||
from torch.optim import SGD | |||||
import torch.distributed as dist | |||||
from pathlib import Path | from pathlib import Path | ||||
import re | import re | ||||
import time | import time | ||||
@@ -20,6 +17,11 @@ from tests.helpers.datasets.torch_data import TorchArgMaxDataset | |||||
from torchmetrics import Accuracy | from torchmetrics import Accuracy | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
from torch.utils.data import DataLoader | |||||
from torch.optim import SGD | |||||
import torch.distributed as dist | |||||
@dataclass | @dataclass | ||||
class ArgMaxDatasetConfig: | class ArgMaxDatasetConfig: | ||||
@@ -550,7 +552,7 @@ def test_trainer_checkpoint_callback_2( | |||||
if version == 0: | if version == 0: | ||||
callbacks = [ | callbacks = [ | ||||
TrainerCheckpointCallback( | |||||
CheckpointCallback( | |||||
monitor="acc", | monitor="acc", | ||||
folder=path, | folder=path, | ||||
every_n_epochs=None, | every_n_epochs=None, | ||||
@@ -558,12 +560,13 @@ def test_trainer_checkpoint_callback_2( | |||||
topk=None, | topk=None, | ||||
last=False, | last=False, | ||||
on_exception=None, | on_exception=None, | ||||
model_save_fn=model_save_fn | |||||
model_save_fn=model_save_fn, | |||||
save_object="trainer" | |||||
) | ) | ||||
] | ] | ||||
elif version == 1: | elif version == 1: | ||||
callbacks = [ | callbacks = [ | ||||
TrainerCheckpointCallback( | |||||
CheckpointCallback( | |||||
monitor="acc", | monitor="acc", | ||||
folder=path, | folder=path, | ||||
every_n_epochs=None, | every_n_epochs=None, | ||||
@@ -571,7 +574,8 @@ def test_trainer_checkpoint_callback_2( | |||||
topk=1, | topk=1, | ||||
last=True, | last=True, | ||||
on_exception=None, | on_exception=None, | ||||
model_save_fn=model_save_fn | |||||
model_save_fn=model_save_fn, | |||||
save_object="trainer" | |||||
) | ) | ||||
] | ] | ||||
@@ -12,9 +12,7 @@ import os | |||||
import pytest | import pytest | ||||
from typing import Any | from typing import Any | ||||
from dataclasses import dataclass | from dataclasses import dataclass | ||||
from torch.utils.data import DataLoader | |||||
from torch.optim import SGD | |||||
import torch.distributed as dist | |||||
from pathlib import Path | from pathlib import Path | ||||
import re | import re | ||||
@@ -29,7 +27,11 @@ from torchmetrics import Accuracy | |||||
from fastNLP.core.metrics import Metric | from fastNLP.core.metrics import Metric | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.callbacks import MoreEvaluateCallback | from fastNLP.core.callbacks import MoreEvaluateCallback | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
from torch.utils.data import DataLoader | |||||
from torch.optim import SGD | |||||
import torch.distributed as dist | |||||
@dataclass | @dataclass | ||||
class ArgMaxDatasetConfig: | class ArgMaxDatasetConfig: | ||||
@@ -17,6 +17,7 @@ def test_get_element_shape_dtype(): | |||||
@pytest.mark.parametrize('backend', ['raw', None, 'numpy', 'torch', 'jittor', 'paddle']) | @pytest.mark.parametrize('backend', ['raw', None, 'numpy', 'torch', 'jittor', 'paddle']) | ||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.paddle | @pytest.mark.paddle | ||||
@pytest.mark.jittor | |||||
def test_get_padder_run(backend): | def test_get_padder_run(backend): | ||||
if not _NEED_IMPORT_TORCH and backend == 'torch': | if not _NEED_IMPORT_TORCH and backend == 'torch': | ||||
pytest.skip("No torch") | pytest.skip("No torch") | ||||
@@ -1,7 +1,7 @@ | |||||
import numpy as np | import numpy as np | ||||
import pytest | import pytest | ||||
from fastNLP.core.collators.padders.paddle_padder import paddleTensorPadder, paddleSequencePadder, paddleNumberPadder | |||||
from fastNLP.core.collators.padders.paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder | |||||
from fastNLP.core.collators.padders.exceptions import DtypeError | from fastNLP.core.collators.padders.exceptions import DtypeError | ||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
@@ -10,9 +10,9 @@ if _NEED_IMPORT_PADDLE: | |||||
@pytest.mark.paddle | @pytest.mark.paddle | ||||
class TestpaddleNumberPadder: | |||||
class TestPaddleNumberPadder: | |||||
def test_run(self): | def test_run(self): | ||||
padder = paddleNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) | |||||
padder = PaddleNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) | |||||
a = [1, 2, 3] | a = [1, 2, 3] | ||||
t_a = padder(a) | t_a = padder(a) | ||||
assert isinstance(t_a, paddle.Tensor) | assert isinstance(t_a, paddle.Tensor) | ||||
@@ -20,9 +20,9 @@ class TestpaddleNumberPadder: | |||||
@pytest.mark.paddle | @pytest.mark.paddle | ||||
class TestpaddleSequencePadder: | |||||
class TestPaddleSequencePadder: | |||||
def test_run(self): | def test_run(self): | ||||
padder = paddleSequencePadder(ele_dtype=int, dtype=int, pad_val=-1) | |||||
padder = PaddleSequencePadder(ele_dtype=int, dtype=int, pad_val=-1) | |||||
a = [[1, 2, 3], [3]] | a = [[1, 2, 3], [3]] | ||||
a = padder(a) | a = padder(a) | ||||
shape = a.shape | shape = a.shape | ||||
@@ -32,20 +32,20 @@ class TestpaddleSequencePadder: | |||||
assert (a == b).sum().item() == shape[0]*shape[1] | assert (a == b).sum().item() == shape[0]*shape[1] | ||||
def test_dtype_check(self): | def test_dtype_check(self): | ||||
padder = paddleSequencePadder(ele_dtype=np.zeros(3, dtype=np.int32).dtype, dtype=int, pad_val=-1) | |||||
padder = PaddleSequencePadder(ele_dtype=np.zeros(3, dtype=np.int32).dtype, dtype=int, pad_val=-1) | |||||
with pytest.raises(DtypeError): | with pytest.raises(DtypeError): | ||||
padder = paddleSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) | |||||
padder = paddleSequencePadder(ele_dtype='int64', dtype=int, pad_val=-1) | |||||
padder = paddleSequencePadder(ele_dtype=np.int32, dtype=None, pad_val=-1) | |||||
padder = PaddleSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) | |||||
padder = PaddleSequencePadder(ele_dtype='int64', dtype=int, pad_val=-1) | |||||
padder = PaddleSequencePadder(ele_dtype=np.int32, dtype=None, pad_val=-1) | |||||
a = padder([[1], [2, 322]]) | a = padder([[1], [2, 322]]) | ||||
# assert (a>67).sum()==0 # 因为int8的范围为-67 - 66 | # assert (a>67).sum()==0 # 因为int8的范围为-67 - 66 | ||||
padder = paddleSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1) | |||||
padder = PaddleSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1) | |||||
@pytest.mark.paddle | @pytest.mark.paddle | ||||
class TestpaddleTensorPadder: | |||||
class TestPaddleTensorPadder: | |||||
def test_run(self): | def test_run(self): | ||||
padder = paddleTensorPadder(ele_dtype=paddle.zeros((3,)).dtype, dtype=paddle.zeros((3,)).dtype, pad_val=-1) | |||||
padder = PaddleTensorPadder(ele_dtype=paddle.zeros((3,)).dtype, dtype=paddle.zeros((3,)).dtype, pad_val=-1) | |||||
a = [paddle.zeros((3,)), paddle.zeros((2,))] | a = [paddle.zeros((3,)), paddle.zeros((2,))] | ||||
a = padder(a) | a = padder(a) | ||||
shape = a.shape | shape = a.shape | ||||
@@ -74,7 +74,7 @@ class TestpaddleTensorPadder: | |||||
[[0, -1], [-1, -1], [-1, -1]]]) | [[0, -1], [-1, -1], [-1, -1]]]) | ||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | ||||
padder = paddleTensorPadder(ele_dtype=paddle.zeros((3, )).dtype, dtype=paddle.zeros((3, )).dtype, pad_val=-1) | |||||
padder = PaddleTensorPadder(ele_dtype=paddle.zeros((3, )).dtype, dtype=paddle.zeros((3, )).dtype, pad_val=-1) | |||||
a = [paddle.zeros((3, 2)), paddle.zeros((2, 2))] | a = [paddle.zeros((3, 2)), paddle.zeros((2, 2))] | ||||
a = padder(a) | a = padder(a) | ||||
shape = a.shape | shape = a.shape | ||||
@@ -85,7 +85,7 @@ class TestpaddleTensorPadder: | |||||
]) | ]) | ||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | ||||
padder = paddleTensorPadder(ele_dtype=paddle.zeros((3, 2)).dtype, dtype=None, pad_val=-1) | |||||
padder = PaddleTensorPadder(ele_dtype=paddle.zeros((3, 2)).dtype, dtype=None, pad_val=-1) | |||||
a = [np.zeros((3, 2), dtype=np.float32), np.zeros((2, 2), dtype=np.float32)] | a = [np.zeros((3, 2), dtype=np.float32), np.zeros((2, 2), dtype=np.float32)] | ||||
a = padder(a) | a = padder(a) | ||||
shape = a.shape | shape = a.shape | ||||
@@ -96,11 +96,11 @@ class TestpaddleTensorPadder: | |||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | ||||
def test_dtype_check(self): | def test_dtype_check(self): | ||||
padder = paddleTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||||
padder = PaddleTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||||
with pytest.raises(DtypeError): | with pytest.raises(DtypeError): | ||||
padder = paddleTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) | |||||
padder = paddleTensorPadder(ele_dtype='int64', dtype=int, pad_val=-1) | |||||
padder = paddleTensorPadder(ele_dtype=int, dtype='int64', pad_val=-1) | |||||
padder = PaddleTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) | |||||
padder = PaddleTensorPadder(ele_dtype='int64', dtype=int, pad_val=-1) | |||||
padder = PaddleTensorPadder(ele_dtype=int, dtype='int64', pad_val=-1) | |||||
def test_v1(self): | def test_v1(self): | ||||
print(paddle.zeros((3, )).dtype) | print(paddle.zeros((3, )).dtype) |
@@ -23,7 +23,6 @@ class TestRawSequencePadder: | |||||
assert (a == b).sum().item() == shape[0]*shape[1] | assert (a == b).sum().item() == shape[0]*shape[1] | ||||
def test_dtype_check(self): | def test_dtype_check(self): | ||||
with pytest.raises(DtypeError): | |||||
padder = RawSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) | |||||
padder = RawSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) | |||||
with pytest.raises(DtypeError): | with pytest.raises(DtypeError): | ||||
padder = RawSequencePadder(pad_val=-1, ele_dtype=str, dtype=int) | padder = RawSequencePadder(pad_val=-1, ele_dtype=str, dtype=int) |
@@ -4,7 +4,7 @@ import pytest | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR | from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR | ||||
from fastNLP.core.collators.new_collator import Collator | |||||
from fastNLP.core.collators.collator import Collator | |||||
def _assert_equal(d1, d2): | def _assert_equal(d1, d2): | ||||
@@ -1,10 +1,7 @@ | |||||
import pytest | import pytest | ||||
from typing import Any | from typing import Any | ||||
from dataclasses import dataclass | from dataclasses import dataclass | ||||
from torch.optim import SGD | |||||
from torch.utils.data import DataLoader | |||||
from torchmetrics import Accuracy | |||||
import torch.distributed as dist | |||||
from fastNLP.core.controllers.trainer import Trainer | from fastNLP.core.controllers.trainer import Trainer | ||||
from fastNLP.core.callbacks.callback_event import Event | from fastNLP.core.callbacks.callback_event import Event | ||||
@@ -12,6 +9,12 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | ||||
from tests.helpers.callbacks.helper_callbacks import RecordTrainerEventTriggerCallback | from tests.helpers.callbacks.helper_callbacks import RecordTrainerEventTriggerCallback | ||||
from tests.helpers.utils import magic_argv_env_context, Capturing | from tests.helpers.utils import magic_argv_env_context, Capturing | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
from torch.optim import SGD | |||||
from torch.utils.data import DataLoader | |||||
from torchmetrics import Accuracy | |||||
import torch.distributed as dist | |||||
@dataclass | @dataclass | ||||
@@ -96,10 +99,10 @@ def test_trainer_event_trigger_1( | |||||
if dist.is_initialized(): | if dist.is_initialized(): | ||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
Event_attrs = Event.__dict__ | |||||
for k, v in Event_attrs.items(): | |||||
if isinstance(v, staticmethod): | |||||
assert k in output[0] | |||||
Event_attrs = Event.__dict__ | |||||
for k, v in Event_attrs.items(): | |||||
if isinstance(v, staticmethod): | |||||
assert k in output[0] | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) | @pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) | ||||
@@ -211,7 +214,101 @@ def test_trainer_event_trigger_2( | |||||
) | ) | ||||
trainer.run() | trainer.run() | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
Event_attrs = Event.__dict__ | Event_attrs = Event.__dict__ | ||||
for k, v in Event_attrs.items(): | for k, v in Event_attrs.items(): | ||||
if isinstance(v, staticmethod): | if isinstance(v, staticmethod): | ||||
assert k in output[0] | assert k in output[0] | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 6)]) | |||||
@pytest.mark.torch | |||||
@magic_argv_env_context | |||||
def test_trainer_event_trigger_3( | |||||
model_and_optimizers: TrainerParameters, | |||||
driver, | |||||
device, | |||||
n_epochs=2, | |||||
): | |||||
import re | |||||
once_message_1 = "This message should be typed 1 times." | |||||
once_message_2 = "test_filter_fn" | |||||
once_message_3 = "once message 3" | |||||
twice_message = "twice message hei hei" | |||||
@Trainer.on(Event.on_train_epoch_begin(every=2)) | |||||
def train_epoch_begin_1(trainer): | |||||
print(once_message_1) | |||||
@Trainer.on(Event.on_train_epoch_begin()) | |||||
def train_epoch_begin_2(trainer): | |||||
print(twice_message) | |||||
@Trainer.on(Event.on_train_epoch_begin(once=2)) | |||||
def train_epoch_begin_3(trainer): | |||||
print(once_message_3) | |||||
def filter_fn(filter, trainer): | |||||
if trainer.cur_epoch_idx == 1: | |||||
return True | |||||
else: | |||||
return False | |||||
@Trainer.on(Event.on_train_epoch_end(filter_fn=filter_fn)) | |||||
def test_filter_fn(trainer): | |||||
print(once_message_2) | |||||
with Capturing() as output: | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=n_epochs, | |||||
) | |||||
trainer.run() | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
once_pattern_1 = re.compile(once_message_1) | |||||
once_pattern_2 = re.compile(once_message_2) | |||||
once_pattern_3 = re.compile(once_message_3) | |||||
twice_pattern = re.compile(twice_message) | |||||
once_res_1 = once_pattern_1.findall(output[0]) | |||||
assert len(once_res_1) == 1 | |||||
once_res_2 = once_pattern_2.findall(output[0]) | |||||
assert len(once_res_2) == 1 | |||||
once_res_3 = once_pattern_3.findall(output[0]) | |||||
assert len(once_res_3) == 1 | |||||
twice_res = twice_pattern.findall(output[0]) | |||||
assert len(twice_res) == 2 | |||||
@@ -1,22 +1,22 @@ | |||||
import pytest | import pytest | ||||
from fastNLP.core.controllers.trainer import Trainer | from fastNLP.core.controllers.trainer import Trainer | ||||
from fastNLP.core.callbacks import Events | |||||
from fastNLP.core.callbacks import Event | |||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_torch_without_evaluator(): | def test_trainer_torch_without_evaluator(): | ||||
@Trainer.on(Events.on_train_epoch_begin(every=10)) | |||||
@Trainer.on(Event.on_train_epoch_begin(every=10), marker="test_trainer_other_things") | |||||
def fn1(trainer): | def fn1(trainer): | ||||
pass | pass | ||||
@Trainer.on(Events.on_train_batch_begin(every=10)) | |||||
@Trainer.on(Event.on_train_batch_begin(every=10), marker="test_trainer_other_things") | |||||
def fn2(trainer, batch, indices): | def fn2(trainer, batch, indices): | ||||
pass | pass | ||||
with pytest.raises(AssertionError): | |||||
@Trainer.on(Events.on_train_batch_begin(every=10)) | |||||
with pytest.raises(BaseException): | |||||
@Trainer.on(Event.on_train_batch_begin(every=10), marker="test_trainer_other_things") | |||||
def fn3(trainer, batch): | def fn3(trainer, batch): | ||||
pass | pass | ||||
@@ -2,9 +2,7 @@ | |||||
注意这一文件中的测试函数都应当是在 `test_trainer_w_evaluator_torch.py` 中已经测试过的测试函数的基础上加上 metrics 和 evaluator 修改而成; | 注意这一文件中的测试函数都应当是在 `test_trainer_w_evaluator_torch.py` 中已经测试过的测试函数的基础上加上 metrics 和 evaluator 修改而成; | ||||
""" | """ | ||||
import pytest | import pytest | ||||
from torch.optim import SGD | |||||
from torch.utils.data import DataLoader | |||||
import torch.distributed as dist | |||||
from dataclasses import dataclass | from dataclasses import dataclass | ||||
from typing import Any | from typing import Any | ||||
from torchmetrics import Accuracy | from torchmetrics import Accuracy | ||||
@@ -14,7 +12,11 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset | ||||
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback | from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback | ||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
from torch.optim import SGD | |||||
from torch.utils.data import DataLoader | |||||
import torch.distributed as dist | |||||
@dataclass | @dataclass | ||||
class NormalClassificationTrainTorchConfig: | class NormalClassificationTrainTorchConfig: | ||||
@@ -2,9 +2,7 @@ import os.path | |||||
import subprocess | import subprocess | ||||
import sys | import sys | ||||
import pytest | import pytest | ||||
import torch.distributed as dist | |||||
from torch.optim import SGD | |||||
from torch.utils.data import DataLoader | |||||
from dataclasses import dataclass | from dataclasses import dataclass | ||||
from typing import Any | from typing import Any | ||||
from pathlib import Path | from pathlib import Path | ||||
@@ -16,6 +14,11 @@ from tests.helpers.callbacks.helper_callbacks import RecordLossCallback | |||||
from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch | from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch | ||||
from tests.helpers.utils import magic_argv_env_context, Capturing | from tests.helpers.utils import magic_argv_env_context, Capturing | ||||
from fastNLP.core import rank_zero_rm | from fastNLP.core import rank_zero_rm | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch.distributed as dist | |||||
from torch.optim import SGD | |||||
from torch.utils.data import DataLoader | |||||
@dataclass | @dataclass | ||||
@@ -286,6 +289,7 @@ def test_trainer_on_exception( | |||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("version", [0, 1, 2, 3]) | @pytest.mark.parametrize("version", [0, 1, 2, 3]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_torch_distributed_launch_1(version): | def test_torch_distributed_launch_1(version): | ||||
@@ -11,7 +11,7 @@ class Test_WrapDataLoader: | |||||
for sanity_batches in all_sanity_batches: | for sanity_batches in all_sanity_batches: | ||||
data = NormalSampler(num_of_data=1000) | data = NormalSampler(num_of_data=1000) | ||||
wrapper = _TruncatedDataLoader(dataloader=data, num_batches=sanity_batches) | wrapper = _TruncatedDataLoader(dataloader=data, num_batches=sanity_batches) | ||||
dataloader = iter(wrapper(dataloader=data)) | |||||
dataloader = iter(wrapper) | |||||
mark = 0 | mark = 0 | ||||
while True: | while True: | ||||
try: | try: | ||||
@@ -32,8 +32,7 @@ class Test_WrapDataLoader: | |||||
dataset = TorchNormalDataset(num_of_data=1000) | dataset = TorchNormalDataset(num_of_data=1000) | ||||
dataloader = DataLoader(dataset, batch_size=bs, shuffle=True) | dataloader = DataLoader(dataset, batch_size=bs, shuffle=True) | ||||
wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches) | wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches) | ||||
dataloader = wrapper(dataloader) | |||||
dataloader = iter(dataloader) | |||||
dataloader = iter(wrapper) | |||||
all_supposed_running_data_num = 0 | all_supposed_running_data_num = 0 | ||||
while True: | while True: | ||||
try: | try: | ||||
@@ -55,6 +54,5 @@ class Test_WrapDataLoader: | |||||
dataset = TorchNormalDataset(num_of_data=1000) | dataset = TorchNormalDataset(num_of_data=1000) | ||||
dataloader = DataLoader(dataset, batch_size=bs, shuffle=True) | dataloader = DataLoader(dataset, batch_size=bs, shuffle=True) | ||||
wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches) | wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches) | ||||
dataloader = wrapper(dataloader) | |||||
length.append(len(dataloader)) | |||||
length.append(len(wrapper)) | |||||
assert length == reduce(lambda x, y: x+y, [all_sanity_batches for _ in range(len(bses))]) | assert length == reduce(lambda x, y: x+y, [all_sanity_batches for _ in range(len(bses))]) |
@@ -15,7 +15,7 @@ else: | |||||
class Model (Module): | |||||
class Model(Module): | |||||
def __init__ (self): | def __init__ (self): | ||||
super (Model, self).__init__() | super (Model, self).__init__() | ||||
self.conv1 = nn.Conv (3, 32, 3, 1) # no padding | self.conv1 = nn.Conv (3, 32, 3, 1) # no padding | ||||
@@ -45,6 +45,7 @@ class Model (Module): | |||||
return x | return x | ||||
@pytest.mark.jittor | @pytest.mark.jittor | ||||
@pytest.mark.skip("Skip jittor tests now.") | |||||
class TestSingleDevice: | class TestSingleDevice: | ||||
def test_on_gpu_without_fp16(self): | def test_on_gpu_without_fp16(self): | ||||
@@ -13,12 +13,13 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset | ||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
from fastNLP.core import rank_zero_rm | from fastNLP.core import rank_zero_rm | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
import torch.distributed as dist | |||||
from torch.utils.data import DataLoader, BatchSampler | |||||
import torch | |||||
import torch.distributed as dist | |||||
from torch.utils.data import DataLoader, BatchSampler | |||||
def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"): | |||||
def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="all"): | |||||
torch_model = TorchNormalModel_Classification_1(num_labels, feature_dimension) | torch_model = TorchNormalModel_Classification_1(num_labels, feature_dimension) | ||||
torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) | torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) | ||||
device = [torch.device(i) for i in device] | device = [torch.device(i) for i in device] | ||||
@@ -72,108 +73,100 @@ def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed= | |||||
# | # | ||||
############################################################################ | ############################################################################ | ||||
@pytest.mark.torch | |||||
@magic_argv_env_context | |||||
def test_multi_drivers(): | |||||
""" | |||||
测试使用了多个 TorchDDPDriver 的情况。 | |||||
""" | |||||
generate_driver(10, 10) | |||||
generate_driver(20, 10) | |||||
with pytest.raises(RuntimeError): | |||||
# 设备设置不同,应该报错 | |||||
generate_driver(20, 3, device=[0,1,2]) | |||||
assert False | |||||
dist.barrier() | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
class TestDDPDriverFunction: | class TestDDPDriverFunction: | ||||
""" | """ | ||||
测试 TorchDDPDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 | 测试 TorchDDPDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 | ||||
""" | """ | ||||
@classmethod | |||||
def setup_class(cls): | |||||
cls.driver = generate_driver(10, 10) | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_multi_drivers(self): | |||||
def test_simple_functions(self): | |||||
""" | """ | ||||
测试使用了多个 TorchDDPDriver 的情况。 | |||||
简单测试多个函数 | |||||
""" | """ | ||||
driver2 = generate_driver(20, 10) | |||||
with pytest.raises(RuntimeError): | |||||
# 设备设置不同,应该报错 | |||||
driver3 = generate_driver(20, 3, device=[0,1,2]) | |||||
assert False | |||||
dist.barrier() | |||||
driver = generate_driver(10, 10) | |||||
@magic_argv_env_context | |||||
def test_move_data_to_device(self): | |||||
""" | """ | ||||
这个函数仅调用了torch_move_data_to_device,测试例在tests/core/utils/test_torch_utils.py中 | |||||
就不重复测试了 | |||||
测试 move_data_to_device 函数。这个函数仅调用了 torch_move_data_to_device ,测试例在 | |||||
tests/core/utils/test_torch_utils.py中,就不重复测试了 | |||||
""" | """ | ||||
self.driver.move_data_to_device(torch.rand((32, 64))) | |||||
driver.move_data_to_device(torch.rand((32, 64))) | |||||
dist.barrier() | dist.barrier() | ||||
@magic_argv_env_context | |||||
def test_is_distributed(self): | |||||
""" | """ | ||||
测试 is_distributed 函数 | 测试 is_distributed 函数 | ||||
""" | """ | ||||
assert self.driver.is_distributed() == True | |||||
assert driver.is_distributed() == True | |||||
dist.barrier() | dist.barrier() | ||||
@magic_argv_env_context | |||||
def test_get_no_sync_context(self): | |||||
""" | """ | ||||
测试 get_no_sync_context 函数 | 测试 get_no_sync_context 函数 | ||||
""" | """ | ||||
res = self.driver.get_model_no_sync_context() | |||||
res = driver.get_model_no_sync_context() | |||||
dist.barrier() | dist.barrier() | ||||
@magic_argv_env_context | |||||
def test_is_global_zero(self): | |||||
""" | """ | ||||
测试 is_global_zero 函数 | 测试 is_global_zero 函数 | ||||
""" | """ | ||||
self.driver.is_global_zero() | |||||
driver.is_global_zero() | |||||
dist.barrier() | dist.barrier() | ||||
@magic_argv_env_context | |||||
def test_unwrap_model(self): | |||||
""" | """ | ||||
测试 unwrap_model 函数 | 测试 unwrap_model 函数 | ||||
""" | """ | ||||
self.driver.unwrap_model() | |||||
driver.unwrap_model() | |||||
dist.barrier() | dist.barrier() | ||||
@magic_argv_env_context | |||||
def test_get_local_rank(self): | |||||
""" | """ | ||||
测试 get_local_rank 函数 | 测试 get_local_rank 函数 | ||||
""" | """ | ||||
self.driver.get_local_rank() | |||||
driver.get_local_rank() | |||||
dist.barrier() | dist.barrier() | ||||
@magic_argv_env_context | |||||
def test_all_gather(self): | |||||
""" | """ | ||||
测试 all_gather 函数 | 测试 all_gather 函数 | ||||
详细的测试在 test_dist_utils.py 中完成 | 详细的测试在 test_dist_utils.py 中完成 | ||||
""" | """ | ||||
obj = { | obj = { | ||||
"rank": self.driver.global_rank | |||||
"rank": driver.global_rank | |||||
} | } | ||||
obj_list = self.driver.all_gather(obj, group=None) | |||||
obj_list = driver.all_gather(obj, group=None) | |||||
for i, res in enumerate(obj_list): | for i, res in enumerate(obj_list): | ||||
assert res["rank"] == i | assert res["rank"] == i | ||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("src_rank", ([0, 1])) | |||||
def test_broadcast_object(self, src_rank): | |||||
""" | """ | ||||
测试 broadcast_object 函数 | 测试 broadcast_object 函数 | ||||
详细的函数在 test_dist_utils.py 中完成 | 详细的函数在 test_dist_utils.py 中完成 | ||||
""" | """ | ||||
if self.driver.global_rank == src_rank: | |||||
if driver.global_rank == 0: | |||||
obj = { | obj = { | ||||
"rank": self.driver.global_rank | |||||
"rank": driver.global_rank | |||||
} | } | ||||
else: | else: | ||||
obj = None | obj = None | ||||
res = self.driver.broadcast_object(obj, src=src_rank) | |||||
assert res["rank"] == src_rank | |||||
res = driver.broadcast_object(obj, src=0) | |||||
assert res["rank"] == 0 | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
############################################################################ | ############################################################################ | ||||
# | # | ||||
@@ -187,7 +180,6 @@ class TestSetDistReproDataloader: | |||||
@classmethod | @classmethod | ||||
def setup_class(cls): | def setup_class(cls): | ||||
cls.device = [0, 1] | cls.device = [0, 1] | ||||
cls.driver = generate_driver(10, 10, device=cls.device) | |||||
def setup_method(self): | def setup_method(self): | ||||
self.dataset = TorchNormalDataset(40) | self.dataset = TorchNormalDataset(40) | ||||
@@ -204,17 +196,20 @@ class TestSetDistReproDataloader: | |||||
测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 | 测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 | ||||
此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler | 此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler | ||||
""" | """ | ||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) | dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) | ||||
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) | batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, batch_sampler, False) | |||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | ||||
assert replaced_loader.batch_sampler is batch_sampler | assert replaced_loader.batch_sampler is batch_sampler | ||||
self.check_distributed_sampler(replaced_loader.batch_sampler) | self.check_distributed_sampler(replaced_loader.batch_sampler) | ||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) | |||||
dist.barrier() | dist.barrier() | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@pytest.mark.parametrize("shuffle", ([True, False])) | @pytest.mark.parametrize("shuffle", ([True, False])) | ||||
@@ -223,9 +218,10 @@ class TestSetDistReproDataloader: | |||||
测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 | 测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 | ||||
此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler | 此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler | ||||
""" | """ | ||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) | dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) | ||||
sampler = RandomSampler(self.dataset, shuffle=shuffle) | sampler = RandomSampler(self.dataset, shuffle=shuffle) | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, sampler, False) | |||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | assert isinstance(replaced_loader.batch_sampler, BatchSampler) | ||||
@@ -234,9 +230,11 @@ class TestSetDistReproDataloader: | |||||
assert replaced_loader.batch_sampler.sampler is sampler | assert replaced_loader.batch_sampler.sampler is sampler | ||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | ||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | ||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) | |||||
dist.barrier() | dist.barrier() | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
""" | """ | ||||
传入的参数 `dist` 为 None 的情况,这种情况出现在 trainer 和 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` | 传入的参数 `dist` 为 None 的情况,这种情况出现在 trainer 和 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` | ||||
@@ -251,15 +249,17 @@ class TestSetDistReproDataloader: | |||||
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现 | 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现 | ||||
当用户在 driver 之外初始化了分布式环境时,fastnlp 不支持进行断点重训,此时应该报错 | 当用户在 driver 之外初始化了分布式环境时,fastnlp 不支持进行断点重训,此时应该报错 | ||||
""" | """ | ||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) | dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) | ||||
with pytest.raises(RuntimeError): | with pytest.raises(RuntimeError): | ||||
# 应当抛出 RuntimeError | # 应当抛出 RuntimeError | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, True) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, True) | |||||
dist.barrier() | dist.barrier() | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
# @pytest.mark.parametrize("shuffle", ([True, False])) | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | @pytest.mark.parametrize("shuffle", ([True, False])) | ||||
def test_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle): | def test_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle): | ||||
""" | """ | ||||
@@ -268,21 +268,24 @@ class TestSetDistReproDataloader: | |||||
此时传入的 dataloader 的 batch_sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler | 此时传入的 dataloader 的 batch_sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler | ||||
和原 dataloader 相同 | 和原 dataloader 相同 | ||||
""" | """ | ||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) | dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) | ||||
dataloader.batch_sampler.set_distributed( | dataloader.batch_sampler.set_distributed( | ||||
num_replicas=self.driver.world_size, | |||||
rank=self.driver.global_rank, | |||||
num_replicas=driver.world_size, | |||||
rank=driver.global_rank, | |||||
pad=True | pad=True | ||||
) | ) | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False) | |||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | ||||
assert replaced_loader.batch_sampler.batch_size == 4 | assert replaced_loader.batch_sampler.batch_size == 4 | ||||
self.check_distributed_sampler(dataloader.batch_sampler) | self.check_distributed_sampler(dataloader.batch_sampler) | ||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) | |||||
dist.barrier() | dist.barrier() | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@pytest.mark.parametrize("shuffle", ([True, False])) | @pytest.mark.parametrize("shuffle", ([True, False])) | ||||
@@ -292,12 +295,13 @@ class TestSetDistReproDataloader: | |||||
此时传入的 dataloader 的 batch_sampler.sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 | 此时传入的 dataloader 的 batch_sampler.sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 | ||||
batch_sampler.sampler 和原 dataloader 相同 | batch_sampler.sampler 和原 dataloader 相同 | ||||
""" | """ | ||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) | dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) | ||||
dataloader.batch_sampler.sampler.set_distributed( | dataloader.batch_sampler.sampler.set_distributed( | ||||
num_replicas=self.driver.world_size, | |||||
rank=self.driver.global_rank | |||||
num_replicas=driver.world_size, | |||||
rank=driver.global_rank | |||||
) | ) | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False) | |||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | assert isinstance(replaced_loader.batch_sampler, BatchSampler) | ||||
@@ -307,9 +311,11 @@ class TestSetDistReproDataloader: | |||||
assert replaced_loader.batch_sampler.batch_size == 4 | assert replaced_loader.batch_sampler.batch_size == 4 | ||||
assert replaced_loader.batch_sampler.drop_last == False | assert replaced_loader.batch_sampler.drop_last == False | ||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | ||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) | |||||
dist.barrier() | dist.barrier() | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@pytest.mark.parametrize("shuffle", ([True, False])) | @pytest.mark.parametrize("shuffle", ([True, False])) | ||||
@@ -318,11 +324,14 @@ class TestSetDistReproDataloader: | |||||
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 | 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 | ||||
此时直接返回原来的 dataloader,不做任何处理。 | 此时直接返回原来的 dataloader,不做任何处理。 | ||||
""" | """ | ||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) | dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False) | |||||
assert replaced_loader is dataloader | assert replaced_loader is dataloader | ||||
dist.barrier() | dist.barrier() | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
""" | """ | ||||
传入的参数 `dist` 为 'dist' 的情况,这种情况出现在 trainer 的初始化过程中,用户指定了 `use_dist_sampler` 参数 | 传入的参数 `dist` 为 'dist' 的情况,这种情况出现在 trainer 的初始化过程中,用户指定了 `use_dist_sampler` 参数 | ||||
@@ -337,12 +346,13 @@ class TestSetDistReproDataloader: | |||||
的表现 | 的表现 | ||||
此时应该返回一个新的 dataloader,其batch_sampler 和原 dataloader 相同,且应该正确地设置了分布式相关的属性 | 此时应该返回一个新的 dataloader,其batch_sampler 和原 dataloader 相同,且应该正确地设置了分布式相关的属性 | ||||
""" | """ | ||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = DataLoader( | dataloader = DataLoader( | ||||
dataset=self.dataset, | dataset=self.dataset, | ||||
batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) | batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) | ||||
) | ) | ||||
dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) | dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | ||||
@@ -351,6 +361,8 @@ class TestSetDistReproDataloader: | |||||
assert replaced_loader.drop_last == dataloader.drop_last | assert replaced_loader.drop_last == dataloader.drop_last | ||||
self.check_distributed_sampler(replaced_loader.batch_sampler) | self.check_distributed_sampler(replaced_loader.batch_sampler) | ||||
dist.barrier() | dist.barrier() | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@pytest.mark.parametrize("shuffle", ([True, False])) | @pytest.mark.parametrize("shuffle", ([True, False])) | ||||
@@ -361,8 +373,9 @@ class TestSetDistReproDataloader: | |||||
此时应该返回一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同,且应该正确地设置了分布式相关 | 此时应该返回一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同,且应该正确地设置了分布式相关 | ||||
的属性 | 的属性 | ||||
""" | """ | ||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) | dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | ||||
@@ -372,6 +385,8 @@ class TestSetDistReproDataloader: | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | ||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | ||||
dist.barrier() | dist.barrier() | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@pytest.mark.parametrize("shuffle", ([True, False])) | @pytest.mark.parametrize("shuffle", ([True, False])) | ||||
@@ -381,8 +396,9 @@ class TestSetDistReproDataloader: | |||||
此时应该返回一个新的 dataloader,并替换其 batch_sampler.sampler 为 RandomSampler,且应该正确设置了分布式相关 | 此时应该返回一个新的 dataloader,并替换其 batch_sampler.sampler 为 RandomSampler,且应该正确设置了分布式相关 | ||||
的属性 | 的属性 | ||||
""" | """ | ||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) | dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | assert isinstance(replaced_loader.batch_sampler, BatchSampler) | ||||
@@ -392,6 +408,8 @@ class TestSetDistReproDataloader: | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | ||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | ||||
dist.barrier() | dist.barrier() | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
""" | """ | ||||
传入的参数 `dist` 为 'unrepeatdist' 的情况,这种情况出现在 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` 参数 | 传入的参数 `dist` 为 'unrepeatdist' 的情况,这种情况出现在 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` 参数 | ||||
@@ -407,8 +425,9 @@ class TestSetDistReproDataloader: | |||||
此时应该返回一个新的 dataloader,且将原来的 Sampler 替换为 UnrepeatedRandomSampler,且正确地设置了分布式相关 | 此时应该返回一个新的 dataloader,且将原来的 Sampler 替换为 UnrepeatedRandomSampler,且正确地设置了分布式相关 | ||||
的属性 | 的属性 | ||||
""" | """ | ||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) | dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | assert isinstance(replaced_loader.batch_sampler, BatchSampler) | ||||
@@ -418,6 +437,8 @@ class TestSetDistReproDataloader: | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | ||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | ||||
dist.barrier() | dist.barrier() | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@pytest.mark.parametrize("shuffle", ([True, False])) | @pytest.mark.parametrize("shuffle", ([True, False])) | ||||
@@ -427,8 +448,9 @@ class TestSetDistReproDataloader: | |||||
的表现 | 的表现 | ||||
此时应该返回一个新的 dataloader,且重新实例化了原来的 Sampler | 此时应该返回一个新的 dataloader,且重新实例化了原来的 Sampler | ||||
""" | """ | ||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=True) | dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=True) | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | assert isinstance(replaced_loader.batch_sampler, BatchSampler) | ||||
@@ -439,6 +461,8 @@ class TestSetDistReproDataloader: | |||||
assert replaced_loader.drop_last == dataloader.drop_last | assert replaced_loader.drop_last == dataloader.drop_last | ||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | ||||
dist.barrier() | dist.barrier() | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@pytest.mark.parametrize("shuffle", ([True, False])) | @pytest.mark.parametrize("shuffle", ([True, False])) | ||||
@@ -448,8 +472,9 @@ class TestSetDistReproDataloader: | |||||
此时应该返回一个新的 dataloader,且将 sampler 替换为 UnrepeatedSequentialSampler,并正确地设置了分布式相关 | 此时应该返回一个新的 dataloader,且将 sampler 替换为 UnrepeatedSequentialSampler,并正确地设置了分布式相关 | ||||
的属性 | 的属性 | ||||
""" | """ | ||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) | dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | assert isinstance(replaced_loader.batch_sampler, BatchSampler) | ||||
@@ -459,6 +484,8 @@ class TestSetDistReproDataloader: | |||||
assert replaced_loader.drop_last == dataloader.drop_last | assert replaced_loader.drop_last == dataloader.drop_last | ||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | ||||
dist.barrier() | dist.barrier() | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
def check_distributed_sampler(self, sampler): | def check_distributed_sampler(self, sampler): | ||||
""" | """ | ||||
@@ -469,7 +496,7 @@ class TestSetDistReproDataloader: | |||||
if not isinstance(sampler, UnrepeatedSampler): | if not isinstance(sampler, UnrepeatedSampler): | ||||
assert sampler.pad == True | assert sampler.pad == True | ||||
def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle): | |||||
def check_set_dist_repro_dataloader(self, driver, dataloader, replaced_loader, shuffle): | |||||
""" | """ | ||||
测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确 | 测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确 | ||||
""" | """ | ||||
@@ -501,8 +528,8 @@ class TestSetDistReproDataloader: | |||||
drop_last=False, | drop_last=False, | ||||
) | ) | ||||
new_loader.batch_sampler.set_distributed( | new_loader.batch_sampler.set_distributed( | ||||
num_replicas=self.driver.world_size, | |||||
rank=self.driver.global_rank, | |||||
num_replicas=driver.world_size, | |||||
rank=driver.global_rank, | |||||
pad=True | pad=True | ||||
) | ) | ||||
new_loader.batch_sampler.load_state_dict(sampler_states) | new_loader.batch_sampler.load_state_dict(sampler_states) | ||||
@@ -512,8 +539,8 @@ class TestSetDistReproDataloader: | |||||
# 重新构造 dataloader | # 重新构造 dataloader | ||||
new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, drop_last=False) | new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, drop_last=False) | ||||
new_loader.batch_sampler.sampler.set_distributed( | new_loader.batch_sampler.sampler.set_distributed( | ||||
num_replicas=self.driver.world_size, | |||||
rank=self.driver.global_rank | |||||
num_replicas=driver.world_size, | |||||
rank=driver.global_rank | |||||
) | ) | ||||
new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | ||||
for idx, batch in enumerate(new_loader): | for idx, batch in enumerate(new_loader): | ||||
@@ -534,11 +561,6 @@ class TestSaveLoad: | |||||
测试多卡情况下 save 和 load 相关函数的表现 | 测试多卡情况下 save 和 load 相关函数的表现 | ||||
""" | """ | ||||
@classmethod | |||||
def setup_class(cls): | |||||
# 不在这里 setup 的话会报错 | |||||
cls.driver = generate_driver(10, 10) | |||||
def setup_method(self): | def setup_method(self): | ||||
self.dataset = TorchArgMaxDataset(10, 20) | self.dataset = TorchArgMaxDataset(10, 20) | ||||
@@ -552,26 +574,26 @@ class TestSaveLoad: | |||||
path = "model" | path = "model" | ||||
dataloader = DataLoader(self.dataset, batch_size=2) | dataloader = DataLoader(self.dataset, batch_size=2) | ||||
self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10) | |||||
driver1, driver2 = generate_driver(10, 10), generate_driver(10, 10) | |||||
self.driver1.save_model(path, only_state_dict) | |||||
driver1.save_model(path, only_state_dict) | |||||
# 同步 | # 同步 | ||||
dist.barrier() | dist.barrier() | ||||
self.driver2.load_model(path, only_state_dict) | |||||
driver2.load_model(path, only_state_dict) | |||||
for idx, batch in enumerate(dataloader): | for idx, batch in enumerate(dataloader): | ||||
batch = self.driver1.move_data_to_device(batch) | |||||
res1 = self.driver1.model( | |||||
batch = driver1.move_data_to_device(batch) | |||||
res1 = driver1.model( | |||||
batch, | batch, | ||||
fastnlp_fn=self.driver1.model.module.model.evaluate_step, | |||||
fastnlp_fn=driver1.model.module.model.evaluate_step, | |||||
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model | # Driver.model -> DataParallel.module -> _FleetWrappingModel.model | ||||
fastnlp_signature_fn=None, | fastnlp_signature_fn=None, | ||||
wo_auto_param_call=False, | wo_auto_param_call=False, | ||||
) | ) | ||||
res2 = self.driver2.model( | |||||
res2 = driver2.model( | |||||
batch, | batch, | ||||
fastnlp_fn=self.driver2.model.module.model.evaluate_step, | |||||
fastnlp_fn=driver2.model.module.model.evaluate_step, | |||||
fastnlp_signature_fn=None, | fastnlp_signature_fn=None, | ||||
wo_auto_param_call=False, | wo_auto_param_call=False, | ||||
) | ) | ||||
@@ -580,6 +602,9 @@ class TestSaveLoad: | |||||
finally: | finally: | ||||
rank_zero_rm(path) | rank_zero_rm(path) | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | @pytest.mark.parametrize("only_state_dict", ([True, False])) | ||||
@pytest.mark.parametrize("fp16", ([True, False])) | @pytest.mark.parametrize("fp16", ([True, False])) | ||||
@@ -593,7 +618,7 @@ class TestSaveLoad: | |||||
path = "model.ckp" | path = "model.ckp" | ||||
num_replicas = len(device) | num_replicas = len(device) | ||||
self.driver1, self.driver2 = generate_driver(10, 10, device=device, fp16=fp16), \ | |||||
driver1, driver2 = generate_driver(10, 10, device=device, fp16=fp16), \ | |||||
generate_driver(10, 10, device=device, fp16=False) | generate_driver(10, 10, device=device, fp16=False) | ||||
dataloader = dataloader_with_bucketedbatchsampler( | dataloader = dataloader_with_bucketedbatchsampler( | ||||
self.dataset, | self.dataset, | ||||
@@ -603,8 +628,8 @@ class TestSaveLoad: | |||||
drop_last=False | drop_last=False | ||||
) | ) | ||||
dataloader.batch_sampler.set_distributed( | dataloader.batch_sampler.set_distributed( | ||||
num_replicas=self.driver1.world_size, | |||||
rank=self.driver1.global_rank, | |||||
num_replicas=driver1.world_size, | |||||
rank=driver1.global_rank, | |||||
pad=True | pad=True | ||||
) | ) | ||||
num_consumed_batches = 2 | num_consumed_batches = 2 | ||||
@@ -623,7 +648,7 @@ class TestSaveLoad: | |||||
# 保存状态 | # 保存状态 | ||||
sampler_states = dataloader.batch_sampler.state_dict() | sampler_states = dataloader.batch_sampler.state_dict() | ||||
save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
# 加载 | # 加载 | ||||
# 更改 batch_size | # 更改 batch_size | ||||
dataloader = dataloader_with_bucketedbatchsampler( | dataloader = dataloader_with_bucketedbatchsampler( | ||||
@@ -634,11 +659,11 @@ class TestSaveLoad: | |||||
drop_last=False | drop_last=False | ||||
) | ) | ||||
dataloader.batch_sampler.set_distributed( | dataloader.batch_sampler.set_distributed( | ||||
num_replicas=self.driver2.world_size, | |||||
rank=self.driver2.global_rank, | |||||
num_replicas=driver2.world_size, | |||||
rank=driver2.global_rank, | |||||
pad=True | pad=True | ||||
) | ) | ||||
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
# 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||
# TODO optimizer 的 state_dict 总是为空 | # TODO optimizer 的 state_dict 总是为空 | ||||
@@ -652,7 +677,7 @@ class TestSaveLoad: | |||||
# 3. 检查 fp16 是否被加载 | # 3. 检查 fp16 是否被加载 | ||||
if fp16: | if fp16: | ||||
assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler) | |||||
assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) | |||||
# 4. 检查 model 的参数是否正确 | # 4. 检查 model 的参数是否正确 | ||||
# 5. 检查 batch_idx | # 5. 检查 batch_idx | ||||
@@ -664,16 +689,16 @@ class TestSaveLoad: | |||||
left_x_batches.update(batch["x"]) | left_x_batches.update(batch["x"]) | ||||
left_y_batches.update(batch["y"]) | left_y_batches.update(batch["y"]) | ||||
res1 = self.driver1.model( | |||||
res1 = driver1.model( | |||||
batch, | batch, | ||||
fastnlp_fn=self.driver1.model.module.model.evaluate_step, | |||||
fastnlp_fn=driver1.model.module.model.evaluate_step, | |||||
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model | # Driver.model -> DataParallel.module -> _FleetWrappingModel.model | ||||
fastnlp_signature_fn=None, | fastnlp_signature_fn=None, | ||||
wo_auto_param_call=False, | wo_auto_param_call=False, | ||||
) | ) | ||||
res2 = self.driver2.model( | |||||
res2 = driver2.model( | |||||
batch, | batch, | ||||
fastnlp_fn=self.driver2.model.module.model.evaluate_step, | |||||
fastnlp_fn=driver2.model.module.model.evaluate_step, | |||||
fastnlp_signature_fn=None, | fastnlp_signature_fn=None, | ||||
wo_auto_param_call=False, | wo_auto_param_call=False, | ||||
) | ) | ||||
@@ -686,6 +711,9 @@ class TestSaveLoad: | |||||
finally: | finally: | ||||
rank_zero_rm(path) | rank_zero_rm(path) | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | @pytest.mark.parametrize("only_state_dict", ([True, False])) | ||||
@pytest.mark.parametrize("fp16", ([True, False])) | @pytest.mark.parametrize("fp16", ([True, False])) | ||||
@@ -700,13 +728,13 @@ class TestSaveLoad: | |||||
num_replicas = len(device) | num_replicas = len(device) | ||||
self.driver1 = generate_driver(10, 10, device=device, fp16=fp16) | |||||
self.driver2 = generate_driver(10, 10, device=device, fp16=False) | |||||
driver1 = generate_driver(10, 10, device=device, fp16=fp16) | |||||
driver2 = generate_driver(10, 10, device=device, fp16=False) | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False) | dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False) | ||||
dataloader.batch_sampler.sampler.set_distributed( | dataloader.batch_sampler.sampler.set_distributed( | ||||
num_replicas=self.driver1.world_size, | |||||
rank=self.driver1.global_rank, | |||||
num_replicas=driver1.world_size, | |||||
rank=driver1.global_rank, | |||||
pad=True | pad=True | ||||
) | ) | ||||
num_consumed_batches = 2 | num_consumed_batches = 2 | ||||
@@ -726,18 +754,18 @@ class TestSaveLoad: | |||||
sampler_states = dataloader.batch_sampler.sampler.state_dict() | sampler_states = dataloader.batch_sampler.sampler.state_dict() | ||||
save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
if only_state_dict: | if only_state_dict: | ||||
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
else: | else: | ||||
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))]) | |||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))]) | |||||
# 加载 | # 加载 | ||||
# 更改 batch_size | # 更改 batch_size | ||||
dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False) | dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False) | ||||
dataloader.batch_sampler.sampler.set_distributed( | dataloader.batch_sampler.sampler.set_distributed( | ||||
num_replicas=self.driver2.world_size, | |||||
rank=self.driver2.global_rank, | |||||
num_replicas=driver2.world_size, | |||||
rank=driver2.global_rank, | |||||
pad=True | pad=True | ||||
) | ) | ||||
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
# 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||
@@ -753,7 +781,7 @@ class TestSaveLoad: | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] | assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] | ||||
# 3. 检查 fp16 是否被加载 | # 3. 检查 fp16 是否被加载 | ||||
if fp16: | if fp16: | ||||
assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler) | |||||
assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) | |||||
# 4. 检查 model 的参数是否正确 | # 4. 检查 model 的参数是否正确 | ||||
# 5. 检查 batch_idx | # 5. 检查 batch_idx | ||||
@@ -765,16 +793,16 @@ class TestSaveLoad: | |||||
left_x_batches.update(batch["x"]) | left_x_batches.update(batch["x"]) | ||||
left_y_batches.update(batch["y"]) | left_y_batches.update(batch["y"]) | ||||
res1 = self.driver1.model( | |||||
res1 = driver1.model( | |||||
batch, | batch, | ||||
fastnlp_fn=self.driver1.model.module.model.evaluate_step, | |||||
fastnlp_fn=driver1.model.module.model.evaluate_step, | |||||
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model | # Driver.model -> DataParallel.module -> _FleetWrappingModel.model | ||||
fastnlp_signature_fn=None, | fastnlp_signature_fn=None, | ||||
wo_auto_param_call=False, | wo_auto_param_call=False, | ||||
) | ) | ||||
res2 = self.driver2.model( | |||||
res2 = driver2.model( | |||||
batch, | batch, | ||||
fastnlp_fn=self.driver2.model.module.model.evaluate_step, | |||||
fastnlp_fn=driver2.model.module.model.evaluate_step, | |||||
fastnlp_signature_fn=None, | fastnlp_signature_fn=None, | ||||
wo_auto_param_call=False, | wo_auto_param_call=False, | ||||
) | ) | ||||
@@ -786,4 +814,7 @@ class TestSaveLoad: | |||||
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas | assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas | ||||
finally: | finally: | ||||
rank_zero_rm(path) | |||||
rank_zero_rm(path) | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() |
@@ -2,12 +2,14 @@ import pytest | |||||
from fastNLP.core.drivers import TorchSingleDriver, TorchDDPDriver | from fastNLP.core.drivers import TorchSingleDriver, TorchDDPDriver | ||||
from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver | from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver | ||||
from fastNLP.envs import get_gpu_count | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
import torch | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
from torch import device as torchdevice | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as torchdevice | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
def test_incorrect_driver(): | def test_incorrect_driver(): | ||||
@@ -20,7 +22,7 @@ def test_incorrect_driver(): | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"device", | "device", | ||||
["cpu", "cuda:0", 0, torch.device("cuda:0")] | |||||
["cpu", "cuda:0", 0, torchdevice("cuda:0")] | |||||
) | ) | ||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"driver", | "driver", | ||||
@@ -83,7 +85,6 @@ def test_get_ddp(driver, device): | |||||
("driver", "device"), | ("driver", "device"), | ||||
[("torch_ddp", "cpu")] | [("torch_ddp", "cpu")] | ||||
) | ) | ||||
@magic_argv_env_context | |||||
def test_get_ddp_cpu(driver, device): | def test_get_ddp_cpu(driver, device): | ||||
""" | """ | ||||
测试试图在 cpu 上初始化分布式训练的情况 | 测试试图在 cpu 上初始化分布式训练的情况 | ||||
@@ -96,13 +97,12 @@ def test_get_ddp_cpu(driver, device): | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"device", | "device", | ||||
[-2, [0, torch.cuda.device_count() + 1, 3], [-2], torch.cuda.device_count() + 1] | |||||
[-2, [0, 20, 3], [-2], 20] | |||||
) | ) | ||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"driver", | "driver", | ||||
["torch", "torch_ddp"] | ["torch", "torch_ddp"] | ||||
) | ) | ||||
@magic_argv_env_context | |||||
def test_device_out_of_range(driver, device): | def test_device_out_of_range(driver, device): | ||||
""" | """ | ||||
测试传入的device超过范围的情况 | 测试传入的device超过范围的情况 | ||||
@@ -7,15 +7,20 @@ import copy | |||||
import socket | import socket | ||||
import pytest | import pytest | ||||
import numpy as np | import numpy as np | ||||
import torch | |||||
import torch.distributed | |||||
from torch.multiprocessing import Pool, set_start_method | |||||
from sklearn.metrics import accuracy_score as sklearn_accuracy | from sklearn.metrics import accuracy_score as sklearn_accuracy | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.metrics.accuracy import Accuracy | from fastNLP.core.metrics.accuracy import Accuracy | ||||
from fastNLP.core.metrics.metric import Metric | from fastNLP.core.metrics.metric import Metric | ||||
from .utils import find_free_network_port, setup_ddp, _assert_allclose | from .utils import find_free_network_port, setup_ddp, _assert_allclose | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
import torch.distributed | |||||
from torch.multiprocessing import Pool, set_start_method | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as set_start_method | |||||
set_start_method("spawn", force=True) | set_start_method("spawn", force=True) | ||||
@@ -26,7 +31,7 @@ pool = None | |||||
def _test(local_rank: int, | def _test(local_rank: int, | ||||
world_size: int, | world_size: int, | ||||
device: torch.device, | |||||
device: "torch.device", | |||||
dataset: DataSet, | dataset: DataSet, | ||||
metric_class: Type[Metric], | metric_class: Type[Metric], | ||||
metric_kwargs: Dict[str, Any], | metric_kwargs: Dict[str, Any], | ||||
@@ -2,18 +2,23 @@ from functools import partial | |||||
import copy | import copy | ||||
import pytest | import pytest | ||||
import torch | |||||
import numpy as np | import numpy as np | ||||
from torch.multiprocessing import Pool, set_start_method | |||||
from fastNLP.core.metrics import ClassifyFPreRecMetric | from fastNLP.core.metrics import ClassifyFPreRecMetric | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
from .utils import find_free_network_port, setup_ddp | from .utils import find_free_network_port, setup_ddp | ||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
from torch.multiprocessing import Pool, set_start_method | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as set_start_method | |||||
set_start_method("spawn", force=True) | set_start_method("spawn", force=True) | ||||
def _test(local_rank: int, world_size: int, device: torch.device, | |||||
def _test(local_rank: int, world_size: int, device: "torch.device", | |||||
dataset: DataSet, metric_class, metric_kwargs, metric_result): | dataset: DataSet, metric_class, metric_kwargs, metric_result): | ||||
metric = metric_class(**metric_kwargs) | metric = metric_class(**metric_kwargs) | ||||
# dataset 也类似(每个进程有自己的一个) | # dataset 也类似(每个进程有自己的一个) | ||||
@@ -5,16 +5,21 @@ import os, sys | |||||
import copy | import copy | ||||
from functools import partial | from functools import partial | ||||
import torch | |||||
import torch.distributed | |||||
import numpy as np | import numpy as np | ||||
import socket | import socket | ||||
from torch.multiprocessing import Pool, set_start_method | |||||
# from multiprocessing import Pool, set_start_method | # from multiprocessing import Pool, set_start_method | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from fastNLP.core.metrics import SpanFPreRecMetric | from fastNLP.core.metrics import SpanFPreRecMetric | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
from .utils import find_free_network_port, setup_ddp | from .utils import find_free_network_port, setup_ddp | ||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
import torch.distributed | |||||
from torch.multiprocessing import Pool, set_start_method | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as set_start_method | |||||
set_start_method("spawn", force=True) | set_start_method("spawn", force=True) | ||||
@@ -44,7 +49,7 @@ pool = None | |||||
def _test(local_rank: int, | def _test(local_rank: int, | ||||
world_size: int, | world_size: int, | ||||
device: torch.device, | |||||
device: "torch.device", | |||||
dataset: DataSet, | dataset: DataSet, | ||||
metric_class, | metric_class, | ||||
metric_kwargs, | metric_kwargs, | ||||
@@ -2,9 +2,11 @@ import os, sys | |||||
import socket | import socket | ||||
from typing import Union | from typing import Union | ||||
import torch | |||||
from torch import distributed | |||||
import numpy as np | import numpy as np | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
from torch import distributed | |||||
def setup_ddp(rank: int, world_size: int, master_port: int) -> None: | def setup_ddp(rank: int, world_size: int, master_port: int) -> None: | ||||
@@ -3,6 +3,7 @@ import pytest | |||||
import subprocess | import subprocess | ||||
from io import StringIO | from io import StringIO | ||||
import sys | import sys | ||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../../..')) | |||||
from fastNLP.core.utils.cache_results import cache_results | from fastNLP.core.utils.cache_results import cache_results | ||||
from fastNLP.core import rank_zero_rm | from fastNLP.core import rank_zero_rm | ||||
@@ -1,4 +1,5 @@ | |||||
import os | import os | ||||
import pytest | |||||
from fastNLP.envs.set_backend import dump_fastnlp_backend | from fastNLP.envs.set_backend import dump_fastnlp_backend | ||||
from tests.helpers.utils import Capturing | from tests.helpers.utils import Capturing | ||||
@@ -9,7 +10,7 @@ def test_dump_fastnlp_envs(): | |||||
filepath = None | filepath = None | ||||
try: | try: | ||||
with Capturing() as output: | with Capturing() as output: | ||||
dump_fastnlp_backend() | |||||
dump_fastnlp_backend(backend="torch") | |||||
filepath = os.path.join(os.path.expanduser('~'), '.fastNLP', 'envs', os.environ['CONDA_DEFAULT_ENV']+'.json') | filepath = os.path.join(os.path.expanduser('~'), '.fastNLP', 'envs', os.environ['CONDA_DEFAULT_ENV']+'.json') | ||||
assert filepath in output[0] | assert filepath in output[0] | ||||
assert os.path.exists(filepath) | assert os.path.exists(filepath) | ||||
@@ -1,7 +1,9 @@ | |||||
import torch | |||||
from copy import deepcopy | from copy import deepcopy | ||||
from fastNLP.core.callbacks.callback import Callback | from fastNLP.core.callbacks.callback import Callback | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
class RecordAccumulationStepsCallback_Torch(Callback): | class RecordAccumulationStepsCallback_Torch(Callback): | ||||
@@ -1,7 +1,11 @@ | |||||
import torch | import torch | ||||
from functools import reduce | from functools import reduce | ||||
from torch.utils.data import Dataset, DataLoader, DistributedSampler | |||||
from torch.utils.data.sampler import SequentialSampler, BatchSampler | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
from torch.utils.data import Dataset, DataLoader, DistributedSampler | |||||
from torch.utils.data.sampler import SequentialSampler, BatchSampler | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as Dataset | |||||
class TorchNormalDataset(Dataset): | class TorchNormalDataset(Dataset): | ||||
@@ -1,9 +1,14 @@ | |||||
import torch | |||||
import torch.nn as nn | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
from torch.nn import Module | |||||
import torch.nn as nn | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as Module | |||||
# 1. 最为基础的分类模型 | # 1. 最为基础的分类模型 | ||||
class TorchNormalModel_Classification_1(nn.Module): | |||||
class TorchNormalModel_Classification_1(Module): | |||||
""" | """ | ||||
单独实现 train_step 和 evaluate_step; | 单独实现 train_step 和 evaluate_step; | ||||
""" | """ | ||||
@@ -38,7 +43,7 @@ class TorchNormalModel_Classification_1(nn.Module): | |||||
return {"preds": x, "target": y} | return {"preds": x, "target": y} | ||||
class TorchNormalModel_Classification_2(nn.Module): | |||||
class TorchNormalModel_Classification_2(Module): | |||||
""" | """ | ||||
只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; | 只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; | ||||
""" | """ | ||||
@@ -62,7 +67,7 @@ class TorchNormalModel_Classification_2(nn.Module): | |||||
return {"loss": loss, "preds": x, "target": y} | return {"loss": loss, "preds": x, "target": y} | ||||
class TorchNormalModel_Classification_3(nn.Module): | |||||
class TorchNormalModel_Classification_3(Module): | |||||
""" | """ | ||||
只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; | 只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; | ||||
关闭 auto_param_call,forward 只有一个 batch 参数; | 关闭 auto_param_call,forward 只有一个 batch 参数; | ||||
@@ -0,0 +1,6 @@ | |||||
[pytest] | |||||
markers = | |||||
torch | |||||
paddle | |||||
jittor | |||||
torchpaddle |