Browse Source

1.删除sampler中num_consumed_samples_array,转为通过batch_size等进行换算。2.在Metric中新增all_gather_object接口,方便评测

tags/v1.0.0alpha
yh_cc 2 years ago
parent
commit
d813f31f9a
10 changed files with 86 additions and 87 deletions
  1. +14
    -4
      fastNLP/core/drivers/torch_driver/torch_driver.py
  2. +6
    -9
      fastNLP/core/metrics/classify_f1_pre_rec_metric.py
  3. +13
    -1
      fastNLP/core/metrics/metric.py
  4. +6
    -9
      fastNLP/core/metrics/span_f1_pre_rec_metric.py
  5. +17
    -18
      fastNLP/core/samplers/reproducible_batch_sampler.py
  6. +11
    -23
      fastNLP/core/samplers/reproducible_sampler.py
  7. +2
    -0
      fastNLP/core/samplers/utils.py
  8. +0
    -3
      fastNLP/envs/env.py
  9. +6
    -8
      tests/core/samplers/test_reproducible_batch_sampler.py
  10. +11
    -12
      tests/core/samplers/test_reproducible_sampler.py

+ 14
- 4
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -188,17 +188,27 @@ class TorchDriver(Driver):
num_consumed_batches = states.pop('num_consumed_batches')
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict):
sampler_states = sampler.state_dict()
# 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples
# 会造成多余实际消耗的问题。
# 需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples
# 会造成多余实际消耗的问题。因为
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)
if num_consumed_samples_array is not None:
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。
try:
if dataloader_args.batch_size is not None:
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size
except: # 有可能 batch_size 为 None,就只有损失精度了
else: # 有可能 batch_size 为 None,就只有损失精度了
logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")
num_consumed_batches = sampler_states['num_consumed_samples']
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches]
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report."
else:
if dataloader_args.batch_size is not None:
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \
* num_consumed_batches
else:
logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")

states['sampler_states'] = sampler_states
else:
raise RuntimeError(


+ 6
- 9
fastNLP/core/metrics/classify_f1_pre_rec_metric.py View File

@@ -63,15 +63,12 @@ class ClassifyFPreRecMetric(Metric):
evaluate_result = {}

# 通过 all_gather_object 将各个卡上的结果收集过来,并加和。
if self.aggregate_when_get_metric:
ls = self.backend.all_gather_object([self._tp, self._fp, self._fn])
tps, fps, fns = zip(*ls)
_tp, _fp, _fn = Counter(), Counter(), Counter()
for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]):
for _c in cs:
c.update(_c)
else:
_tp, _fp, _fn = self._tp, self._fp, self._tp
ls = self.all_gather_object([self._tp, self._fp, self._fn])
tps, fps, fns = zip(*ls)
_tp, _fp, _fn = Counter(), Counter(), Counter()
for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]):
for _c in cs:
c.update(_c)

if not self.only_gross or self.f_type == 'macro':
tags = set(_fn.keys())


+ 13
- 1
fastNLP/core/metrics/metric.py View File

