Browse Source

完善SpanFPreRecMetric即测试

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
e8b6298fb1
8 changed files with 84 additions and 119 deletions
  1. +1
    -1
      fastNLP/core/drivers/torch_driver/dist_utils.py
  2. +1
    -55
      fastNLP/core/metrics/backend/torch_backend/backend.py
  3. +7
    -10
      fastNLP/core/metrics/element.py
  4. +3
    -5
      fastNLP/core/metrics/metric.py
  5. +23
    -3
      fastNLP/core/metrics/span_f1_pre_rec_metric.py
  6. +10
    -3
      fastNLP/envs/set_backend.py
  7. +2
    -1
      fastNLP/envs/set_env_on_import.py
  8. +37
    -41
      tests/core/metrics/test_span_f1_rec_acc_torch.py

+ 1
- 1
fastNLP/core/drivers/torch_driver/dist_utils.py View File

@@ -402,7 +402,7 @@ def fastnlp_torch_all_gather(obj:Any, device=None, group=None)->List:
if _TORCH_GREATER_EQUAL_1_8:
objs = [None for _ in range(dist.get_world_size(group))]
dist.all_gather_object(objs, obj)
apply_to_collection(obj, torch.Tensor, _to_device, device=device) # 保证如果有tensor的话,所有tensor都在当前卡上
objs = apply_to_collection(objs, torch.Tensor, _to_device, device=device) # 保证如果有tensor的话,所有tensor都在当前卡上
return objs
group = group if group is not None else torch.distributed.group.WORLD
data = convert_to_tensors(obj, device=device)


+ 1
- 55
fastNLP/core/metrics/backend/torch_backend/backend.py View File

@@ -33,8 +33,7 @@ class TorchBackend(Backend):
if dist.is_initialized():
if method is None:
raise AggregateMethodError(should_have_aggregate_method=True)
tensor = self._gather_all(tensor)
# tensor = self.all_gather_object(tensor)
tensor = fastnlp_torch_all_gather(tensor)
if isinstance(tensor[0], torch.Tensor):
tensor = torch.stack(tensor)
# 第一步, aggregate结果
@@ -69,59 +68,6 @@ class TorchBackend(Backend):
def get_scalar(self, tensor) -> float:
return tensor.item()

@staticmethod
def _gather_all(result, group: Optional[Any] = None) -> List:
"""Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes.
Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case
tensors are padded, gathered and then trimmed to secure equal workload for all processes.

Args:
result: the value to sync
group: the process group to gather results from. Defaults to all processes (world)

Return:
gathered_result: list with size equal to the process group where
gathered_result[i] corresponds to result tensor from process i
"""

if group is None:
group = dist.group.WORLD

# convert tensors to contiguous format
result = result.contiguous()

world_size = dist.get_world_size(group)
dist.barrier(group=group)

# if the tensor is scalar, things are easy
if result.ndim == 0:
return _simple_gather_all_tensors(result, group, world_size)

# 1. Gather sizes of all tensors
local_size = torch.tensor(result.shape, device=result.device)
local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)]
dist.all_gather(local_sizes, local_size, group=group)
max_size = torch.stack(local_sizes).max(dim=0).values
all_sizes_equal = all(all(ls == max_size) for ls in local_sizes)

# 2. If shapes are all the same, then do a simple gather:
if all_sizes_equal:
return _simple_gather_all_tensors(result, group, world_size)

# 3. If not, we need to pad each local tensor to maximum size, gather and then truncate
pad_dims = []
pad_by = (max_size - local_size).detach().cpu()
for val in reversed(pad_by):
pad_dims.append(0)
pad_dims.append(val.item())
result_padded = torch.nn.functional.pad(result, pad_dims)
gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)]
dist.all_gather(gathered_result, result_padded, group)
for idx, item_size in enumerate(local_sizes):
slice_param = [slice(dim_size) for dim_size in item_size]
gathered_result[idx] = gathered_result[idx][slice_param]
return gathered_result

