Browse Source

unrepeatedSampler支持chunk_split

tags/v1.0.0alpha
yh 2 years ago
parent
commit
1d665c9480
4 changed files with 28 additions and 10 deletions
  1. +1
    -1
      fastNLP/core/callbacks/checkpoint_callback.py
  2. +2
    -2
      fastNLP/core/dataset/dataset.py
  3. +22
    -5
      fastNLP/core/samplers/unrepeated_sampler.py
  4. +3
    -2
      tests/core/samplers/test_unrepeated_sampler.py

+ 1
- 1
fastNLP/core/callbacks/checkpoint_callback.py View File

@@ -58,7 +58,7 @@ class CheckpointCallback(Callback):
""" """
def __init__(self, folder: Optional[Union[str, Path]] = None, every_n_epochs: Optional[int] = None, def __init__(self, folder: Optional[Union[str, Path]] = None, every_n_epochs: Optional[int] = None,
every_n_batches: Optional[int] = None, last: bool = False, topk: int = 0, every_n_batches: Optional[int] = None, last: bool = False, topk: int = 0,
on_exceptions: Optional[Union[BaseException, Sequence[BaseException]]] = [EarlyStopException],
on_exceptions: Optional[Union[BaseException, Sequence[BaseException]]] = (EarlyStopException),
monitor: Optional[Union[str, Callable]] = None, larger_better: bool = True, monitor: Optional[Union[str, Callable]] = None, larger_better: bool = True,
only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, save_object: str = 'model', only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, save_object: str = 'model',
save_evaluate_results=True, **kwargs): save_evaluate_results=True, **kwargs):


+ 2
- 2
fastNLP/core/dataset/dataset.py View File

@@ -402,10 +402,10 @@ class DataSet:


def __getattr__(self, item): def __getattr__(self, item):
# Not tested. Don't use !! # Not tested. Don't use !!
if item == "field_arrays":
raise AttributeError
if isinstance(item, str) and item in self.field_arrays: if isinstance(item, str) and item in self.field_arrays:
return self.field_arrays[item] return self.field_arrays[item]
else:
raise AttributeError


def __setstate__(self, state): def __setstate__(self, state):
self.__dict__ = state self.__dict__ = state


+ 22
- 5
fastNLP/core/samplers/unrepeated_sampler.py View File

@@ -121,7 +121,9 @@ class UnrepeatedSortedSampler(UnrepeatedRandomSampler):
:param kwargs: fastNLP 保留使用 :param kwargs: fastNLP 保留使用
""" """
def __init__(self, dataset, length:Union[str, List], **kwargs): def __init__(self, dataset, length:Union[str, List], **kwargs):
super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs)
kwargs['shuffle'] = False
kwargs['seed'] = 0
super().__init__(dataset=dataset, **kwargs)
if isinstance(dataset, DataSet) and isinstance(length, str): if isinstance(dataset, DataSet) and isinstance(length, str):
length = dataset.get_field(length).content length = dataset.get_field(length).content
if not isinstance(length[0], int): if not isinstance(length[0], int):
@@ -141,17 +143,32 @@ class UnrepeatedSortedSampler(UnrepeatedRandomSampler):


class UnrepeatedSequentialSampler(UnrepeatedRandomSampler): class UnrepeatedSequentialSampler(UnrepeatedRandomSampler):
""" """
按照顺序读取 dataset。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。
按照顺序读取 dataset。


:param dataset: 实现了 __len__ 方法的数据容器。 :param dataset: 实现了 __len__ 方法的数据容器。
:param chunk_dist: 如果为 True ,当多卡时,将不间隔索取数据;为 False ,间隔取数据。例如,假设 dataset 有 10 个 sample ,使用
2 卡,如果为 True ,卡 0 拿 [0, 1, 2, 3, 4], 卡 1 拿 [5, 6, 7, 8, 9] ; 如果为 False ,则卡 0 拿 [0, 2, 4, 8, 8], 卡
1 拿 [1, 3, 5, 7, 9] 。
:param kwargs: :param kwargs:
""" """
def __init__(self, dataset, **kwargs):
super(UnrepeatedSequentialSampler, self).__init__(dataset, shuffle=False, seed=0, **kwargs)
def __init__(self, dataset, chunk_dist=False, **kwargs):
kwargs['shuffle'] = False
kwargs['seed'] = 0
super(UnrepeatedSequentialSampler, self).__init__(dataset, **kwargs)
self.chunk_dist = chunk_dist


def __iter__(self): def __iter__(self):
indices = self.generate_indices() indices = self.generate_indices()
indices = indices[self.rank:len(indices):self.num_replicas]
if self.num_replicas>1:
if self.chunk_dist:
chunk_size = len(indices)//self.num_replicas
start = chunk_size * self.rank
end = chunk_size * (self.rank + 1)
if self.rank == self.num_replicas - 1:
end = len(indices)
indices = indices[start:end]
else:
indices = indices[self.rank:len(indices):self.num_replicas]
for index in indices: for index in indices:
yield index yield index




+ 3
- 2
tests/core/samplers/test_unrepeated_sampler.py View File

@@ -87,13 +87,14 @@ class TestUnrepeatedSequentialSampler:


@pytest.mark.parametrize('num_replicas', [2, 3]) @pytest.mark.parametrize('num_replicas', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
def test_multi(self, num_replicas, num_of_data):
@pytest.mark.parametrize('chunk_dist', [True, False])
def test_multi(self, num_replicas, num_of_data, chunk_dist):
if num_replicas > num_of_data: if num_replicas > num_of_data:
pytest.skip("num_replicas > num_of_data") pytest.skip("num_replicas > num_of_data")
data = DatasetWithVaryLength(num_of_data=num_of_data) data = DatasetWithVaryLength(num_of_data=num_of_data)
samplers = [] samplers = []
for i in range(num_replicas): for i in range(num_replicas):
sampler = UnrepeatedSequentialSampler(dataset=data, length=data.data)
sampler = UnrepeatedSequentialSampler(dataset=data, chunk_dist=chunk_dist)
sampler.set_distributed(num_replicas, rank=i) sampler.set_distributed(num_replicas, rank=i)
samplers.append(sampler) samplers.append(sampler)




Loading…
Cancel
Save