| @@ -202,26 +202,12 @@ class TorchDriver(Driver): | |||
| num_consumed_batches = states.pop('num_consumed_batches') | |||
| if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | |||
| sampler_states = sampler.state_dict() | |||
| # 需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||
| # 会造成多余实际消耗的问题。因为 | |||
| num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||
| if num_consumed_samples_array is not None: | |||
| if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||
| if dataloader_args.batch_size is not None: | |||
| num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||
| else: # 有可能 batch_size 为 None,就只有损失精度了 | |||
| logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
| "it may cause missing some samples when reload.") | |||
| num_consumed_batches = sampler_states['num_consumed_samples'] | |||
| sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||
| assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||
| if dataloader_args.batch_size is not None: | |||
| sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||
| * num_consumed_batches | |||
| else: | |||
| if dataloader_args.batch_size is not None: | |||
| sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||
| * num_consumed_batches | |||
| else: | |||
| logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
| "it may cause missing some samples when reload.") | |||
| logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on sampler's " | |||
| "`num_consumed_samples`, it may cause missing some samples when reload.") | |||
| states['sampler_states'] = sampler_states | |||
| else: | |||
| @@ -283,7 +283,7 @@ def optimizer_state_to_device(state, device): | |||
| def _check_dataloader_args_for_distributed(args, controller='Trainer'): | |||
| if type(args.batch_sampler) is not TorchBatchSampler and (type(args.sampler) not in {TorchRandomSampler, | |||
| if type(args.batch_sampler) is not TorchBatchSampler or (type(args.sampler) not in {TorchRandomSampler, | |||
| TorchSequentialSampler}): | |||
| mode = 'training' if controller == 'Trainer' else 'evaluation' | |||
| substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler' | |||
| @@ -13,7 +13,6 @@ from itertools import chain | |||
| import numpy as np | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.envs.utils import get_global_seed | |||
| from fastNLP.core.log import logger | |||
| from .utils import create_array | |||
| from abc import abstractmethod | |||
| @@ -171,7 +170,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
| :param kwargs: fastNLP 保留使用 | |||
| """ | |||
| def __init__(self, dataset, batch_size:int = 32, shuffle: bool = True, | |||
| drop_last: bool = False, seed: int = None, **kwargs): | |||
| drop_last: bool = False, seed: int = 0, **kwargs): | |||
| super().__init__() | |||
| self.dataset = dataset | |||
| @@ -179,7 +178,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
| self.batch_size = batch_size | |||
| self.shuffle = shuffle | |||
| self.drop_last = drop_last | |||
| self.seed = get_global_seed() if seed is None else seed | |||
| self.seed = int(seed) | |||
| self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 | |||
| @@ -398,7 +397,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
| :param kwargs: fastNLP 保留使用 | |||
| """ | |||
| def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, | |||
| shuffle: bool = True, drop_last: bool = False, seed: int = None, **kwargs): | |||
| shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): | |||
| super().__init__() | |||
| if isinstance(dataset, DataSet) and isinstance(length, str): | |||
| length = dataset.get_field(length).content | |||
| @@ -423,7 +422,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
| self.num_batch_per_bucket = num_batch_per_bucket | |||
| self.shuffle = shuffle | |||
| self.drop_last = drop_last | |||
| self.seed = get_global_seed() if seed is None else seed | |||
| self.seed = int(seed) | |||
| self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 | |||
| @@ -12,7 +12,6 @@ import numpy as np | |||
| from fastNLP.core.log import logger | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.envs.utils import get_global_seed | |||
| class ReproducibleSampler: | |||
| @@ -66,11 +65,11 @@ class RandomSampler(ReproducibleSampler): | |||
| :param seed: 随机数种子。 | |||
| :param kwargs: 用户不需要使用,fastNLP 内部使用 | |||
| """ | |||
| def __init__(self, dataset, shuffle: bool = True, seed: int = None, **kwargs): | |||
| def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): | |||
| super(RandomSampler, self).__init__() | |||
| self.dataset = dataset | |||
| self.shuffle = shuffle | |||
| self.seed = get_global_seed() if seed is None else seed | |||
| self.seed = int(seed) | |||
| self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 | |||
| @@ -7,7 +7,6 @@ __all__ = [ | |||
| from typing import List, Union | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.envs.utils import get_global_seed | |||
| import numpy as np | |||
| @@ -28,10 +27,10 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||
| :param seed: 设置的随机数种子 | |||
| :param kwargs: fastNLP 保留使用 | |||
| """ | |||
| def __init__(self, dataset, shuffle: bool = False, seed: int = None, **kwargs): | |||
| def __init__(self, dataset, shuffle: bool = False, seed: int = 0, **kwargs): | |||
| self.dataset = dataset | |||
| self.shuffle = shuffle | |||
| self.seed = get_global_seed() if seed is None else seed | |||
| self.seed = int(seed) | |||
| # 多卡的相关的参数 | |||
| self.num_replicas = kwargs.get('num_replicas', 1) | |||
| @@ -1,3 +1,5 @@ | |||
| import os | |||
| import pytest | |||
| from pathlib import Path | |||
| @@ -185,7 +187,7 @@ class TestSetDistReproDataloader: | |||
| cls.device = [0, 1] | |||
| def setup_method(self): | |||
| self.dataset = TorchNormalDataset(40) | |||
| self.dataset = TorchNormalDataset(100) | |||
| """ | |||
| 传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况 | |||
| @@ -571,7 +573,7 @@ class TestSaveLoad: | |||
| """ | |||
| def setup_method(self): | |||
| self.dataset = TorchNormalXYDataset(20) | |||
| self.dataset = TorchNormalXYDataset(100) | |||
| @magic_argv_env_context | |||
| @pytest.mark.parametrize("only_state_dict", ([True, False])) | |||
| @@ -641,7 +643,7 @@ class TestSaveLoad: | |||
| rank=driver1.global_rank, | |||
| pad=True | |||
| ) | |||
| num_consumed_batches = 2 | |||
| num_consumed_batches = 4 | |||
| already_seen_x_set = set() | |||
| already_seen_y_set = set() | |||
| @@ -686,7 +688,8 @@ class TestSaveLoad: | |||
| assert not (replaced_loader is dataloader) | |||
| assert replaced_loader.batch_sampler is dataloader.batch_sampler | |||
| assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | |||
| assert replaced_loader.batch_sampler.seed == sampler_states["seed"] | |||
| if os.environ['FASTNLP_GLOBAL_RANK'] == '0': | |||
| assert replaced_loader.batch_sampler.seed == sampler_states["seed"] | |||
| assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 * num_replicas | |||
| # 3. 检查 fp16 是否被加载 | |||
| @@ -753,7 +756,7 @@ class TestSaveLoad: | |||
| rank=driver1.global_rank, | |||
| pad=True | |||
| ) | |||
| num_consumed_batches = 2 | |||
| num_consumed_batches = 4 | |||
| already_seen_x_set = set() | |||
| already_seen_y_set = set() | |||
| @@ -792,11 +795,13 @@ class TestSaveLoad: | |||
| # 2. 检查 sampler 是否被正确地加载和替换 | |||
| assert not (replaced_loader is dataloader) | |||
| assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||
| assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] | |||
| assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] | |||
| if os.environ['FASTNLP_GLOBAL_RANK'] == '0': | |||
| assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] | |||
| assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] | |||
| assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] | |||
| assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] | |||
| assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches * num_replicas | |||
| assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] | |||
| assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] | |||
| # 3. 检查 fp16 是否被加载 | |||
| if fp16: | |||
| assert not isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) | |||