def tensor2numpy(self, tensor) -> np.array:
"""
将对应的tensor转为numpy对象


+ 7
- 10
fastNLP/core/metrics/element.py View File

@@ -11,12 +11,12 @@ from fastNLP.envs.env import FASTNLP_GLOBAL_RANK


class Element:
def __init__(self, value: float, aggregate_method, backend: Backend, name=None):
def __init__(self, name, value: float, aggregate_method, backend: Backend):
self.name = name
self.init_value = value
self.aggregate_method = aggregate_method
self.name = name
if backend == 'auto':
raise RuntimeError("You have to specify the backend.")
raise RuntimeError(f"You have to specify the backend for Element:{self.name}.")
elif isinstance(backend, AutoBackend):
self.backend = backend
else:
@@ -41,14 +41,9 @@ class Element:
msg = 'If you see this message, please report a bug.'
if self.name and e.should_have_aggregate_method:
msg = f"Element:{self.name} has no specified `aggregate_method`."
elif e.should_have_aggregate_method:
msg = "Element has no specified `aggregate_method`."
elif self.name and not e.should_have_aggregate_method:
msg = f"Element:{self.name}'s backend:{self.backend.__class__.__name__} does not support " \
f'aggregate_method:{self.aggregate_method}.'
elif not e.should_have_aggregate_method:
msg = f"Element's backend:{self.backend.__class__.__name__} does not support " \
f'aggregate_method:{self.aggregate_method}.'
if e.only_warn:
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
logger.warning(msg)
@@ -75,6 +70,7 @@ class Element:
return self._value

def get_scalar(self) -> float:
self._check_value_initialized()
return self.backend.get_scalar(self._value)

def fill_value(self, value):
@@ -96,7 +92,7 @@ class Element:

def _check_value_when_call(self):
if self.value is None:
prefix = f'Element:`{self.name}`' if self.name else 'Element'
prefix = f'Element:`{self.name}`'
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this "
"element, or use it after it being used by the `Metric.compute()` method.")

@@ -274,9 +270,10 @@ class Element:
"""
try:
if self._value is None:
prefix = f'Element:`{self.name}`' if self.name else 'Element'
prefix = f'Element:`{self.name}`'
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this "
"element, or use it after it being used by the `Metric.compute()` method.")
return getattr(self._value, item)
except AttributeError as e:
logger.error(f"Element:{self.name} has no `{item}` attribute.")
raise e

+ 3
- 5
fastNLP/core/metrics/metric.py View File

@@ -35,7 +35,7 @@ class Metric:
def elements(self) -> dict:
return self._elements

def register_element(self, name=None, value: float = 0, aggregate_method=None, backend='auto') -> Element:
def register_element(self, name, value: float = 0, aggregate_method=None, backend='auto') -> Element:
"""
注册一个 element 对象,注册之后便可以通过在 Metric 中直接通过 self.{name} 进行调用,可以认为该对象即为对应 backend 的
tensor 直接进行加减乘除计算即可。
@@ -57,11 +57,9 @@ class Metric:
else:
backend = AutoBackend(backend)

# 当name为None,默认为变量取得变量名
if name is None:
name = f'ele_var_{len(self._elements)}'
assert name is not None and name not in self.elements

element = Element(value=value, aggregate_method=aggregate_method, backend=backend, name=name)
element = Element(name=name, value=value, aggregate_method=aggregate_method, backend=backend)
self.elements[name] = element
setattr(self, name, element)
return element


+ 23
- 3
fastNLP/core/metrics/span_f1_pre_rec_metric.py View File

@@ -219,6 +219,23 @@ class SpanFPreRecMetric(Metric):
def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None,
only_gross: bool = True, f_type='micro',
beta=1, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True,) -> None:
r"""

