@@ -188,17 +188,27 @@ class TorchDriver(Driver): | |||||
num_consumed_batches = states.pop('num_consumed_batches') | num_consumed_batches = states.pop('num_consumed_batches') | ||||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | ||||
sampler_states = sampler.state_dict() | sampler_states = sampler.state_dict() | ||||
# 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||||
# 会造成多余实际消耗的问题。 | |||||
# 需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||||
# 会造成多余实际消耗的问题。因为 | |||||
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | ||||
if num_consumed_samples_array is not None: | if num_consumed_samples_array is not None: | ||||
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | ||||
try: | |||||
if dataloader_args.batch_size is not None: | |||||
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | ||||
except: # 有可能 batch_size 为 None,就只有损失精度了 | |||||
else: # 有可能 batch_size 为 None,就只有损失精度了 | |||||
logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||||
"it may cause missing some samples when reload.") | |||||
num_consumed_batches = sampler_states['num_consumed_samples'] | num_consumed_batches = sampler_states['num_consumed_samples'] | ||||
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | ||||
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | ||||
else: | |||||
if dataloader_args.batch_size is not None: | |||||
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||||
* num_consumed_batches | |||||
else: | |||||
logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||||
"it may cause missing some samples when reload.") | |||||
states['sampler_states'] = sampler_states | states['sampler_states'] = sampler_states | ||||
else: | else: | ||||
raise RuntimeError( | raise RuntimeError( | ||||
@@ -63,15 +63,12 @@ class ClassifyFPreRecMetric(Metric): | |||||
evaluate_result = {} | evaluate_result = {} | ||||
# 通过 all_gather_object 将各个卡上的结果收集过来,并加和。 | # 通过 all_gather_object 将各个卡上的结果收集过来,并加和。 | ||||
if self.aggregate_when_get_metric: | |||||
ls = self.backend.all_gather_object([self._tp, self._fp, self._fn]) | |||||
tps, fps, fns = zip(*ls) | |||||
_tp, _fp, _fn = Counter(), Counter(), Counter() | |||||
for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]): | |||||
for _c in cs: | |||||
c.update(_c) | |||||
else: | |||||
_tp, _fp, _fn = self._tp, self._fp, self._tp | |||||
ls = self.all_gather_object([self._tp, self._fp, self._fn]) | |||||
tps, fps, fns = zip(*ls) | |||||
_tp, _fp, _fn = Counter(), Counter(), Counter() | |||||
for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]): | |||||
for _c in cs: | |||||
c.update(_c) | |||||
if not self.only_gross or self.f_type == 'macro': | if not self.only_gross or self.f_type == 'macro': | ||||
tags = set(_fn.keys()) | tags = set(_fn.keys()) | ||||
@@ -4,7 +4,7 @@ __all__ = [ | |||||
from abc import abstractmethod | from abc import abstractmethod | ||||
from typing import Union | |||||
from typing import Union, List | |||||
import functools | import functools | ||||
from contextlib import contextmanager | from contextlib import contextmanager | ||||
import numpy as np | import numpy as np | ||||
@@ -180,3 +180,15 @@ class Metric: | |||||
""" | """ | ||||
for element in self.elements.values(): | for element in self.elements.values(): | ||||
element.to(device) | element.to(device) | ||||
def all_gather_object(self, obj, group=None)->List: | |||||
""" | |||||
给定 obj 将各个 rank 上的 obj 汇总到每个 obj 上。返回一个 list 对象,里面依次为各个 rank 对应的 obj 。 | |||||
:param obj: 需要汇总的对象,必须是个 pickable 的对象。 | |||||
:param group: | |||||
:return: -> List[obj0, obj1, ...] 其中 obj0 是rank 0 上的 obj;obj1 是 rank 1 上的 obj... | |||||
""" | |||||
if self.aggregate_when_get_metric: | |||||
return self.backend.all_gather_object(obj, group=group) | |||||
return [obj] |
@@ -264,15 +264,12 @@ class SpanFPreRecMetric(Metric): | |||||
evaluate_result = {} | evaluate_result = {} | ||||
# 通过 all_gather_object 将各个卡上的结果收集过来,并加和。 | # 通过 all_gather_object 将各个卡上的结果收集过来,并加和。 | ||||
if self.aggregate_when_get_metric: | |||||
ls = self.backend.all_gather_object([self._tp, self._fp, self._fn]) | |||||
tps, fps, fns = zip(*ls) | |||||
_tp, _fp, _fn = Counter(), Counter(), Counter() | |||||
for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]): | |||||
for _c in cs: | |||||
c.update(_c) | |||||
else: | |||||
_tp, _fp, _fn = self._tp, self._fp, self._tp | |||||
ls = self.all_gather_object([self._tp, self._fp, self._fn]) | |||||
tps, fps, fns = zip(*ls) | |||||
_tp, _fp, _fn = Counter(), Counter(), Counter() | |||||
for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]): | |||||
for _c in cs: | |||||
c.update(_c) | |||||
if not self.only_gross or self.f_type == 'macro': | if not self.only_gross or self.f_type == 'macro': | ||||
tags = set(_fn.keys()) | tags = set(_fn.keys()) | ||||
@@ -13,9 +13,8 @@ import numpy as np | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from .utils import create_array, NumConsumedSamplesArray | |||||
from .utils import create_array | |||||
from abc import abstractmethod | from abc import abstractmethod | ||||
from fastNLP.envs.env import FASTNLP_DEQUE_SIZE | |||||
class ReproducibleBatchSampler: | class ReproducibleBatchSampler: | ||||
@@ -37,9 +36,6 @@ class ReproducibleBatchSampler: | |||||
@abstractmethod | @abstractmethod | ||||
def state_dict(self): | def state_dict(self): | ||||
""" | """ | ||||
由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomBatchSampler 中 states 里面 num_consumed_samples_array 的实现 | |||||
正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward | |||||
了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。 | |||||
:return: | :return: | ||||
""" | """ | ||||
@@ -57,6 +53,14 @@ class ReproducibleBatchSampler: | |||||
def batch_idx_in_epoch(self): | def batch_idx_in_epoch(self): | ||||
raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.") | raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.") | ||||
@property | |||||
def num_replicas(self): | |||||
return self._num_replicas | |||||
@num_replicas.setter | |||||
def num_replicas(self, value): | |||||
self._num_replicas = value | |||||
class RandomBatchSampler(ReproducibleBatchSampler): | class RandomBatchSampler(ReproducibleBatchSampler): | ||||
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; | # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; | ||||
@@ -105,24 +109,21 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
else: | else: | ||||
index_list = self.index_list | index_list = self.index_list | ||||
# 记住每个 batch 对应的 consumed_samples, 需要这个原因是由于现在的 dataloader 都存在预取数据的设计,需要再结合Trainer中 | |||||
# 暂时弃用。记住每个 batch 对应的 consumed_samples, 需要这个原因是由于现在的 dataloader 都存在预取数据的设计,需要再结合Trainer中 | |||||
# batch_idx_in_epoch 才能最终确定实际消耗的数据。这个变量需要记录每次yield出去时的真实 num_consumed_samples 的数值。 | # batch_idx_in_epoch 才能最终确定实际消耗的数据。这个变量需要记录每次yield出去时的真实 num_consumed_samples 的数值。 | ||||
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30), | |||||
num_consumed_samples=self.num_consumed_samples) | |||||
# self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30), | |||||
# num_consumed_samples=self.num_consumed_samples) | |||||
for idx in index_list: | for idx in index_list: | ||||
batch.append(idx) | batch.append(idx) | ||||
if len(batch) == self.batch_size: | if len(batch) == self.batch_size: | ||||
self.num_consumed_samples += self.batch_size # [16, 32, 48, 64,..., ] | self.num_consumed_samples += self.batch_size # [16, 32, 48, 64,..., ] | ||||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||||
yield batch | yield batch | ||||
batch = [] | batch = [] | ||||
if len(batch) > 0 and not self.drop_last: | if len(batch) > 0 and not self.drop_last: | ||||
self.num_consumed_samples += len(batch) | self.num_consumed_samples += len(batch) | ||||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||||
yield batch | yield batch | ||||
# 需要重置防止边界条件问题 | # 需要重置防止边界条件问题 | ||||
self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
delattr(self, 'num_consumed_samples_array') | |||||
def __len__(self) -> int: | def __len__(self) -> int: | ||||
if self.drop_last: | if self.drop_last: | ||||
@@ -136,7 +137,6 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
"num_consumed_samples": self.num_consumed_samples, | "num_consumed_samples": self.num_consumed_samples, | ||||
'sampler_type': self.__class__.__name__ | 'sampler_type': self.__class__.__name__ | ||||
} | } | ||||
states['num_consumed_samples_array'] = getattr(self, 'num_consumed_samples_array', None) | |||||
return states | return states | ||||
def load_state_dict(self, states: Dict): | def load_state_dict(self, states: Dict): | ||||
@@ -327,15 +327,14 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: | if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: | ||||
batches = batches[:-1] | batches = batches[:-1] | ||||
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30), | |||||
num_consumed_samples=self.num_consumed_samples) | |||||
# 暂时弃用 | |||||
# self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30), | |||||
# num_consumed_samples=self.num_consumed_samples) | |||||
for batch in batches: | for batch in batches: | ||||
self.num_consumed_samples += self.num_replicas * len(batch) | self.num_consumed_samples += self.num_replicas * len(batch) | ||||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||||
yield list(map(int, batch)) | yield list(map(int, batch)) | ||||
self.during_iter = False | self.during_iter = False | ||||
self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
delattr(self, 'num_consumed_samples_array') | |||||
self.old_batch_size = self.batch_size | self.old_batch_size = self.batch_size | ||||
self.old_num_batch_per_bucket = self.num_batch_per_bucket | self.old_num_batch_per_bucket = self.num_batch_per_bucket | ||||
self.old_num_replicas = self.num_replicas | self.old_num_replicas = self.num_replicas | ||||
@@ -390,8 +389,8 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | ||||
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, | 'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, | ||||
'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket, | 'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket, | ||||
'num_replicas': self.num_replicas, | |||||
'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} | |||||
'num_replicas': self.num_replicas | |||||
} | |||||
return states | return states | ||||
@@ -7,15 +7,11 @@ __all__ = [ | |||||
from typing import Dict, List, Union | from typing import Dict, List, Union | ||||
import math | import math | ||||
import os | |||||
import numpy as np | import numpy as np | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.envs.env import FASTNLP_DEQUE_SIZE | |||||
from .utils import NumConsumedSamplesArray | |||||
class ReproducibleSampler: | class ReproducibleSampler: | ||||
@@ -36,9 +32,6 @@ class ReproducibleSampler: | |||||
def state_dict(self): | def state_dict(self): | ||||
""" | """ | ||||
由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomSampler 中 states 里面 num_consumed_samples_array 的实现 | |||||
正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward | |||||
了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。 | |||||
:return: | :return: | ||||
""" | """ | ||||
@@ -54,6 +47,14 @@ class ReproducibleSampler: | |||||
def set_epoch(self, epoch): | def set_epoch(self, epoch): | ||||
pass | pass | ||||
@property | |||||
def num_repliacs(self): | |||||
return self._num_replicas | |||||
@num_repliacs.setter | |||||
def num_repliacs(self, value): | |||||
self._num_replicas = value | |||||
class RandomSampler(ReproducibleSampler): | class RandomSampler(ReproducibleSampler): | ||||
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): | def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): | ||||
@@ -121,15 +122,11 @@ class RandomSampler(ReproducibleSampler): | |||||
indices = indices[self.num_consumed_samples:] | indices = indices[self.num_consumed_samples:] | ||||
indices = indices[self.rank:len(indices):self.num_replicas] | indices = indices[self.rank:len(indices):self.num_replicas] | ||||
assert len(indices) == self.num_left_samples | assert len(indices) == self.num_left_samples | ||||
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), | |||||
num_consumed_samples=self.num_consumed_samples) | |||||
for idx, index in enumerate(indices, start=1): | for idx, index in enumerate(indices, start=1): | ||||
self.num_consumed_samples += self.num_replicas | self.num_consumed_samples += self.num_replicas | ||||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||||
yield index | yield index | ||||
self.during_iter = False | self.during_iter = False | ||||
self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
delattr(self, 'num_consumed_samples_array') | |||||
def generate_indices(self) -> List[int]: | def generate_indices(self) -> List[int]: | ||||
""" | """ | ||||
@@ -150,8 +147,7 @@ class RandomSampler(ReproducibleSampler): | |||||
def state_dict(self) -> Dict: | def state_dict(self) -> Dict: | ||||
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | ||||
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, | |||||
'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} | |||||
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle} | |||||
return states | return states | ||||
def load_state_dict(self, states: Dict): | def load_state_dict(self, states: Dict): | ||||
@@ -255,15 +251,11 @@ class SequentialSampler(RandomSampler): | |||||
indices = indices[self.rank:len(indices):self.num_replicas] | indices = indices[self.rank:len(indices):self.num_replicas] | ||||
assert len(indices) == self.num_left_samples | assert len(indices) == self.num_left_samples | ||||
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), | |||||
num_consumed_samples=self.num_consumed_samples) | |||||
for idx, index in enumerate(indices, start=1): | for idx, index in enumerate(indices, start=1): | ||||
self.num_consumed_samples += self.num_replicas | self.num_consumed_samples += self.num_replicas | ||||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||||
yield index | yield index | ||||
self.during_iter = False | self.during_iter = False | ||||
self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
delattr(self, 'num_consumed_samples_array') | |||||
def generate_indices(self) -> List[int]: | def generate_indices(self) -> List[int]: | ||||
""" | """ | ||||
@@ -275,8 +267,8 @@ class SequentialSampler(RandomSampler): | |||||
def state_dict(self) -> Dict: | def state_dict(self) -> Dict: | ||||
states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, | states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, | ||||
'length': len(self.dataset), | |||||
'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} | |||||
'length': len(self.dataset) | |||||
} | |||||
return states | return states | ||||
def load_state_dict(self, states: Dict): | def load_state_dict(self, states: Dict): | ||||
@@ -346,13 +338,9 @@ class SortedSampler(SequentialSampler): | |||||
indices = indices[self.rank:len(indices):self.num_replicas] | indices = indices[self.rank:len(indices):self.num_replicas] | ||||
assert len(indices) == self.num_left_samples | assert len(indices) == self.num_left_samples | ||||
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), | |||||
num_consumed_samples=self.num_consumed_samples) | |||||
for idx, index in enumerate(indices, start=1): | for idx, index in enumerate(indices, start=1): | ||||
self.num_consumed_samples += self.num_replicas | self.num_consumed_samples += self.num_replicas | ||||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||||
yield index | yield index | ||||
self.during_iter = False | self.during_iter = False | ||||
self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
delattr(self, 'num_consumed_samples_array') | |||||
@@ -43,6 +43,8 @@ class NumConsumedSamplesArray: | |||||
array[9] # 输出为9,表示这个位置真实的 num_consumed_samples 是多少。 | array[9] # 输出为9,表示这个位置真实的 num_consumed_samples 是多少。 | ||||
array[6] # 报错,因为只保留了3个最近的数据,6超过了最大buffer的记录了,即 [7, 8, 9] | array[6] # 报错,因为只保留了3个最近的数据,6超过了最大buffer的记录了,即 [7, 8, 9] | ||||
暂时由于 sampler 的 batch 都是规整的,先保留 | |||||
:param buffer_size: 报错多少个历史。 | :param buffer_size: 报错多少个历史。 | ||||
:param num_consumed_samples: 第一个 num_consumed_samples 是多少。 | :param num_consumed_samples: 第一个 num_consumed_samples 是多少。 | ||||
""" | """ | ||||
@@ -45,9 +45,6 @@ FASTNLP_REMOVE_LOCAL_RANK = 'FASTNLP_REMOVE_LOCAL_RANK' | |||||
# todo 注释 | # todo 注释 | ||||
FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH" | FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH" | ||||
# fastNLP 中初始化deque的默认大小 | |||||
FASTNLP_DEQUE_SIZE = 'FASTNLP_DEQUE_SIZE' | |||||
# fastNLP中用于关闭 fastNLP 1.barrier 与 2.gather/broadcast 。默认为 '0' 表示不关闭;为 '1' 表示 fastNLP 的 barrier 不执行; | # fastNLP中用于关闭 fastNLP 1.barrier 与 2.gather/broadcast 。默认为 '0' 表示不关闭;为 '1' 表示 fastNLP 的 barrier 不执行; | ||||
# 为 '2' 表示 barrier 与 gather/broadcast 都关闭。 | # 为 '2' 表示 barrier 与 gather/broadcast 都关闭。 | ||||
FASTNLP_NO_SYNC = 'FASTNLP_NO_SYNC' | FASTNLP_NO_SYNC = 'FASTNLP_NO_SYNC' | ||||
@@ -445,8 +445,7 @@ class TestBucketedBatchSampler: | |||||
@pytest.mark.parametrize('num_replicas', [1, 2, 3]) | @pytest.mark.parametrize('num_replicas', [1, 2, 3]) | ||||
def test_multi_save_load(self, shuffle, drop_last, pad, num_samples, num_replicas): | def test_multi_save_load(self, shuffle, drop_last, pad, num_samples, num_replicas): | ||||
""" | """ | ||||
测试是否能够正确地恢复使用过的(forward)数据,由于 DataLoader 存在预取,所以 Sampler 自身的 num_consumed_samples 可能 | |||||
偏多 | |||||
测试是否能够正确地恢复使用过的(forward)数据 | |||||
:return: | :return: | ||||
""" | """ | ||||
@@ -454,6 +453,7 @@ class TestBucketedBatchSampler: | |||||
num_batch_per_bucket = 10 | num_batch_per_bucket = 10 | ||||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | dataset = DatasetWithVaryLength(num_of_data=num_samples) | ||||
samplers = [] | samplers = [] | ||||
num_consumed_samples_array = list(range(0, num_samples+num_replicas, num_replicas)) | |||||
for i in range(num_replicas): | for i in range(num_replicas): | ||||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size, | sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size, | ||||
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) | num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) | ||||
@@ -472,8 +472,7 @@ class TestBucketedBatchSampler: | |||||
break | break | ||||
states = samplers[0].state_dict() | states = samplers[0].state_dict() | ||||
for i in range(len(already_seen_sets)): | for i in range(len(already_seen_sets)): | ||||
if states['num_consumed_samples_array'] is not None: | |||||
states['num_consumed_samples'] = states['num_consumed_samples_array'][i] | |||||
states['num_consumed_samples'] = num_consumed_samples_array[i] | |||||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size+1, | sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size+1, | ||||
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, | num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, | ||||
drop_last=drop_last) | drop_last=drop_last) | ||||
@@ -489,8 +488,7 @@ class TestBucketedBatchSampler: | |||||
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, | num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, | ||||
drop_last=drop_last) | drop_last=drop_last) | ||||
sampler.set_epoch(0) | sampler.set_epoch(0) | ||||
if states['num_consumed_samples_array'] is not None: | |||||
states['num_consumed_samples'] = states['num_consumed_samples_array'][2] | |||||
states['num_consumed_samples'] = num_consumed_samples_array[2] | |||||
if len(already_seen_sets)<3: | if len(already_seen_sets)<3: | ||||
return | return | ||||
already_seen_set = already_seen_sets[2] | already_seen_set = already_seen_sets[2] | ||||
@@ -502,8 +500,8 @@ class TestBucketedBatchSampler: | |||||
break | break | ||||
states = sampler.state_dict() | states = sampler.state_dict() | ||||
if states['num_consumed_samples_array'] is not None: | |||||
states['num_consumed_samples'] = states['num_consumed_samples_array'][count] | |||||
num_consumed_samples_array = list(range(len(dataset))) | |||||
states['num_consumed_samples'] = num_consumed_samples_array[count] | |||||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size//2, | sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size//2, | ||||
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, | num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, | ||||
drop_last=drop_last) | drop_last=drop_last) | ||||
@@ -186,9 +186,11 @@ class TestRandomSamplerYh: | |||||
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | @pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | ||||
@pytest.mark.parametrize('num_replicas', [1, 2, 3]) | @pytest.mark.parametrize('num_replicas', [1, 2, 3]) | ||||
def test_num_consumed_samples_array(self, shuffle, pad, num_samples, num_replicas): | def test_num_consumed_samples_array(self, shuffle, pad, num_samples, num_replicas): | ||||
# def test_num_consumed_samples_array(self, shuffle=True, pad=True, num_samples=100, num_replicas=2): | |||||
# 测试在 sampler 多生成的时候,可以仍然可以恢复 | # 测试在 sampler 多生成的时候,可以仍然可以恢复 | ||||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | dataset = DatasetWithVaryLength(num_of_data=num_samples) | ||||
samplers = [] | samplers = [] | ||||
num_consumed_samples_array = list(range(0, len(dataset)+num_replicas, num_replicas)) | |||||
for i in range(num_replicas): | for i in range(num_replicas): | ||||
sampler = RandomSampler(dataset, shuffle=shuffle) | sampler = RandomSampler(dataset, shuffle=shuffle) | ||||
sampler.set_epoch(0) | sampler.set_epoch(0) | ||||
@@ -205,8 +207,7 @@ class TestRandomSamplerYh: | |||||
break | break | ||||
states = samplers[0].state_dict() | states = samplers[0].state_dict() | ||||
for i in range(len(already_seen_sets)): | for i in range(len(already_seen_sets)): | ||||
if states['num_consumed_samples_array'] is not None: | |||||
states['num_consumed_samples'] = states['num_consumed_samples_array'][i] | |||||
states['num_consumed_samples'] = num_consumed_samples_array[i] | |||||
sampler = RandomSampler(dataset, shuffle=shuffle) | sampler = RandomSampler(dataset, shuffle=shuffle) | ||||
already_seen_set = deepcopy(already_seen_sets[i]) | already_seen_set = deepcopy(already_seen_sets[i]) | ||||
for batch in sampler: | for batch in sampler: | ||||
@@ -215,12 +216,11 @@ class TestRandomSamplerYh: | |||||
# 测试保存之后再次保存 | # 测试保存之后再次保存 | ||||
sampler = RandomSampler(dataset, shuffle=shuffle) | sampler = RandomSampler(dataset, shuffle=shuffle) | ||||
sampler.set_epoch(0) | sampler.set_epoch(0) | ||||
if states['num_consumed_samples_array'] is not None: | |||||
states['num_consumed_samples'] = states['num_consumed_samples_array'][2] | |||||
if len(already_seen_sets)<3: | if len(already_seen_sets)<3: | ||||
return | return | ||||
already_seen_set = already_seen_sets[2] | already_seen_set = already_seen_sets[2] | ||||
count = 0 | count = 0 | ||||
num_consumed_samples_array = list(range(0, num_samples)) | |||||
for idx in sampler: | for idx in sampler: | ||||
already_seen_set.add(idx) | already_seen_set.add(idx) | ||||
count += 1 | count += 1 | ||||
@@ -228,8 +228,7 @@ class TestRandomSamplerYh: | |||||
break | break | ||||
states = sampler.state_dict() | states = sampler.state_dict() | ||||
if states['num_consumed_samples_array'] is not None: | |||||
states['num_consumed_samples'] = states['num_consumed_samples_array'][count] | |||||
states['num_consumed_samples'] = num_consumed_samples_array[count] | |||||
sampler = RandomSampler(dataset, shuffle=shuffle) | sampler = RandomSampler(dataset, shuffle=shuffle) | ||||
sampler.load_state_dict(states) | sampler.load_state_dict(states) | ||||
sampler.set_epoch(0) | sampler.set_epoch(0) | ||||
@@ -446,12 +445,12 @@ class TestSortedSampler: | |||||
@pytest.mark.parametrize('pad', [True, False]) | @pytest.mark.parametrize('pad', [True, False]) | ||||
@pytest.mark.parametrize('num_replicas', [2, 3]) | @pytest.mark.parametrize('num_replicas', [2, 3]) | ||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
def test_multi(self, pad, num_replica, num_of_data): | |||||
def test_multi(self, pad, num_replicas, num_of_data): | |||||
data = DatasetWithVaryLength(num_of_data=num_of_data) | data = DatasetWithVaryLength(num_of_data=num_of_data) | ||||
samplers = [] | samplers = [] | ||||
for i in range(num_replica): | |||||
for i in range(num_replicas): | |||||
sampler = SortedSampler(dataset=data, length=data.data) | sampler = SortedSampler(dataset=data, length=data.data) | ||||
sampler.set_distributed(num_replica, rank=i, pad=pad) | |||||
sampler.set_distributed(num_replicas, rank=i, pad=pad) | |||||
samplers.append(sampler) | samplers.append(sampler) | ||||
# 保证顺序是没乱的 | # 保证顺序是没乱的 | ||||
@@ -600,12 +599,12 @@ class TestSequentialSampler: | |||||
@pytest.mark.parametrize('pad', [True, False]) | @pytest.mark.parametrize('pad', [True, False]) | ||||
@pytest.mark.parametrize('num_replicas', [2, 3]) | @pytest.mark.parametrize('num_replicas', [2, 3]) | ||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
def test_multi(self, pad, num_replica, num_of_data): | |||||
def test_multi(self, pad, num_replicas, num_of_data): | |||||
data = DatasetWithVaryLength(num_of_data=num_of_data) | data = DatasetWithVaryLength(num_of_data=num_of_data) | ||||
samplers = [] | samplers = [] | ||||
for i in range(num_replica): | |||||
for i in range(num_replicas): | |||||
sampler = SequentialSampler(dataset=data) | sampler = SequentialSampler(dataset=data) | ||||
sampler.set_distributed(num_replica, rank=i, pad=pad) | |||||
sampler.set_distributed(num_replicas, rank=i, pad=pad) | |||||
samplers.append(sampler) | samplers.append(sampler) | ||||
# 保证顺序是没乱的 | # 保证顺序是没乱的 | ||||