|
@@ -1,4 +1,5 @@ |
|
|
from functools import reduce |
|
|
from functools import reduce |
|
|
|
|
|
import pytest |
|
|
|
|
|
|
|
|
from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader # TODO: 该类修改过,记得将 test 也修改; |
|
|
from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader # TODO: 该类修改过,记得将 test 也修改; |
|
|
from tests.helpers.datasets.normal_data import NormalSampler |
|
|
from tests.helpers.datasets.normal_data import NormalSampler |
|
@@ -21,6 +22,7 @@ class Test_WrapDataLoader: |
|
|
mark += 1 |
|
|
mark += 1 |
|
|
assert mark == sanity_batches |
|
|
assert mark == sanity_batches |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.torch |
|
|
def test_torch_dataloader(self): |
|
|
def test_torch_dataloader(self): |
|
|
from tests.helpers.datasets.torch_data import TorchNormalDataset |
|
|
from tests.helpers.datasets.torch_data import TorchNormalDataset |
|
|
from torch.utils.data import DataLoader |
|
|
from torch.utils.data import DataLoader |
|
@@ -42,6 +44,7 @@ class Test_WrapDataLoader: |
|
|
all_supposed_running_data_num += _data.shape[0] |
|
|
all_supposed_running_data_num += _data.shape[0] |
|
|
assert all_supposed_running_data_num == bs * sanity_batches |
|
|
assert all_supposed_running_data_num == bs * sanity_batches |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.torch |
|
|
def test_len(self): |
|
|
def test_len(self): |
|
|
from tests.helpers.datasets.torch_data import TorchNormalDataset |
|
|
from tests.helpers.datasets.torch_data import TorchNormalDataset |
|
|
from torch.utils.data import DataLoader |
|
|
from torch.utils.data import DataLoader |
|
|