@@ -4,7 +4,7 @@ __all__ = [

from abc import abstractmethod

from typing import Union
from typing import Union, List
import functools
from contextlib import contextmanager
import numpy as np
@@ -180,3 +180,15 @@ class Metric:
"""
for element in self.elements.values():
element.to(device)

def all_gather_object(self, obj, group=None)->List:
"""
给定 obj 将各个 rank 上的 obj 汇总到每个 obj 上。返回一个 list 对象,里面依次为各个 rank 对应的 obj 。

:param obj: 需要汇总的对象,必须是个 pickable 的对象。
:param group:
:return: -> List[obj0, obj1, ...] 其中 obj0 是rank 0 上的 obj;obj1 是 rank 1 上的 obj...
"""
if self.aggregate_when_get_metric:
return self.backend.all_gather_object(obj, group=group)
return [obj]

+ 6
- 9
fastNLP/core/metrics/span_f1_pre_rec_metric.py View File

@@ -264,15 +264,12 @@ class SpanFPreRecMetric(Metric):
evaluate_result = {}

# 通过 all_gather_object 将各个卡上的结果收集过来,并加和。
if self.aggregate_when_get_metric:
ls = self.backend.all_gather_object([self._tp, self._fp, self._fn])
tps, fps, fns = zip(*ls)
_tp, _fp, _fn = Counter(), Counter(), Counter()
for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]):
for _c in cs:
c.update(_c)
else:
_tp, _fp, _fn = self._tp, self._fp, self._tp
ls = self.all_gather_object([self._tp, self._fp, self._fn])
tps, fps, fns = zip(*ls)
_tp, _fp, _fn = Counter(), Counter(), Counter()
for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]):
for _c in cs:
c.update(_c)

if not self.only_gross or self.f_type == 'macro':
tags = set(_fn.keys())


+ 17
- 18
fastNLP/core/samplers/reproducible_batch_sampler.py View File

@@ -13,9 +13,8 @@ import numpy as np

from fastNLP.core.dataset import DataSet
from fastNLP.core.log import logger
from .utils import create_array, NumConsumedSamplesArray
from .utils import create_array
from abc import abstractmethod
from fastNLP.envs.env import FASTNLP_DEQUE_SIZE


class ReproducibleBatchSampler:
@@ -37,9 +36,6 @@ class ReproducibleBatchSampler:
@abstractmethod
def state_dict(self):
"""
由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomBatchSampler 中 states 里面 num_consumed_samples_array 的实现
正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward
了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。

:return:
"""
@@ -57,6 +53,14 @@ class ReproducibleBatchSampler:
def batch_idx_in_epoch(self):
raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.")

@property
def num_replicas(self):
return self._num_replicas

@num_replicas.setter
def num_replicas(self, value):
self._num_replicas = value


class RandomBatchSampler(ReproducibleBatchSampler):
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿;
@@ -105,24 +109,21 @@ class RandomBatchSampler(ReproducibleBatchSampler):
else:
index_list = self.index_list

# 记住每个 batch 对应的 consumed_samples, 需要这个原因是由于现在的 dataloader 都存在预取数据的设计,需要再结合Trainer中
# 暂时弃用。记住每个 batch 对应的 consumed_samples, 需要这个原因是由于现在的 dataloader 都存在预取数据的设计,需要再结合Trainer中
# batch_idx_in_epoch 才能最终确定实际消耗的数据。这个变量需要记录每次yield出去时的真实 num_consumed_samples 的数值。
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30),
num_consumed_samples=self.num_consumed_samples)
# self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30),
# num_consumed_samples=self.num_consumed_samples)
for idx in index_list:
batch.append(idx)
if len(batch) == self.batch_size:
self.num_consumed_samples += self.batch_size # [16, 32, 48, 64,..., ]
self.num_consumed_samples_array.push(self.num_consumed_samples)
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
self.num_consumed_samples += len(batch)
self.num_consumed_samples_array.push(self.num_consumed_samples)
yield batch
# 需要重置防止边界条件问题
self.num_consumed_samples = 0
delattr(self, 'num_consumed_samples_array')

def __len__(self) -> int:
if self.drop_last:
@@ -136,7 +137,6 @@ class RandomBatchSampler(ReproducibleBatchSampler):
"num_consumed_samples": self.num_consumed_samples,
'sampler_type': self.__class__.__name__
}
states['num_consumed_samples_array'] = getattr(self, 'num_consumed_samples_array', None)
return states

def load_state_dict(self, states: Dict):
@@ -327,15 +327,14 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size:
batches = batches[:-1]

self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30),
num_consumed_samples=self.num_consumed_samples)
# 暂时弃用
# self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30),
# num_consumed_samples=self.num_consumed_samples)
for batch in batches:
self.num_consumed_samples += self.num_replicas * len(batch)
self.num_consumed_samples_array.push(self.num_consumed_samples)
yield list(map(int, batch))
self.during_iter = False
self.num_consumed_samples = 0
delattr(self, 'num_consumed_samples_array')
self.old_batch_size = self.batch_size
self.old_num_batch_per_bucket = self.num_batch_per_bucket
self.old_num_replicas = self.num_replicas
@@ -390,8 +389,8 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples,
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle,
'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket,
'num_replicas': self.num_replicas,
'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)}
'num_replicas': self.num_replicas
}

return states



+ 11
- 23
fastNLP/core/samplers/reproducible_sampler.py View File

@@ -7,15 +7,11 @@ __all__ = [

from typing import Dict, List, Union
import math
import os

import numpy as np

from fastNLP.core.log import logger
from fastNLP.core.dataset import DataSet
from fastNLP.envs.env import FASTNLP_DEQUE_SIZE
from .utils import NumConsumedSamplesArray



class ReproducibleSampler:
@@ -36,9 +32,6 @@ class ReproducibleSampler:

def state_dict(self):
"""
由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomSampler 中 states 里面 num_consumed_samples_array 的实现
正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward
了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。

:return:
"""
@@ -54,6 +47,14 @@ class ReproducibleSampler:
def set_epoch(self, epoch):
pass

@property
def num_repliacs(self):
return self._num_replicas

@num_repliacs.setter
def num_repliacs(self, value):
self._num_replicas = value


class RandomSampler(ReproducibleSampler):
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs):
@@ -121,15 +122,11 @@ class RandomSampler(ReproducibleSampler):
indices = indices[self.num_consumed_samples:]
indices = indices[self.rank:len(indices):self.num_replicas]
assert len(indices) == self.num_left_samples
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000),
num_consumed_samples=self.num_consumed_samples)
for idx, index in enumerate(indices, start=1):
self.num_consumed_samples += self.num_replicas
self.num_consumed_samples_array.push(self.num_consumed_samples)
yield index
self.during_iter = False
self.num_consumed_samples = 0
delattr(self, 'num_consumed_samples_array')

def generate_indices(self) -> List[int]:
"""
@@ -150,8 +147,7 @@ class RandomSampler(ReproducibleSampler):

def state_dict(self) -> Dict:
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples,
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle,
'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)}
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle}
return states