:param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN),
在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'.
:param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用 `pred` 取数据
:param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用 `target` 取数据
:param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用 `seq_len` 取数据。
:param str encoding_type: 目前支持bio, bmes, bmeso, bioes。默认为None,通过tag_vocab自动判断.
:param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'个label
:param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个label的f1, pre, rec
:param str f_type: `micro` 或 `macro` . `micro` :通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; `macro` : 分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同)
:param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . 常用为 `beta=0.5, 1, 2` 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。
:param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update()
函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric,
当 backend 不支持分布式时,该参数无意义。
"""
super(SpanFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric)
if f_type not in ('micro', 'macro'):
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type))
@@ -255,7 +272,7 @@ class SpanFPreRecMetric(Metric):
for word, _ in tag_vocab:
word = word.lower()
if word != 'o':
word = word.split('-')[1]
word = word[2:]
if word in self._true_positives:
continue
self._true_positives[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', backend=backend)
@@ -266,8 +283,8 @@ class SpanFPreRecMetric(Metric):
evaluate_result = {}
if not self.only_gross or self.f_type == 'macro':
tags = set(self._false_negatives.keys())
tags.update(set(self._false_positives.keys()))
tags.update(set(self._true_positives.keys()))
tags.update(self._false_positives.keys())
tags.update(self._true_positives.keys())
f_sum = 0
pre_sum = 0
rec_sum = 0
@@ -275,6 +292,9 @@ class SpanFPreRecMetric(Metric):
tp = self._true_positives[tag].get_scalar()
fn = self._false_negatives[tag].get_scalar()
fp = self._false_positives[tag].get_scalar()
if tp == fn == fp == 0:
continue

f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp)
f_sum += f
pre_sum += pre


+ 10
- 3
fastNLP/envs/set_backend.py View File

@@ -150,7 +150,7 @@ def seed_jittor_global_seed(global_seed):
pass


def dump_fastnlp_backend(default:bool = False):
def dump_fastnlp_backend(default:bool = False, backend=None):
"""
将 fastNLP 的设置写入到 ~/.fastNLP/envs/ 文件夹下,
若 default 为 True,则保存的文件为 ~/.fastNLP/envs/default.json 。
@@ -162,6 +162,7 @@ def dump_fastnlp_backend(default:bool = False):
会保存的环境变量为 FASTNLP_BACKEND 。

