Browse Source

1.删除sampler中num_consumed_samples_array,转为通过batch_size等进行换算。2.在Metric中新增all_gather_object接口,方便评测

tags/v1.0.0alpha
yh_cc 2 years ago
parent
commit
d813f31f9a
10 changed files with 86 additions and 87 deletions
  1. +14
    -4
      fastNLP/core/drivers/torch_driver/torch_driver.py
  2. +6
    -9
      fastNLP/core/metrics/classify_f1_pre_rec_metric.py
  3. +13
    -1
      fastNLP/core/metrics/metric.py
  4. +6
    -9
      fastNLP/core/metrics/span_f1_pre_rec_metric.py
  5. +17
    -18
      fastNLP/core/samplers/reproducible_batch_sampler.py
  6. +11
    -23
      fastNLP/core/samplers/reproducible_sampler.py
  7. +2
    -0
      fastNLP/core/samplers/utils.py
  8. +0
    -3
      fastNLP/envs/env.py
  9. +6
    -8
      tests/core/samplers/test_reproducible_batch_sampler.py
  10. +11
    -12
      tests/core/samplers/test_reproducible_sampler.py

+ 14
- 4
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -188,17 +188,27 @@ class TorchDriver(Driver):
num_consumed_batches = states.pop('num_consumed_batches') num_consumed_batches = states.pop('num_consumed_batches')
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): if hasattr(sampler, 'state_dict') and callable(sampler.state_dict):
sampler_states = sampler.state_dict() sampler_states = sampler.state_dict()
# 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples
# 会造成多余实际消耗的问题。
# 需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples
# 会造成多余实际消耗的问题。因为
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)
if num_consumed_samples_array is not None: if num_consumed_samples_array is not None:
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。
try:
if dataloader_args.batch_size is not None:
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size num_consumed_batches = num_consumed_batches * dataloader_args.batch_size
except: # 有可能 batch_size 为 None,就只有损失精度了
else: # 有可能 batch_size 为 None,就只有损失精度了
logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")
num_consumed_batches = sampler_states['num_consumed_samples'] num_consumed_batches = sampler_states['num_consumed_samples']
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches]
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report."
else:
if dataloader_args.batch_size is not None:
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \
* num_consumed_batches
else:
logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")

states['sampler_states'] = sampler_states states['sampler_states'] = sampler_states
else: else:
raise RuntimeError( raise RuntimeError(


+ 6
- 9
fastNLP/core/metrics/classify_f1_pre_rec_metric.py View File

@@ -63,15 +63,12 @@ class ClassifyFPreRecMetric(Metric):
evaluate_result = {} evaluate_result = {}


# 通过 all_gather_object 将各个卡上的结果收集过来,并加和。 # 通过 all_gather_object 将各个卡上的结果收集过来,并加和。
if self.aggregate_when_get_metric:
ls = self.backend.all_gather_object([self._tp, self._fp, self._fn])
tps, fps, fns = zip(*ls)
_tp, _fp, _fn = Counter(), Counter(), Counter()
for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]):
for _c in cs:
c.update(_c)
else:
_tp, _fp, _fn = self._tp, self._fp, self._tp
ls = self.all_gather_object([self._tp, self._fp, self._fn])
tps, fps, fns = zip(*ls)
_tp, _fp, _fn = Counter(), Counter(), Counter()
for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]):
for _c in cs:
c.update(_c)


if not self.only_gross or self.f_type == 'macro': if not self.only_gross or self.f_type == 'macro':
tags = set(_fn.keys()) tags = set(_fn.keys())


+ 13
- 1
fastNLP/core/metrics/metric.py View File