def load_state_dict(self, states: Dict):
@@ -255,15 +251,11 @@ class SequentialSampler(RandomSampler):
indices = indices[self.rank:len(indices):self.num_replicas]
assert len(indices) == self.num_left_samples

self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000),
num_consumed_samples=self.num_consumed_samples)
for idx, index in enumerate(indices, start=1):
self.num_consumed_samples += self.num_replicas
self.num_consumed_samples_array.push(self.num_consumed_samples)
yield index
self.during_iter = False
self.num_consumed_samples = 0
delattr(self, 'num_consumed_samples_array')

def generate_indices(self) -> List[int]:
"""
@@ -275,8 +267,8 @@ class SequentialSampler(RandomSampler):

def state_dict(self) -> Dict:
states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__,
'length': len(self.dataset),
'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)}
'length': len(self.dataset)
}
return states

def load_state_dict(self, states: Dict):
@@ -346,13 +338,9 @@ class SortedSampler(SequentialSampler):
indices = indices[self.rank:len(indices):self.num_replicas]
assert len(indices) == self.num_left_samples

self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000),
num_consumed_samples=self.num_consumed_samples)
for idx, index in enumerate(indices, start=1):
self.num_consumed_samples += self.num_replicas
self.num_consumed_samples_array.push(self.num_consumed_samples)
yield index
self.during_iter = False
self.num_consumed_samples = 0
delattr(self, 'num_consumed_samples_array')


+ 2
- 0
fastNLP/core/samplers/utils.py View File

@@ -43,6 +43,8 @@ class NumConsumedSamplesArray:
array[9] # 输出为9,表示这个位置真实的 num_consumed_samples 是多少。
array[6] # 报错,因为只保留了3个最近的数据,6超过了最大buffer的记录了,即 [7, 8, 9]

暂时由于 sampler 的 batch 都是规整的,先保留

:param buffer_size: 报错多少个历史。
:param num_consumed_samples: 第一个 num_consumed_samples 是多少。
"""


+ 0
- 3
fastNLP/envs/env.py View File

@@ -45,9 +45,6 @@ FASTNLP_REMOVE_LOCAL_RANK = 'FASTNLP_REMOVE_LOCAL_RANK'
# todo 注释
FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH"

# fastNLP 中初始化deque的默认大小
FASTNLP_DEQUE_SIZE = 'FASTNLP_DEQUE_SIZE'

# fastNLP中用于关闭 fastNLP 1.barrier 与 2.gather/broadcast 。默认为 '0' 表示不关闭;为 '1' 表示 fastNLP 的 barrier 不执行;
# 为 '2' 表示 barrier 与 gather/broadcast 都关闭。
FASTNLP_NO_SYNC = 'FASTNLP_NO_SYNC'


+ 6
- 8
tests/core/samplers/test_reproducible_batch_sampler.py View File

@@ -445,8 +445,7 @@ class TestBucketedBatchSampler:
@pytest.mark.parametrize('num_replicas', [1, 2, 3])
def test_multi_save_load(self, shuffle, drop_last, pad, num_samples, num_replicas):
"""
测试是否能够正确地恢复使用过的(forward)数据,由于 DataLoader 存在预取,所以 Sampler 自身的 num_consumed_samples 可能
偏多
测试是否能够正确地恢复使用过的(forward)数据

:return:
"""
@@ -454,6 +453,7 @@ class TestBucketedBatchSampler:
num_batch_per_bucket = 10
dataset = DatasetWithVaryLength(num_of_data=num_samples)
samplers = []
num_consumed_samples_array = list(range(0, num_samples+num_replicas, num_replicas))
for i in range(num_replicas):
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size,
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last)
@@ -472,8 +472,7 @@ class TestBucketedBatchSampler:
break
states = samplers[0].state_dict()
for i in range(len(already_seen_sets)):
if states['num_consumed_samples_array'] is not None:
states['num_consumed_samples'] = states['num_consumed_samples_array'][i]
states['num_consumed_samples'] = num_consumed_samples_array[i]
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size+1,
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle,
drop_last=drop_last)
@@ -489,8 +488,7 @@ class TestBucketedBatchSampler:
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle,
drop_last=drop_last)
sampler.set_epoch(0)
if states['num_consumed_samples_array'] is not None:
states['num_consumed_samples'] = states['num_consumed_samples_array'][2]
states['num_consumed_samples'] = num_consumed_samples_array[2]
if len(already_seen_sets)<3:
return
already_seen_set = already_seen_sets[2]
@@ -502,8 +500,8 @@ class TestBucketedBatchSampler:
break

