@@ -1,7 +1,6 @@ | |||||
import os | import os | ||||
from pathlib import Path | from pathlib import Path | ||||
from typing import Union, Optional, Dict | from typing import Union, Optional, Dict | ||||
from contextlib import nullcontext | |||||
from dataclasses import dataclass | from dataclasses import dataclass | ||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | ||||
@@ -9,7 +8,7 @@ from fastNLP.core.drivers.driver import Driver | |||||
from fastNLP.core.dataloaders import JittorDataLoader | from fastNLP.core.dataloaders import JittorDataLoader | ||||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler | from fastNLP.core.samplers import ReproducibleSampler, RandomSampler | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.utils import apply_to_collection | |||||
from fastNLP.core.utils import apply_to_collection, nullcontext | |||||
from fastNLP.envs import ( | from fastNLP.envs import ( | ||||
FASTNLP_MODEL_FILENAME, | FASTNLP_MODEL_FILENAME, | ||||
FASTNLP_CHECKPOINT_FILENAME, | FASTNLP_CHECKPOINT_FILENAME, | ||||
@@ -318,7 +318,15 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
@property | @property | ||||
def num_samples(self): | def num_samples(self): | ||||
return getattr(self.dataset, 'total_len', len(self.dataset)) | |||||
""" | |||||
返回样本的总数 | |||||
:return: | |||||
""" | |||||
total_len = getattr(self.dataset, 'total_len', None) | |||||
if not isinstance(total_len, int): | |||||
total_len = len(self.dataset) | |||||
return total_len | |||||
def __len__(self)->int: | def __len__(self)->int: | ||||
""" | """ | ||||
@@ -473,7 +481,15 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
@property | @property | ||||
def num_samples(self): | def num_samples(self): | ||||
return getattr(self.dataset, 'total_len', len(self.dataset)) | |||||
""" | |||||
返回样本的总数 | |||||
:return: | |||||
""" | |||||
total_len = getattr(self.dataset, 'total_len', None) | |||||
if not isinstance(total_len, int): | |||||
total_len = len(self.dataset) | |||||
return total_len | |||||
def __len__(self)->int: | def __len__(self)->int: | ||||
""" | """ | ||||
@@ -222,7 +222,10 @@ class RandomSampler(ReproducibleSampler): | |||||
:return: | :return: | ||||
""" | """ | ||||
return getattr(self.dataset, 'total_len', len(self.dataset)) | |||||
total_len = getattr(self.dataset, 'total_len', None) | |||||
if not isinstance(total_len, int): | |||||
total_len = len(self.dataset) | |||||
return total_len | |||||
class SequentialSampler(RandomSampler): | class SequentialSampler(RandomSampler): | ||||
""" | """ | ||||
@@ -84,7 +84,7 @@ class TestAllGatherAndBroadCast: | |||||
@classmethod | @classmethod | ||||
def setup_class(cls): | def setup_class(cls): | ||||
devices = [0,1,2] | |||||
devices = [0,1] | |||||
output_from_new_proc = "all" | output_from_new_proc = "all" | ||||
launcher = FleetLauncher(devices=devices, output_from_new_proc=output_from_new_proc) | launcher = FleetLauncher(devices=devices, output_from_new_proc=output_from_new_proc) | ||||
@@ -150,7 +150,7 @@ class TestAllGatherAndBroadCast: | |||||
dist.barrier() | dist.barrier() | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@pytest.mark.parametrize("src_rank", ([0, 1, 2])) | |||||
@pytest.mark.parametrize("src_rank", ([0, 1])) | |||||
def test_fastnlp_paddle_broadcast_object(self, src_rank): | def test_fastnlp_paddle_broadcast_object(self, src_rank): | ||||
if self.local_rank == src_rank: | if self.local_rank == src_rank: | ||||
obj = { | obj = { | ||||