|
|
@@ -121,7 +121,9 @@ class UnrepeatedSortedSampler(UnrepeatedRandomSampler): |
|
|
|
:param kwargs: fastNLP 保留使用 |
|
|
|
""" |
|
|
|
def __init__(self, dataset, length:Union[str, List], **kwargs): |
|
|
|
super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) |
|
|
|
kwargs['shuffle'] = False |
|
|
|
kwargs['seed'] = 0 |
|
|
|
super().__init__(dataset=dataset, **kwargs) |
|
|
|
if isinstance(dataset, DataSet) and isinstance(length, str): |
|
|
|
length = dataset.get_field(length).content |
|
|
|
if not isinstance(length[0], int): |
|
|
@@ -141,17 +143,32 @@ class UnrepeatedSortedSampler(UnrepeatedRandomSampler): |
|
|
|
|
|
|
|
class UnrepeatedSequentialSampler(UnrepeatedRandomSampler): |
|
|
|
""" |
|
|
|
按照顺序读取 dataset。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。 |
|
|
|
按照顺序读取 dataset。 |
|
|
|
|
|
|
|
:param dataset: 实现了 __len__ 方法的数据容器。 |
|
|
|
:param chunk_dist: 如果为 True ,当多卡时,将不间隔索取数据;为 False ,间隔取数据。例如,假设 dataset 有 10 个 sample ,使用 |
|
|
|
2 卡,如果为 True ,卡 0 拿 [0, 1, 2, 3, 4], 卡 1 拿 [5, 6, 7, 8, 9] ; 如果为 False ,则卡 0 拿 [0, 2, 4, 8, 8], 卡 |
|
|
|
1 拿 [1, 3, 5, 7, 9] 。 |
|
|
|
:param kwargs: |
|
|
|
""" |
|
|
|
def __init__(self, dataset, **kwargs): |
|
|
|
super(UnrepeatedSequentialSampler, self).__init__(dataset, shuffle=False, seed=0, **kwargs) |
|
|
|
def __init__(self, dataset, chunk_dist=False, **kwargs): |
|
|
|
kwargs['shuffle'] = False |
|
|
|
kwargs['seed'] = 0 |
|
|
|
super(UnrepeatedSequentialSampler, self).__init__(dataset, **kwargs) |
|
|
|
self.chunk_dist = chunk_dist |
|
|
|
|
|
|
|
def __iter__(self): |
|
|
|
indices = self.generate_indices() |
|
|
|
indices = indices[self.rank:len(indices):self.num_replicas] |
|
|
|
if self.num_replicas>1: |
|
|
|
if self.chunk_dist: |
|
|
|
chunk_size = len(indices)//self.num_replicas |
|
|
|
start = chunk_size * self.rank |
|
|
|
end = chunk_size * (self.rank + 1) |
|
|
|
if self.rank == self.num_replicas - 1: |
|
|
|
end = len(indices) |
|
|
|
indices = indices[start:end] |
|
|
|
else: |
|
|
|
indices = indices[self.rank:len(indices):self.num_replicas] |
|
|
|
for index in indices: |
|
|
|
yield index |
|
|
|
|
|
|
|