states = sampler.state_dict()
if states['num_consumed_samples_array'] is not None:
states['num_consumed_samples'] = states['num_consumed_samples_array'][count]
num_consumed_samples_array = list(range(len(dataset)))
states['num_consumed_samples'] = num_consumed_samples_array[count]
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size//2,
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle,
drop_last=drop_last)


+ 11
- 12
tests/core/samplers/test_reproducible_sampler.py View File

@@ -186,9 +186,11 @@ class TestRandomSamplerYh:
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000])
@pytest.mark.parametrize('num_replicas', [1, 2, 3])
def test_num_consumed_samples_array(self, shuffle, pad, num_samples, num_replicas):
# def test_num_consumed_samples_array(self, shuffle=True, pad=True, num_samples=100, num_replicas=2):
# 测试在 sampler 多生成的时候,可以仍然可以恢复
dataset = DatasetWithVaryLength(num_of_data=num_samples)
samplers = []
num_consumed_samples_array = list(range(0, len(dataset)+num_replicas, num_replicas))
for i in range(num_replicas):
sampler = RandomSampler(dataset, shuffle=shuffle)
sampler.set_epoch(0)
@@ -205,8 +207,7 @@ class TestRandomSamplerYh:
break
states = samplers[0].state_dict()
for i in range(len(already_seen_sets)):
if states['num_consumed_samples_array'] is not None:
states['num_consumed_samples'] = states['num_consumed_samples_array'][i]
states['num_consumed_samples'] = num_consumed_samples_array[i]
sampler = RandomSampler(dataset, shuffle=shuffle)
already_seen_set = deepcopy(already_seen_sets[i])
for batch in sampler:
@@ -215,12 +216,11 @@ class TestRandomSamplerYh:
# 测试保存之后再次保存
sampler = RandomSampler(dataset, shuffle=shuffle)
sampler.set_epoch(0)
if states['num_consumed_samples_array'] is not None:
states['num_consumed_samples'] = states['num_consumed_samples_array'][2]
if len(already_seen_sets)<3:
return
already_seen_set = already_seen_sets[2]
count = 0
num_consumed_samples_array = list(range(0, num_samples))
for idx in sampler:
already_seen_set.add(idx)
count += 1
@@ -228,8 +228,7 @@ class TestRandomSamplerYh:
break

states = sampler.state_dict()
if states['num_consumed_samples_array'] is not None:
states['num_consumed_samples'] = states['num_consumed_samples_array'][count]
states['num_consumed_samples'] = num_consumed_samples_array[count]
sampler = RandomSampler(dataset, shuffle=shuffle)
sampler.load_state_dict(states)
sampler.set_epoch(0)
@@ -446,12 +445,12 @@ class TestSortedSampler:
@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_replicas', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
def test_multi(self, pad, num_replica, num_of_data):
def test_multi(self, pad, num_replicas, num_of_data):
data = DatasetWithVaryLength(num_of_data=num_of_data)
samplers = []
for i in range(num_replica):
for i in range(num_replicas):
sampler = SortedSampler(dataset=data, length=data.data)
sampler.set_distributed(num_replica, rank=i, pad=pad)
sampler.set_distributed(num_replicas, rank=i, pad=pad)
samplers.append(sampler)

# 保证顺序是没乱的
@@ -600,12 +599,12 @@ class TestSequentialSampler:
@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_replicas', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
def test_multi(self, pad, num_replica, num_of_data):
def test_multi(self, pad, num_replicas, num_of_data):
data = DatasetWithVaryLength(num_of_data=num_of_data)
samplers = []
for i in range(num_replica):
for i in range(num_replicas):
sampler = SequentialSampler(dataset=data)
sampler.set_distributed(num_replica, rank=i, pad=pad)
sampler.set_distributed(num_replicas, rank=i, pad=pad)
samplers.append(sampler)

# 保证顺序是没乱的


Loading…
Cancel
Save