@@ -58,7 +58,7 @@ class CheckpointCallback(Callback): | |||||
""" | """ | ||||
def __init__(self, folder: Optional[Union[str, Path]] = None, every_n_epochs: Optional[int] = None, | def __init__(self, folder: Optional[Union[str, Path]] = None, every_n_epochs: Optional[int] = None, | ||||
every_n_batches: Optional[int] = None, last: bool = False, topk: int = 0, | every_n_batches: Optional[int] = None, last: bool = False, topk: int = 0, | ||||
on_exceptions: Optional[Union[BaseException, Sequence[BaseException]]] = [EarlyStopException], | |||||
on_exceptions: Optional[Union[BaseException, Sequence[BaseException]]] = (EarlyStopException), | |||||
monitor: Optional[Union[str, Callable]] = None, larger_better: bool = True, | monitor: Optional[Union[str, Callable]] = None, larger_better: bool = True, | ||||
only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, save_object: str = 'model', | only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, save_object: str = 'model', | ||||
save_evaluate_results=True, **kwargs): | save_evaluate_results=True, **kwargs): | ||||
@@ -402,10 +402,10 @@ class DataSet: | |||||
def __getattr__(self, item): | def __getattr__(self, item): | ||||
# Not tested. Don't use !! | # Not tested. Don't use !! | ||||
if item == "field_arrays": | |||||
raise AttributeError | |||||
if isinstance(item, str) and item in self.field_arrays: | if isinstance(item, str) and item in self.field_arrays: | ||||
return self.field_arrays[item] | return self.field_arrays[item] | ||||
else: | |||||
raise AttributeError | |||||
def __setstate__(self, state): | def __setstate__(self, state): | ||||
self.__dict__ = state | self.__dict__ = state | ||||
@@ -121,7 +121,9 @@ class UnrepeatedSortedSampler(UnrepeatedRandomSampler): | |||||
:param kwargs: fastNLP 保留使用 | :param kwargs: fastNLP 保留使用 | ||||
""" | """ | ||||
def __init__(self, dataset, length:Union[str, List], **kwargs): | def __init__(self, dataset, length:Union[str, List], **kwargs): | ||||
super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) | |||||
kwargs['shuffle'] = False | |||||
kwargs['seed'] = 0 | |||||
super().__init__(dataset=dataset, **kwargs) | |||||
if isinstance(dataset, DataSet) and isinstance(length, str): | if isinstance(dataset, DataSet) and isinstance(length, str): | ||||
length = dataset.get_field(length).content | length = dataset.get_field(length).content | ||||
if not isinstance(length[0], int): | if not isinstance(length[0], int): | ||||
@@ -141,17 +143,32 @@ class UnrepeatedSortedSampler(UnrepeatedRandomSampler): | |||||
class UnrepeatedSequentialSampler(UnrepeatedRandomSampler): | class UnrepeatedSequentialSampler(UnrepeatedRandomSampler): | ||||
""" | """ | ||||
按照顺序读取 dataset。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。 | |||||
按照顺序读取 dataset。 | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | :param dataset: 实现了 __len__ 方法的数据容器。 | ||||
:param chunk_dist: 如果为 True ,当多卡时,将不间隔索取数据;为 False ,间隔取数据。例如,假设 dataset 有 10 个 sample ,使用 | |||||
2 卡,如果为 True ,卡 0 拿 [0, 1, 2, 3, 4], 卡 1 拿 [5, 6, 7, 8, 9] ; 如果为 False ,则卡 0 拿 [0, 2, 4, 8, 8], 卡 | |||||
1 拿 [1, 3, 5, 7, 9] 。 | |||||
:param kwargs: | :param kwargs: | ||||
""" | """ | ||||
def __init__(self, dataset, **kwargs): | |||||
super(UnrepeatedSequentialSampler, self).__init__(dataset, shuffle=False, seed=0, **kwargs) | |||||
def __init__(self, dataset, chunk_dist=False, **kwargs): | |||||
kwargs['shuffle'] = False | |||||
kwargs['seed'] = 0 | |||||
super(UnrepeatedSequentialSampler, self).__init__(dataset, **kwargs) | |||||
self.chunk_dist = chunk_dist | |||||
def __iter__(self): | def __iter__(self): | ||||
indices = self.generate_indices() | indices = self.generate_indices() | ||||
indices = indices[self.rank:len(indices):self.num_replicas] | |||||
if self.num_replicas>1: | |||||
if self.chunk_dist: | |||||
chunk_size = len(indices)//self.num_replicas | |||||
start = chunk_size * self.rank | |||||
end = chunk_size * (self.rank + 1) | |||||
if self.rank == self.num_replicas - 1: | |||||
end = len(indices) | |||||
indices = indices[start:end] | |||||
else: | |||||
indices = indices[self.rank:len(indices):self.num_replicas] | |||||
for index in indices: | for index in indices: | ||||
yield index | yield index | ||||
@@ -87,13 +87,14 @@ class TestUnrepeatedSequentialSampler: | |||||
@pytest.mark.parametrize('num_replicas', [2, 3]) | @pytest.mark.parametrize('num_replicas', [2, 3]) | ||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
def test_multi(self, num_replicas, num_of_data): | |||||
@pytest.mark.parametrize('chunk_dist', [True, False]) | |||||
def test_multi(self, num_replicas, num_of_data, chunk_dist): | |||||
if num_replicas > num_of_data: | if num_replicas > num_of_data: | ||||
pytest.skip("num_replicas > num_of_data") | pytest.skip("num_replicas > num_of_data") | ||||
data = DatasetWithVaryLength(num_of_data=num_of_data) | data = DatasetWithVaryLength(num_of_data=num_of_data) | ||||
samplers = [] | samplers = [] | ||||
for i in range(num_replicas): | for i in range(num_replicas): | ||||
sampler = UnrepeatedSequentialSampler(dataset=data, length=data.data) | |||||
sampler = UnrepeatedSequentialSampler(dataset=data, chunk_dist=chunk_dist) | |||||
sampler.set_distributed(num_replicas, rank=i) | sampler.set_distributed(num_replicas, rank=i) | ||||
samplers.append(sampler) | samplers.append(sampler) | ||||