From badd07f8f13784058bbd9d1da137bd2396b4168a Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sun, 29 May 2022 07:04:24 +0000 Subject: [PATCH 1/2] =?UTF-8?q?=E4=BF=AE=E6=94=B9=20ReproducibleSampler=20?= =?UTF-8?q?ReproducibleBatchSampler=20=E7=9A=84=20num=5Fsamples=20?= =?UTF-8?q?=E8=8E=B7=E5=8F=96=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../samplers/reproducible_batch_sampler.py | 20 +++++++++++++++++-- fastNLP/core/samplers/reproducible_sampler.py | 5 ++++- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index 874ad895..edb8a67f 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -318,7 +318,15 @@ class RandomBatchSampler(ReproducibleBatchSampler): @property def num_samples(self): - return getattr(self.dataset, 'total_len', len(self.dataset)) + """ + 返回样本的总数 + + :return: + """ + total_len = getattr(self.dataset, 'total_len', None) + if not isinstance(total_len, int): + total_len = len(self.dataset) + return total_len def __len__(self)->int: """ @@ -473,7 +481,15 @@ class BucketedBatchSampler(ReproducibleBatchSampler): @property def num_samples(self): - return getattr(self.dataset, 'total_len', len(self.dataset)) + """ + 返回样本的总数 + + :return: + """ + total_len = getattr(self.dataset, 'total_len', None) + if not isinstance(total_len, int): + total_len = len(self.dataset) + return total_len def __len__(self)->int: """ diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py index a1dc318e..fe38a808 100644 --- a/fastNLP/core/samplers/reproducible_sampler.py +++ b/fastNLP/core/samplers/reproducible_sampler.py @@ -222,7 +222,10 @@ class RandomSampler(ReproducibleSampler): :return: """ - return getattr(self.dataset, 'total_len', len(self.dataset)) + total_len = getattr(self.dataset, 'total_len', None) + if not isinstance(total_len, int): + total_len = len(self.dataset) + return total_len class SequentialSampler(RandomSampler): """ From 9b0a30c8fb9b7f79e319a9e3d6bcf11227442cd4 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sun, 29 May 2022 13:51:10 +0000 Subject: [PATCH 2/2] small bug --- fastNLP/core/drivers/jittor_driver/jittor_driver.py | 3 +-- tests/core/drivers/paddle_driver/test_dist_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/fastNLP/core/drivers/jittor_driver/jittor_driver.py b/fastNLP/core/drivers/jittor_driver/jittor_driver.py index 5b38747d..e486df8e 100644 --- a/fastNLP/core/drivers/jittor_driver/jittor_driver.py +++ b/fastNLP/core/drivers/jittor_driver/jittor_driver.py @@ -1,7 +1,6 @@ import os from pathlib import Path from typing import Union, Optional, Dict -from contextlib import nullcontext from dataclasses import dataclass from fastNLP.envs.imports import _NEED_IMPORT_JITTOR @@ -9,7 +8,7 @@ from fastNLP.core.drivers.driver import Driver from fastNLP.core.dataloaders import JittorDataLoader from fastNLP.core.samplers import ReproducibleSampler, RandomSampler from fastNLP.core.log import logger -from fastNLP.core.utils import apply_to_collection +from fastNLP.core.utils import apply_to_collection, nullcontext from fastNLP.envs import ( FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME, diff --git a/tests/core/drivers/paddle_driver/test_dist_utils.py b/tests/core/drivers/paddle_driver/test_dist_utils.py index e3a3eb5d..30ce9d29 100644 --- a/tests/core/drivers/paddle_driver/test_dist_utils.py +++ b/tests/core/drivers/paddle_driver/test_dist_utils.py @@ -84,7 +84,7 @@ class TestAllGatherAndBroadCast: @classmethod def setup_class(cls): - devices = [0,1,2] + devices = [0,1] output_from_new_proc = "all" launcher = FleetLauncher(devices=devices, output_from_new_proc=output_from_new_proc) @@ -150,7 +150,7 @@ class TestAllGatherAndBroadCast: dist.barrier() @magic_argv_env_context - @pytest.mark.parametrize("src_rank", ([0, 1, 2])) + @pytest.mark.parametrize("src_rank", ([0, 1])) def test_fastnlp_paddle_broadcast_object(self, src_rank): if self.local_rank == src_rank: obj = {