@@ -4,7 +4,7 @@ __all__ = [


from abc import abstractmethod from abc import abstractmethod


from typing import Union
from typing import Union, List
import functools import functools
from contextlib import contextmanager from contextlib import contextmanager
import numpy as np import numpy as np
@@ -180,3 +180,15 @@ class Metric:
""" """
for element in self.elements.values(): for element in self.elements.values():
element.to(device) element.to(device)

def all_gather_object(self, obj, group=None)->List:
"""
给定 obj 将各个 rank 上的 obj 汇总到每个 obj 上。返回一个 list 对象,里面依次为各个 rank 对应的 obj 。

:param obj: 需要汇总的对象,必须是个 pickable 的对象。
:param group:
:return: -> List[obj0, obj1, ...] 其中 obj0 是rank 0 上的 obj;obj1 是 rank 1 上的 obj...
"""
if self.aggregate_when_get_metric:
return self.backend.all_gather_object(obj, group=group)
return [obj]

+ 6
- 9
fastNLP/core/metrics/span_f1_pre_rec_metric.py View File

@@ -264,15 +264,12 @@ class SpanFPreRecMetric(Metric):
evaluate_result = {} evaluate_result = {}


# 通过 all_gather_object 将各个卡上的结果收集过来,并加和。 # 通过 all_gather_object 将各个卡上的结果收集过来,并加和。
if self.aggregate_when_get_metric:
ls = self.backend.all_gather_object([self._tp, self._fp, self._fn])
tps, fps, fns = zip(*ls)
_tp, _fp, _fn = Counter(), Counter(), Counter()
for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]):
for _c in cs:
c.update(_c)
else:
_tp, _fp, _fn = self._tp, self._fp, self._tp
ls = self.all_gather_object([self._tp, self._fp, self._fn])
tps, fps, fns = zip(*ls)
_tp, _fp, _fn = Counter(), Counter(), Counter()
for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]):
for _c in cs:
c.update(_c)


if not self.only_gross or self.f_type == 'macro': if not self.only_gross or self.f_type == 'macro':
tags = set(_fn.keys()) tags = set(_fn.keys())


+ 17
- 18
fastNLP/core/samplers/reproducible_batch_sampler.py View File

@@ -13,9 +13,8 @@ import numpy as np


from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.log import logger from fastNLP.core.log import logger
from .utils import create_array, NumConsumedSamplesArray
from .utils import create_array
from abc import abstractmethod from abc import abstractmethod
from fastNLP.envs.env import FASTNLP_DEQUE_SIZE




class ReproducibleBatchSampler: class ReproducibleBatchSampler:
@@ -37,9 +36,6 @@ class ReproducibleBatchSampler:
@abstractmethod @abstractmethod
def state_dict(self): def state_dict(self):
""" """
由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomBatchSampler 中 states 里面 num_consumed_samples_array 的实现
正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward
了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。


:return: :return:
""" """
@@ -57,6 +53,14 @@ class ReproducibleBatchSampler:
def batch_idx_in_epoch(self): def batch_idx_in_epoch(self):
raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.") raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.")


@property
def num_replicas(self):
return self._num_replicas

@num_replicas.setter
def num_replicas(self, value):
self._num_replicas = value



class RandomBatchSampler(ReproducibleBatchSampler): class RandomBatchSampler(ReproducibleBatchSampler):
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿;
@@ -105,24 +109,21 @@ class RandomBatchSampler(ReproducibleBatchSampler):
else: else:
index_list = self.index_list index_list = self.index_list


# 记住每个 batch 对应的 consumed_samples, 需要这个原因是由于现在的 dataloader 都存在预取数据的设计,需要再结合Trainer中
# 暂时弃用。记住每个 batch 对应的 consumed_samples, 需要这个原因是由于现在的 dataloader 都存在预取数据的设计,需要再结合Trainer中
# batch_idx_in_epoch 才能最终确定实际消耗的数据。这个变量需要记录每次yield出去时的真实 num_consumed_samples 的数值。 # batch_idx_in_epoch 才能最终确定实际消耗的数据。这个变量需要记录每次yield出去时的真实 num_consumed_samples 的数值。
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30),
num_consumed_samples=self.num_consumed_samples)
# self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30),
# num_consumed_samples=self.num_consumed_samples)
for idx in index_list: for idx in index_list:
batch.append(idx) batch.append(idx)
if len(batch) == self.batch_size: if len(batch) == self.batch_size:
self.num_consumed_samples += self.batch_size # [16, 32, 48, 64,..., ] self.num_consumed_samples += self.batch_size # [16, 32, 48, 64,..., ]
self.num_consumed_samples_array.push(self.num_consumed_samples)
yield batch yield batch
batch = [] batch = []
if len(batch) > 0 and not self.drop_last: if len(batch) > 0 and not self.drop_last:
self.num_consumed_samples += len(batch) self.num_consumed_samples += len(batch)
self.num_consumed_samples_array.push(self.num_consumed_samples)
yield batch yield batch
# 需要重置防止边界条件问题 # 需要重置防止边界条件问题
self.num_consumed_samples = 0 self.num_consumed_samples = 0
delattr(self, 'num_consumed_samples_array')