:param default:
:param backend: 保存使用的 backend 为哪个值,允许的值有 ['torch', 'paddle', 'jittor']。如果为 None ,则使用环境变量中的值。
:return:
"""
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
@@ -176,10 +177,16 @@ def dump_fastnlp_backend(default:bool = False):
os.makedirs(os.path.dirname(env_path), exist_ok=True)

envs = {}
if FASTNLP_BACKEND in os.environ:
envs[FASTNLP_BACKEND] = os.environ[FASTNLP_BACKEND]
assert backend in SUPPORT_BACKENDS, f"fastNLP only supports {SUPPORT_BACKENDS} right now."
if backend is None:
if FASTNLP_BACKEND in os.environ:
envs[FASTNLP_BACKEND] = os.environ[FASTNLP_BACKEND]
else:
envs[FASTNLP_BACKEND] = backend
if len(envs):
with open(env_path, 'w', encoding='utf8') as f:
json.dump(fp=f, obj=envs)

print(f"Writing the default fastNLP backend:{envs[FASTNLP_BACKEND]} to {env_path}.")
else:
raise RuntimeError("No backend specified.")

+ 2
- 1
fastNLP/envs/set_env_on_import.py View File

@@ -48,7 +48,8 @@ def set_env_on_import_paddle():
# TODO jittor may need set this
def set_env_on_import_jittor():
# todo 需要设置 FASTNLP_GLOBAL_RANK 和 FASTNLP_BACKEND_LAUNCH
pass
if 'log_silent' not in os.environ:
os.environ['log_silent'] = '1'


def set_env_on_import():


tests/core/metrics/test_f1_rec_acc_torch.py → tests/core/metrics/test_span_f1_rec_acc_torch.py View File

@@ -99,7 +99,7 @@ def _test(local_rank: int,
assert my_result == sklearn_metric


class SpanFPreRecMetricTest(unittest.TestCase):
class TestSpanFPreRecMetric:

def test_case1(self):
from fastNLP.core.metrics.span_f1_pre_rec_metric import _bmes_tag_to_spans
@@ -136,33 +136,31 @@ class SpanFPreRecMetricTest(unittest.TestCase):
fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None)
fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels))
fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False)
bio_sequence = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169,
-0.3782, 0.8240],
[-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563,
bio_sequence = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169,
-0.3782, 0.8240],
[-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563,
-0.3562, -1.4116],
[1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858,
2.0023, 0.7075],
[-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186,
[ 1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858,
2.0023, 0.7075],
[-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186,
0.3832, -0.1540],
[-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120,
[-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120,
-1.3508, -0.9513],
[1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919,
-0.0842, -0.4294]],
[
[[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490,
-1.4138, -0.8853],
[-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845,
-1.0726, 0.0364],
[0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264,
-0.8836, -0.9320],
[0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044,
-1.6857, 1.1571],
[1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125,
-0.5837, 1.0184],
[1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943,
-0.9025, 0.0864]]
]
])
[ 1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919,
-0.0842, -0.4294]],

[[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490,
-1.4138, -0.8853],
[-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845,
-1.0726, 0.0364],
[ 0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264,
-0.8836, -0.9320],
[ 0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044,
-1.6857, 1.1571],
[ 1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125,
-0.5837, 1.0184],
[ 1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943,
-0.9025, 0.0864]]])
bio_target = torch.LongTensor([[3, 6, 0, 8, 2, 4], [4, 1, 7, 0, 4, 7]])
fastnlp_bio_metric.update(bio_sequence, bio_target, [6, 6])
expect_bio_res = {'pre-1': 0.333333, 'rec-1': 0.333333, 'f-1': 0.333333, 'pre-2': 0.5, 'rec-2': 0.5,
@@ -254,7 +252,7 @@ class SpanFPreRecMetricTest(unittest.TestCase):
# print(expected_metric)
metric_value = metric.get_metric()
for key, value in expected_metric.items():
self.assertAlmostEqual(value, metric_value[key], places=5)
np.allclose(value, metric_value[key])

def test_auto_encoding_type_infer(self):
# 检查是否可以自动check encode的类型
@@ -271,9 +269,8 @@ class SpanFPreRecMetricTest(unittest.TestCase):
vocab.add_word('o')
vocabs[encoding_type] = vocab
for e in ['bio', 'bioes', 'bmeso']:
with self.subTest(e=e):
metric = SpanFPreRecMetric(tag_vocab=vocabs[e])
assert metric.encoding_type == e
metric = SpanFPreRecMetric(tag_vocab=vocabs[e])
assert metric.encoding_type == e

bmes_vocab = _generate_tags('bmes')
vocab = Vocabulary()
@@ -286,7 +283,7 @@ class SpanFPreRecMetricTest(unittest.TestCase):
vocab = Vocabulary()
for i in range(10):
vocab.add_word(str(i))
with self.assertRaises(Exception):
with pytest.raises(Exception):
metric = SpanFPreRecMetric(vocab)

def test_encoding_type(self):
@@ -305,21 +302,20 @@ class SpanFPreRecMetricTest(unittest.TestCase):
vocab.add_word('o')
vocabs[encoding_type] = vocab
for e1, e2 in product(['bio', 'bioes', 'bmeso'], ['bio', 'bioes', 'bmeso']):
with self.subTest(e1=e1, e2=e2):
if e1 == e2:
if e1 == e2:
metric = SpanFPreRecMetric(tag_vocab=vocabs[e1], encoding_type=e2)
else:
s2 = set(e2)
s2.update(set(e1))
if s2 == set(e2):
continue
with pytest.raises(AssertionError):
metric = SpanFPreRecMetric(tag_vocab=vocabs[e1], encoding_type=e2)
else:
s2 = set(e2)
s2.update(set(e1))
if s2 == set(e2):
continue
with self.assertRaises(AssertionError):
metric = SpanFPreRecMetric(tag_vocab=vocabs[e1], encoding_type=e2)
for encoding_type in ['bio', 'bioes', 'bmeso']:
with self.assertRaises(AssertionError):
with pytest.raises(AssertionError):
metric = SpanFPreRecMetric(tag_vocab=vocabs[encoding_type], encoding_type='bmes')

with self.assertWarns(Warning):
with pytest.warns(Warning):
vocab = Vocabulary(unknown=None, padding=None).add_word_lst(list('bmes'))
metric = SpanFPreRecMetric(tag_vocab=vocab, encoding_type='bmeso')
vocab = Vocabulary().add_word_lst(list('bmes'))

Loading…
Cancel
Save