diff --git a/fastNLP/core/callbacks/checkpoint_callback.py b/fastNLP/core/callbacks/checkpoint_callback.py index 8a02ffbc..0cc3021b 100644 --- a/fastNLP/core/callbacks/checkpoint_callback.py +++ b/fastNLP/core/callbacks/checkpoint_callback.py @@ -58,7 +58,7 @@ class CheckpointCallback(Callback): """ def __init__(self, folder: Optional[Union[str, Path]] = None, every_n_epochs: Optional[int] = None, every_n_batches: Optional[int] = None, last: bool = False, topk: int = 0, - on_exceptions: Optional[Union[BaseException, Sequence[BaseException]]] = [EarlyStopException], + on_exceptions: Optional[Union[BaseException, Sequence[BaseException]]] = (EarlyStopException), monitor: Optional[Union[str, Callable]] = None, larger_better: bool = True, only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, save_object: str = 'model', save_evaluate_results=True, **kwargs): diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index 9da65112..63a1e079 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -402,10 +402,10 @@ class DataSet: def __getattr__(self, item): # Not tested. Don't use !! - if item == "field_arrays": - raise AttributeError if isinstance(item, str) and item in self.field_arrays: return self.field_arrays[item] + else: + raise AttributeError def __setstate__(self, state): self.__dict__ = state diff --git a/fastNLP/core/samplers/unrepeated_sampler.py b/fastNLP/core/samplers/unrepeated_sampler.py index 69eb532d..e94215a6 100644 --- a/fastNLP/core/samplers/unrepeated_sampler.py +++ b/fastNLP/core/samplers/unrepeated_sampler.py @@ -121,7 +121,9 @@ class UnrepeatedSortedSampler(UnrepeatedRandomSampler): :param kwargs: fastNLP 保留使用 """ def __init__(self, dataset, length:Union[str, List], **kwargs): - super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) + kwargs['shuffle'] = False + kwargs['seed'] = 0 + super().__init__(dataset=dataset, **kwargs) if isinstance(dataset, DataSet) and isinstance(length, str): length = dataset.get_field(length).content if not isinstance(length[0], int): @@ -141,17 +143,32 @@ class UnrepeatedSortedSampler(UnrepeatedRandomSampler): class UnrepeatedSequentialSampler(UnrepeatedRandomSampler): """ - 按照顺序读取 dataset。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。 + 按照顺序读取 dataset。 :param dataset: 实现了 __len__ 方法的数据容器。 + :param chunk_dist: 如果为 True ,当多卡时,将不间隔索取数据;为 False ,间隔取数据。例如,假设 dataset 有 10 个 sample ,使用 + 2 卡,如果为 True ,卡 0 拿 [0, 1, 2, 3, 4], 卡 1 拿 [5, 6, 7, 8, 9] ; 如果为 False ,则卡 0 拿 [0, 2, 4, 8, 8], 卡 + 1 拿 [1, 3, 5, 7, 9] 。 :param kwargs: """ - def __init__(self, dataset, **kwargs): - super(UnrepeatedSequentialSampler, self).__init__(dataset, shuffle=False, seed=0, **kwargs) + def __init__(self, dataset, chunk_dist=False, **kwargs): + kwargs['shuffle'] = False + kwargs['seed'] = 0 + super(UnrepeatedSequentialSampler, self).__init__(dataset, **kwargs) + self.chunk_dist = chunk_dist def __iter__(self): indices = self.generate_indices() - indices = indices[self.rank:len(indices):self.num_replicas] + if self.num_replicas>1: + if self.chunk_dist: + chunk_size = len(indices)//self.num_replicas + start = chunk_size * self.rank + end = chunk_size * (self.rank + 1) + if self.rank == self.num_replicas - 1: + end = len(indices) + indices = indices[start:end] + else: + indices = indices[self.rank:len(indices):self.num_replicas] for index in indices: yield index diff --git a/tests/core/samplers/test_unrepeated_sampler.py b/tests/core/samplers/test_unrepeated_sampler.py index 0d16ec89..d3a74269 100644 --- a/tests/core/samplers/test_unrepeated_sampler.py +++ b/tests/core/samplers/test_unrepeated_sampler.py @@ -87,13 +87,14 @@ class TestUnrepeatedSequentialSampler: @pytest.mark.parametrize('num_replicas', [2, 3]) @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) - def test_multi(self, num_replicas, num_of_data): + @pytest.mark.parametrize('chunk_dist', [True, False]) + def test_multi(self, num_replicas, num_of_data, chunk_dist): if num_replicas > num_of_data: pytest.skip("num_replicas > num_of_data") data = DatasetWithVaryLength(num_of_data=num_of_data) samplers = [] for i in range(num_replicas): - sampler = UnrepeatedSequentialSampler(dataset=data, length=data.data) + sampler = UnrepeatedSequentialSampler(dataset=data, chunk_dist=chunk_dist) sampler.set_distributed(num_replicas, rank=i) samplers.append(sampler)