def __len__(self) -> int: def __len__(self) -> int:
if self.drop_last: if self.drop_last:
@@ -136,7 +137,6 @@ class RandomBatchSampler(ReproducibleBatchSampler):
"num_consumed_samples": self.num_consumed_samples, "num_consumed_samples": self.num_consumed_samples,
'sampler_type': self.__class__.__name__ 'sampler_type': self.__class__.__name__
} }
states['num_consumed_samples_array'] = getattr(self, 'num_consumed_samples_array', None)
return states return states


def load_state_dict(self, states: Dict): def load_state_dict(self, states: Dict):
@@ -327,15 +327,14 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size:
batches = batches[:-1] batches = batches[:-1]


self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30),
num_consumed_samples=self.num_consumed_samples)
# 暂时弃用
# self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30),
# num_consumed_samples=self.num_consumed_samples)
for batch in batches: for batch in batches:
self.num_consumed_samples += self.num_replicas * len(batch) self.num_consumed_samples += self.num_replicas * len(batch)
self.num_consumed_samples_array.push(self.num_consumed_samples)
yield list(map(int, batch)) yield list(map(int, batch))
self.during_iter = False self.during_iter = False
self.num_consumed_samples = 0 self.num_consumed_samples = 0
delattr(self, 'num_consumed_samples_array')
self.old_batch_size = self.batch_size self.old_batch_size = self.batch_size
self.old_num_batch_per_bucket = self.num_batch_per_bucket self.old_num_batch_per_bucket = self.num_batch_per_bucket
self.old_num_replicas = self.num_replicas self.old_num_replicas = self.num_replicas
@@ -390,8 +389,8 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples,
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, 'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle,
'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket, 'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket,
'num_replicas': self.num_replicas,
'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)}
'num_replicas': self.num_replicas
}


return states return states




+ 11
- 23
fastNLP/core/samplers/reproducible_sampler.py View File

@@ -7,15 +7,11 @@ __all__ = [


from typing import Dict, List, Union from typing import Dict, List, Union
import math import math
import os


import numpy as np import numpy as np


from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.envs.env import FASTNLP_DEQUE_SIZE
from .utils import NumConsumedSamplesArray





class ReproducibleSampler: class ReproducibleSampler:
@@ -36,9 +32,6 @@ class ReproducibleSampler:


def state_dict(self): def state_dict(self):
""" """
由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomSampler 中 states 里面 num_consumed_samples_array 的实现
正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward
了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。


:return: :return:
""" """
@@ -54,6 +47,14 @@ class ReproducibleSampler:
def set_epoch(self, epoch): def set_epoch(self, epoch):
pass pass


@property
def num_repliacs(self):
return self._num_replicas

@num_repliacs.setter
def num_repliacs(self, value):
self._num_replicas = value



class RandomSampler(ReproducibleSampler): class RandomSampler(ReproducibleSampler):
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs):
@@ -121,15 +122,11 @@ class RandomSampler(ReproducibleSampler):
indices = indices[self.num_consumed_samples:] indices = indices[self.num_consumed_samples:]
indices = indices[self.rank:len(indices):self.num_replicas] indices = indices[self.rank:len(indices):self.num_replicas]
assert len(indices) == self.num_left_samples assert len(indices) == self.num_left_samples
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000),
num_consumed_samples=self.num_consumed_samples)
for idx, index in enumerate(indices, start=1): for idx, index in enumerate(indices, start=1):
self.num_consumed_samples += self.num_replicas self.num_consumed_samples += self.num_replicas
self.num_consumed_samples_array.push(self.num_consumed_samples)
yield index yield index
self.during_iter = False self.during_iter = False
self.num_consumed_samples = 0 self.num_consumed_samples = 0
delattr(self, 'num_consumed_samples_array')


