diff --git a/fastNLP/core/drivers/jittor_driver/jittor_driver.py b/fastNLP/core/drivers/jittor_driver/jittor_driver.py index 5b38747d..e486df8e 100644 --- a/fastNLP/core/drivers/jittor_driver/jittor_driver.py +++ b/fastNLP/core/drivers/jittor_driver/jittor_driver.py @@ -1,7 +1,6 @@ import os from pathlib import Path from typing import Union, Optional, Dict -from contextlib import nullcontext from dataclasses import dataclass from fastNLP.envs.imports import _NEED_IMPORT_JITTOR @@ -9,7 +8,7 @@ from fastNLP.core.drivers.driver import Driver from fastNLP.core.dataloaders import JittorDataLoader from fastNLP.core.samplers import ReproducibleSampler, RandomSampler from fastNLP.core.log import logger -from fastNLP.core.utils import apply_to_collection +from fastNLP.core.utils import apply_to_collection, nullcontext from fastNLP.envs import ( FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME, diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index 874ad895..edb8a67f 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -318,7 +318,15 @@ class RandomBatchSampler(ReproducibleBatchSampler): @property def num_samples(self): - return getattr(self.dataset, 'total_len', len(self.dataset)) + """ + 返回样本的总数 + + :return: + """ + total_len = getattr(self.dataset, 'total_len', None) + if not isinstance(total_len, int): + total_len = len(self.dataset) + return total_len def __len__(self)->int: """ @@ -473,7 +481,15 @@ class BucketedBatchSampler(ReproducibleBatchSampler): @property def num_samples(self): - return getattr(self.dataset, 'total_len', len(self.dataset)) + """ + 返回样本的总数 + + :return: + """ + total_len = getattr(self.dataset, 'total_len', None) + if not isinstance(total_len, int): + total_len = len(self.dataset) + return total_len def __len__(self)->int: """ diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py index a1dc318e..fe38a808 100644 --- a/fastNLP/core/samplers/reproducible_sampler.py +++ b/fastNLP/core/samplers/reproducible_sampler.py @@ -222,7 +222,10 @@ class RandomSampler(ReproducibleSampler): :return: """ - return getattr(self.dataset, 'total_len', len(self.dataset)) + total_len = getattr(self.dataset, 'total_len', None) + if not isinstance(total_len, int): + total_len = len(self.dataset) + return total_len class SequentialSampler(RandomSampler): """ diff --git a/tests/core/drivers/paddle_driver/test_dist_utils.py b/tests/core/drivers/paddle_driver/test_dist_utils.py index e3a3eb5d..30ce9d29 100644 --- a/tests/core/drivers/paddle_driver/test_dist_utils.py +++ b/tests/core/drivers/paddle_driver/test_dist_utils.py @@ -84,7 +84,7 @@ class TestAllGatherAndBroadCast: @classmethod def setup_class(cls): - devices = [0,1,2] + devices = [0,1] output_from_new_proc = "all" launcher = FleetLauncher(devices=devices, output_from_new_proc=output_from_new_proc) @@ -150,7 +150,7 @@ class TestAllGatherAndBroadCast: dist.barrier() @magic_argv_env_context - @pytest.mark.parametrize("src_rank", ([0, 1, 2])) + @pytest.mark.parametrize("src_rank", ([0, 1])) def test_fastnlp_paddle_broadcast_object(self, src_rank): if self.local_rank == src_rank: obj = {