def generate_indices(self) -> List[int]: def generate_indices(self) -> List[int]:
""" """
@@ -150,8 +147,7 @@ class RandomSampler(ReproducibleSampler):


def state_dict(self) -> Dict: def state_dict(self) -> Dict:
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples,
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle,
'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)}
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle}
return states return states


def load_state_dict(self, states: Dict): def load_state_dict(self, states: Dict):
@@ -255,15 +251,11 @@ class SequentialSampler(RandomSampler):
indices = indices[self.rank:len(indices):self.num_replicas] indices = indices[self.rank:len(indices):self.num_replicas]
assert len(indices) == self.num_left_samples assert len(indices) == self.num_left_samples


self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000),
num_consumed_samples=self.num_consumed_samples)
for idx, index in enumerate(indices, start=1): for idx, index in enumerate(indices, start=1):
self.num_consumed_samples += self.num_replicas self.num_consumed_samples += self.num_replicas
self.num_consumed_samples_array.push(self.num_consumed_samples)
yield index yield index
self.during_iter = False self.during_iter = False
self.num_consumed_samples = 0 self.num_consumed_samples = 0
delattr(self, 'num_consumed_samples_array')


def generate_indices(self) -> List[int]: def generate_indices(self) -> List[int]:
""" """
@@ -275,8 +267,8 @@ class SequentialSampler(RandomSampler):


def state_dict(self) -> Dict: def state_dict(self) -> Dict:
states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__,
'length': len(self.dataset),
'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)}
'length': len(self.dataset)
}
return states return states


def load_state_dict(self, states: Dict): def load_state_dict(self, states: Dict):
@@ -346,13 +338,9 @@ class SortedSampler(SequentialSampler):
indices = indices[self.rank:len(indices):self.num_replicas] indices = indices[self.rank:len(indices):self.num_replicas]
assert len(indices) == self.num_left_samples assert len(indices) == self.num_left_samples


self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000),
num_consumed_samples=self.num_consumed_samples)
for idx, index in enumerate(indices, start=1): for idx, index in enumerate(indices, start=1):
self.num_consumed_samples += self.num_replicas self.num_consumed_samples += self.num_replicas
self.num_consumed_samples_array.push(self.num_consumed_samples)
yield index yield index
self.during_iter = False self.during_iter = False
self.num_consumed_samples = 0 self.num_consumed_samples = 0
delattr(self, 'num_consumed_samples_array')



+ 2
- 0
fastNLP/core/samplers/utils.py View File

@@ -43,6 +43,8 @@ class NumConsumedSamplesArray:
array[9] # 输出为9,表示这个位置真实的 num_consumed_samples 是多少。 array[9] # 输出为9,表示这个位置真实的 num_consumed_samples 是多少。
array[6] # 报错,因为只保留了3个最近的数据,6超过了最大buffer的记录了,即 [7, 8, 9] array[6] # 报错,因为只保留了3个最近的数据,6超过了最大buffer的记录了,即 [7, 8, 9]


暂时由于 sampler 的 batch 都是规整的,先保留

:param buffer_size: 报错多少个历史。 :param buffer_size: 报错多少个历史。
:param num_consumed_samples: 第一个 num_consumed_samples 是多少。 :param num_consumed_samples: 第一个 num_consumed_samples 是多少。
""" """


+ 0
- 3
fastNLP/envs/env.py View File

@@ -45,9 +45,6 @@ FASTNLP_REMOVE_LOCAL_RANK = 'FASTNLP_REMOVE_LOCAL_RANK'
# todo 注释 # todo 注释
FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH" FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH"


# fastNLP 中初始化deque的默认大小
FASTNLP_DEQUE_SIZE = 'FASTNLP_DEQUE_SIZE'

# fastNLP中用于关闭 fastNLP 1.barrier 与 2.gather/broadcast 。默认为 '0' 表示不关闭;为 '1' 表示 fastNLP 的 barrier 不执行; # fastNLP中用于关闭 fastNLP 1.barrier 与 2.gather/broadcast 。默认为 '0' 表示不关闭;为 '1' 表示 fastNLP 的 barrier 不执行;
# 为 '2' 表示 barrier 与 gather/broadcast 都关闭。 # 为 '2' 表示 barrier 与 gather/broadcast 都关闭。
FASTNLP_NO_SYNC = 'FASTNLP_NO_SYNC' FASTNLP_NO_SYNC = 'FASTNLP_NO_SYNC'


+ 6
- 8
tests/core/samplers/test_reproducible_batch_sampler.py View File

@@ -445,8 +445,7 @@ class TestBucketedBatchSampler:
@pytest.mark.parametrize('num_replicas', [1, 2, 3]) @pytest.mark.parametrize('num_replicas', [1, 2, 3])
def test_multi_save_load(self, shuffle, drop_last, pad, num_samples, num_replicas): def test_multi_save_load(self, shuffle, drop_last, pad, num_samples, num_replicas):
""" """
测试是否能够正确地恢复使用过的(forward)数据,由于 DataLoader 存在预取,所以 Sampler 自身的 num_consumed_samples 可能
偏多
测试是否能够正确地恢复使用过的(forward)数据


:return: :return:
""" """
@@ -454,6 +453,7 @@ class TestBucketedBatchSampler:
num_batch_per_bucket = 10 num_batch_per_bucket = 10
dataset = DatasetWithVaryLength(num_of_data=num_samples) dataset = DatasetWithVaryLength(num_of_data=num_samples)
samplers = [] samplers = []
num_consumed_samples_array = list(range(0, num_samples+num_replicas, num_replicas))
for i in range(num_replicas): for i in range(num_replicas):
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size, sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size,
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last)
@@ -472,8 +472,7 @@ class TestBucketedBatchSampler:
break break
states = samplers[0].state_dict() states = samplers[0].state_dict()
for i in range(len(already_seen_sets)): for i in range(len(already_seen_sets)):
if states['num_consumed_samples_array'] is not None:
states['num_consumed_samples'] = states['num_consumed_samples_array'][i]
states['num_consumed_samples'] = num_consumed_samples_array[i]
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size+1, sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size+1,
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle,
drop_last=drop_last) drop_last=drop_last)
@@ -489,8 +488,7 @@ class TestBucketedBatchSampler:
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle,
drop_last=drop_last) drop_last=drop_last)
sampler.set_epoch(0) sampler.set_epoch(0)
if states['num_consumed_samples_array'] is not None:
states['num_consumed_samples'] = states['num_consumed_samples_array'][2]
states['num_consumed_samples'] = num_consumed_samples_array[2]
if len(already_seen_sets)<3: if len(already_seen_sets)<3:
return return
already_seen_set = already_seen_sets[2] already_seen_set = already_seen_sets[2]
@@ -502,8 +500,8 @@ class TestBucketedBatchSampler:
break break


states = sampler.state_dict() states = sampler.state_dict()
if states['num_consumed_samples_array'] is not None:
states['num_consumed_samples'] = states['num_consumed_samples_array'][count]
num_consumed_samples_array = list(range(len(dataset)))
states['num_consumed_samples'] = num_consumed_samples_array[count]
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size//2, sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size//2,
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle,
drop_last=drop_last) drop_last=drop_last)


+ 11
- 12
tests/core/samplers/test_reproducible_sampler.py View File

@@ -186,9 +186,11 @@ class TestRandomSamplerYh:
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) @pytest.mark.parametrize('num_samples', [13, 100, 623, 1000])
@pytest.mark.parametrize('num_replicas', [1, 2, 3]) @pytest.mark.parametrize('num_replicas', [1, 2, 3])
def test_num_consumed_samples_array(self, shuffle, pad, num_samples, num_replicas): def test_num_consumed_samples_array(self, shuffle, pad, num_samples, num_replicas):
# def test_num_consumed_samples_array(self, shuffle=True, pad=True, num_samples=100, num_replicas=2):
# 测试在 sampler 多生成的时候,可以仍然可以恢复 # 测试在 sampler 多生成的时候,可以仍然可以恢复
dataset = DatasetWithVaryLength(num_of_data=num_samples) dataset = DatasetWithVaryLength(num_of_data=num_samples)
samplers = [] samplers = []
num_consumed_samples_array = list(range(0, len(dataset)+num_replicas, num_replicas))
for i in range(num_replicas): for i in range(num_replicas):
sampler = RandomSampler(dataset, shuffle=shuffle) sampler = RandomSampler(dataset, shuffle=shuffle)
sampler.set_epoch(0) sampler.set_epoch(0)
@@ -205,8 +207,7 @@ class TestRandomSamplerYh:
break break
states = samplers[0].state_dict() states = samplers[0].state_dict()
for i in range(len(already_seen_sets)): for i in range(len(already_seen_sets)):
if states['num_consumed_samples_array'] is not None:
states['num_consumed_samples'] = states['num_consumed_samples_array'][i]
states['num_consumed_samples'] = num_consumed_samples_array[i]
sampler = RandomSampler(dataset, shuffle=shuffle) sampler = RandomSampler(dataset, shuffle=shuffle)
already_seen_set = deepcopy(already_seen_sets[i]) already_seen_set = deepcopy(already_seen_sets[i])
for batch in sampler: for batch in sampler:
@@ -215,12 +216,11 @@ class TestRandomSamplerYh:
# 测试保存之后再次保存 # 测试保存之后再次保存
sampler = RandomSampler(dataset, shuffle=shuffle) sampler = RandomSampler(dataset, shuffle=shuffle)
sampler.set_epoch(0) sampler.set_epoch(0)
if states['num_consumed_samples_array'] is not None:
states['num_consumed_samples'] = states['num_consumed_samples_array'][2]
if len(already_seen_sets)<3: if len(already_seen_sets)<3:
return return
already_seen_set = already_seen_sets[2] already_seen_set = already_seen_sets[2]
count = 0 count = 0
num_consumed_samples_array = list(range(0, num_samples))
for idx in sampler: for idx in sampler:
already_seen_set.add(idx) already_seen_set.add(idx)
count += 1 count += 1
@@ -228,8 +228,7 @@ class TestRandomSamplerYh:
break break


states = sampler.state_dict() states = sampler.state_dict()
if states['num_consumed_samples_array'] is not None:
states['num_consumed_samples'] = states['num_consumed_samples_array'][count]
states['num_consumed_samples'] = num_consumed_samples_array[count]
sampler = RandomSampler(dataset, shuffle=shuffle) sampler = RandomSampler(dataset, shuffle=shuffle)
sampler.load_state_dict(states) sampler.load_state_dict(states)
sampler.set_epoch(0) sampler.set_epoch(0)
@@ -446,12 +445,12 @@ class TestSortedSampler:
@pytest.mark.parametrize('pad', [True, False]) @pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_replicas', [2, 3]) @pytest.mark.parametrize('num_replicas', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
def test_multi(self, pad, num_replica, num_of_data):
def test_multi(self, pad, num_replicas, num_of_data):
data = DatasetWithVaryLength(num_of_data=num_of_data) data = DatasetWithVaryLength(num_of_data=num_of_data)
samplers = [] samplers = []
for i in range(num_replica):
for i in range(num_replicas):
sampler = SortedSampler(dataset=data, length=data.data) sampler = SortedSampler(dataset=data, length=data.data)
sampler.set_distributed(num_replica, rank=i, pad=pad)
sampler.set_distributed(num_replicas, rank=i, pad=pad)
samplers.append(sampler) samplers.append(sampler)


# 保证顺序是没乱的 # 保证顺序是没乱的
@@ -600,12 +599,12 @@ class TestSequentialSampler:
@pytest.mark.parametrize('pad', [True, False]) @pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_replicas', [2, 3]) @pytest.mark.parametrize('num_replicas', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
def test_multi(self, pad, num_replica, num_of_data):
def test_multi(self, pad, num_replicas, num_of_data):
data = DatasetWithVaryLength(num_of_data=num_of_data) data = DatasetWithVaryLength(num_of_data=num_of_data)
samplers = [] samplers = []
for i in range(num_replica):
for i in range(num_replicas):
sampler = SequentialSampler(dataset=data) sampler = SequentialSampler(dataset=data)
sampler.set_distributed(num_replica, rank=i, pad=pad)
sampler.set_distributed(num_replicas, rank=i, pad=pad)
samplers.append(sampler) samplers.append(sampler)


# 保证顺序是没乱的 # 保证顺序是没乱的


Loading…
Cancel
Save