@@ -402,7 +402,7 @@ def fastnlp_torch_all_gather(obj:Any, device=None, group=None)->List: | |||
if _TORCH_GREATER_EQUAL_1_8: | |||
objs = [None for _ in range(dist.get_world_size(group))] | |||
dist.all_gather_object(objs, obj) | |||
apply_to_collection(obj, torch.Tensor, _to_device, device=device) # 保证如果有tensor的话,所有tensor都在当前卡上 | |||
objs = apply_to_collection(objs, torch.Tensor, _to_device, device=device) # 保证如果有tensor的话,所有tensor都在当前卡上 | |||
return objs | |||
group = group if group is not None else torch.distributed.group.WORLD | |||
data = convert_to_tensors(obj, device=device) | |||
@@ -33,8 +33,7 @@ class TorchBackend(Backend): | |||
if dist.is_initialized(): | |||
if method is None: | |||
raise AggregateMethodError(should_have_aggregate_method=True) | |||
tensor = self._gather_all(tensor) | |||
# tensor = self.all_gather_object(tensor) | |||
tensor = fastnlp_torch_all_gather(tensor) | |||
if isinstance(tensor[0], torch.Tensor): | |||
tensor = torch.stack(tensor) | |||
# 第一步, aggregate结果 | |||
@@ -69,59 +68,6 @@ class TorchBackend(Backend): | |||
def get_scalar(self, tensor) -> float: | |||
return tensor.item() | |||
@staticmethod | |||
def _gather_all(result, group: Optional[Any] = None) -> List: | |||
"""Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes. | |||
Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case | |||
tensors are padded, gathered and then trimmed to secure equal workload for all processes. | |||
Args: | |||
result: the value to sync | |||
group: the process group to gather results from. Defaults to all processes (world) | |||
Return: | |||
gathered_result: list with size equal to the process group where | |||
gathered_result[i] corresponds to result tensor from process i | |||
""" | |||
if group is None: | |||
group = dist.group.WORLD | |||
# convert tensors to contiguous format | |||
result = result.contiguous() | |||
world_size = dist.get_world_size(group) | |||
dist.barrier(group=group) | |||
# if the tensor is scalar, things are easy | |||
if result.ndim == 0: | |||
return _simple_gather_all_tensors(result, group, world_size) | |||
# 1. Gather sizes of all tensors | |||
local_size = torch.tensor(result.shape, device=result.device) | |||
local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)] | |||
dist.all_gather(local_sizes, local_size, group=group) | |||
max_size = torch.stack(local_sizes).max(dim=0).values | |||
all_sizes_equal = all(all(ls == max_size) for ls in local_sizes) | |||
# 2. If shapes are all the same, then do a simple gather: | |||
if all_sizes_equal: | |||
return _simple_gather_all_tensors(result, group, world_size) | |||
# 3. If not, we need to pad each local tensor to maximum size, gather and then truncate | |||
pad_dims = [] | |||
pad_by = (max_size - local_size).detach().cpu() | |||
for val in reversed(pad_by): | |||
pad_dims.append(0) | |||
pad_dims.append(val.item()) | |||
result_padded = torch.nn.functional.pad(result, pad_dims) | |||
gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)] | |||
dist.all_gather(gathered_result, result_padded, group) | |||
for idx, item_size in enumerate(local_sizes): | |||
slice_param = [slice(dim_size) for dim_size in item_size] | |||
gathered_result[idx] = gathered_result[idx][slice_param] | |||
return gathered_result | |||
def tensor2numpy(self, tensor) -> np.array: | |||
""" | |||
将对应的tensor转为numpy对象 | |||
@@ -11,12 +11,12 @@ from fastNLP.envs.env import FASTNLP_GLOBAL_RANK | |||
class Element: | |||
def __init__(self, value: float, aggregate_method, backend: Backend, name=None): | |||
def __init__(self, name, value: float, aggregate_method, backend: Backend): | |||
self.name = name | |||
self.init_value = value | |||
self.aggregate_method = aggregate_method | |||
self.name = name | |||
if backend == 'auto': | |||
raise RuntimeError("You have to specify the backend.") | |||
raise RuntimeError(f"You have to specify the backend for Element:{self.name}.") | |||
elif isinstance(backend, AutoBackend): | |||
self.backend = backend | |||
else: | |||
@@ -41,14 +41,9 @@ class Element: | |||
msg = 'If you see this message, please report a bug.' | |||
if self.name and e.should_have_aggregate_method: | |||
msg = f"Element:{self.name} has no specified `aggregate_method`." | |||
elif e.should_have_aggregate_method: | |||
msg = "Element has no specified `aggregate_method`." | |||
elif self.name and not e.should_have_aggregate_method: | |||
msg = f"Element:{self.name}'s backend:{self.backend.__class__.__name__} does not support " \ | |||
f'aggregate_method:{self.aggregate_method}.' | |||
elif not e.should_have_aggregate_method: | |||
msg = f"Element's backend:{self.backend.__class__.__name__} does not support " \ | |||
f'aggregate_method:{self.aggregate_method}.' | |||
if e.only_warn: | |||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||
logger.warning(msg) | |||
@@ -75,6 +70,7 @@ class Element: | |||
return self._value | |||
def get_scalar(self) -> float: | |||
self._check_value_initialized() | |||
return self.backend.get_scalar(self._value) | |||
def fill_value(self, value): | |||
@@ -96,7 +92,7 @@ class Element: | |||
def _check_value_when_call(self): | |||
if self.value is None: | |||
prefix = f'Element:`{self.name}`' if self.name else 'Element' | |||
prefix = f'Element:`{self.name}`' | |||
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " | |||
"element, or use it after it being used by the `Metric.compute()` method.") | |||
@@ -274,9 +270,10 @@ class Element: | |||
""" | |||
try: | |||
if self._value is None: | |||
prefix = f'Element:`{self.name}`' if self.name else 'Element' | |||
prefix = f'Element:`{self.name}`' | |||
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " | |||
"element, or use it after it being used by the `Metric.compute()` method.") | |||
return getattr(self._value, item) | |||
except AttributeError as e: | |||
logger.error(f"Element:{self.name} has no `{item}` attribute.") | |||
raise e |
@@ -35,7 +35,7 @@ class Metric: | |||
def elements(self) -> dict: | |||
return self._elements | |||
def register_element(self, name=None, value: float = 0, aggregate_method=None, backend='auto') -> Element: | |||
def register_element(self, name, value: float = 0, aggregate_method=None, backend='auto') -> Element: | |||
""" | |||
注册一个 element 对象,注册之后便可以通过在 Metric 中直接通过 self.{name} 进行调用,可以认为该对象即为对应 backend 的 | |||
tensor 直接进行加减乘除计算即可。 | |||
@@ -57,11 +57,9 @@ class Metric: | |||
else: | |||
backend = AutoBackend(backend) | |||
# 当name为None,默认为变量取得变量名 | |||
if name is None: | |||
name = f'ele_var_{len(self._elements)}' | |||
assert name is not None and name not in self.elements | |||
element = Element(value=value, aggregate_method=aggregate_method, backend=backend, name=name) | |||
element = Element(name=name, value=value, aggregate_method=aggregate_method, backend=backend) | |||
self.elements[name] = element | |||
setattr(self, name, element) | |||
return element | |||
@@ -219,6 +219,23 @@ class SpanFPreRecMetric(Metric): | |||
def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None, | |||
only_gross: bool = True, f_type='micro', | |||
beta=1, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True,) -> None: | |||
r""" | |||
:param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), | |||
在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'. | |||
:param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用 `pred` 取数据 | |||
:param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用 `target` 取数据 | |||
:param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用 `seq_len` 取数据。 | |||
:param str encoding_type: 目前支持bio, bmes, bmeso, bioes。默认为None,通过tag_vocab自动判断. | |||
:param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'个label | |||
:param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个label的f1, pre, rec | |||
:param str f_type: `micro` 或 `macro` . `micro` :通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; `macro` : 分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) | |||
:param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . 常用为 `beta=0.5, 1, 2` 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||
:param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update() | |||
函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。 | |||
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | |||
当 backend 不支持分布式时,该参数无意义。 | |||
""" | |||
super(SpanFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | |||
if f_type not in ('micro', 'macro'): | |||
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | |||
@@ -255,7 +272,7 @@ class SpanFPreRecMetric(Metric): | |||
for word, _ in tag_vocab: | |||
word = word.lower() | |||
if word != 'o': | |||
word = word.split('-')[1] | |||
word = word[2:] | |||
if word in self._true_positives: | |||
continue | |||
self._true_positives[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', backend=backend) | |||
@@ -266,8 +283,8 @@ class SpanFPreRecMetric(Metric): | |||
evaluate_result = {} | |||
if not self.only_gross or self.f_type == 'macro': | |||
tags = set(self._false_negatives.keys()) | |||
tags.update(set(self._false_positives.keys())) | |||
tags.update(set(self._true_positives.keys())) | |||
tags.update(self._false_positives.keys()) | |||
tags.update(self._true_positives.keys()) | |||
f_sum = 0 | |||
pre_sum = 0 | |||
rec_sum = 0 | |||
@@ -275,6 +292,9 @@ class SpanFPreRecMetric(Metric): | |||
tp = self._true_positives[tag].get_scalar() | |||
fn = self._false_negatives[tag].get_scalar() | |||
fp = self._false_positives[tag].get_scalar() | |||
if tp == fn == fp == 0: | |||
continue | |||
f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp) | |||
f_sum += f | |||
pre_sum += pre | |||
@@ -150,7 +150,7 @@ def seed_jittor_global_seed(global_seed): | |||
pass | |||
def dump_fastnlp_backend(default:bool = False): | |||
def dump_fastnlp_backend(default:bool = False, backend=None): | |||
""" | |||
将 fastNLP 的设置写入到 ~/.fastNLP/envs/ 文件夹下, | |||
若 default 为 True,则保存的文件为 ~/.fastNLP/envs/default.json 。 | |||
@@ -162,6 +162,7 @@ def dump_fastnlp_backend(default:bool = False): | |||
会保存的环境变量为 FASTNLP_BACKEND 。 | |||
:param default: | |||
:param backend: 保存使用的 backend 为哪个值,允许的值有 ['torch', 'paddle', 'jittor']。如果为 None ,则使用环境变量中的值。 | |||
:return: | |||
""" | |||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||
@@ -176,10 +177,16 @@ def dump_fastnlp_backend(default:bool = False): | |||
os.makedirs(os.path.dirname(env_path), exist_ok=True) | |||
envs = {} | |||
if FASTNLP_BACKEND in os.environ: | |||
envs[FASTNLP_BACKEND] = os.environ[FASTNLP_BACKEND] | |||
assert backend in SUPPORT_BACKENDS, f"fastNLP only supports {SUPPORT_BACKENDS} right now." | |||
if backend is None: | |||
if FASTNLP_BACKEND in os.environ: | |||
envs[FASTNLP_BACKEND] = os.environ[FASTNLP_BACKEND] | |||
else: | |||
envs[FASTNLP_BACKEND] = backend | |||
if len(envs): | |||
with open(env_path, 'w', encoding='utf8') as f: | |||
json.dump(fp=f, obj=envs) | |||
print(f"Writing the default fastNLP backend:{envs[FASTNLP_BACKEND]} to {env_path}.") | |||
else: | |||
raise RuntimeError("No backend specified.") |
@@ -48,7 +48,8 @@ def set_env_on_import_paddle(): | |||
# TODO jittor may need set this | |||
def set_env_on_import_jittor(): | |||
# todo 需要设置 FASTNLP_GLOBAL_RANK 和 FASTNLP_BACKEND_LAUNCH | |||
pass | |||
if 'log_silent' not in os.environ: | |||
os.environ['log_silent'] = '1' | |||
def set_env_on_import(): | |||
@@ -99,7 +99,7 @@ def _test(local_rank: int, | |||
assert my_result == sklearn_metric | |||
class SpanFPreRecMetricTest(unittest.TestCase): | |||
class TestSpanFPreRecMetric: | |||
def test_case1(self): | |||
from fastNLP.core.metrics.span_f1_pre_rec_metric import _bmes_tag_to_spans | |||
@@ -136,33 +136,31 @@ class SpanFPreRecMetricTest(unittest.TestCase): | |||
fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) | |||
fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels)) | |||
fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) | |||
bio_sequence = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, | |||
-0.3782, 0.8240], | |||
[-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, | |||
bio_sequence = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, | |||
-0.3782, 0.8240], | |||
[-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, | |||
-0.3562, -1.4116], | |||
[1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858, | |||
2.0023, 0.7075], | |||
[-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186, | |||
[ 1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858, | |||
2.0023, 0.7075], | |||
[-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186, | |||
0.3832, -0.1540], | |||
[-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120, | |||
[-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120, | |||
-1.3508, -0.9513], | |||
[1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, | |||
-0.0842, -0.4294]], | |||
[ | |||
[[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, | |||
-1.4138, -0.8853], | |||
[-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, | |||
-1.0726, 0.0364], | |||
[0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, | |||
-0.8836, -0.9320], | |||
[0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, | |||
-1.6857, 1.1571], | |||
[1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, | |||
-0.5837, 1.0184], | |||
[1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, | |||
-0.9025, 0.0864]] | |||
] | |||
]) | |||
[ 1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, | |||
-0.0842, -0.4294]], | |||
[[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, | |||
-1.4138, -0.8853], | |||
[-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, | |||
-1.0726, 0.0364], | |||
[ 0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, | |||
-0.8836, -0.9320], | |||
[ 0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, | |||
-1.6857, 1.1571], | |||
[ 1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, | |||
-0.5837, 1.0184], | |||
[ 1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, | |||
-0.9025, 0.0864]]]) | |||
bio_target = torch.LongTensor([[3, 6, 0, 8, 2, 4], [4, 1, 7, 0, 4, 7]]) | |||
fastnlp_bio_metric.update(bio_sequence, bio_target, [6, 6]) | |||
expect_bio_res = {'pre-1': 0.333333, 'rec-1': 0.333333, 'f-1': 0.333333, 'pre-2': 0.5, 'rec-2': 0.5, | |||
@@ -254,7 +252,7 @@ class SpanFPreRecMetricTest(unittest.TestCase): | |||
# print(expected_metric) | |||
metric_value = metric.get_metric() | |||
for key, value in expected_metric.items(): | |||
self.assertAlmostEqual(value, metric_value[key], places=5) | |||
np.allclose(value, metric_value[key]) | |||
def test_auto_encoding_type_infer(self): | |||
# 检查是否可以自动check encode的类型 | |||
@@ -271,9 +269,8 @@ class SpanFPreRecMetricTest(unittest.TestCase): | |||
vocab.add_word('o') | |||
vocabs[encoding_type] = vocab | |||
for e in ['bio', 'bioes', 'bmeso']: | |||
with self.subTest(e=e): | |||
metric = SpanFPreRecMetric(tag_vocab=vocabs[e]) | |||
assert metric.encoding_type == e | |||
metric = SpanFPreRecMetric(tag_vocab=vocabs[e]) | |||
assert metric.encoding_type == e | |||
bmes_vocab = _generate_tags('bmes') | |||
vocab = Vocabulary() | |||
@@ -286,7 +283,7 @@ class SpanFPreRecMetricTest(unittest.TestCase): | |||
vocab = Vocabulary() | |||
for i in range(10): | |||
vocab.add_word(str(i)) | |||
with self.assertRaises(Exception): | |||
with pytest.raises(Exception): | |||
metric = SpanFPreRecMetric(vocab) | |||
def test_encoding_type(self): | |||
@@ -305,21 +302,20 @@ class SpanFPreRecMetricTest(unittest.TestCase): | |||
vocab.add_word('o') | |||
vocabs[encoding_type] = vocab | |||
for e1, e2 in product(['bio', 'bioes', 'bmeso'], ['bio', 'bioes', 'bmeso']): | |||
with self.subTest(e1=e1, e2=e2): | |||
if e1 == e2: | |||
if e1 == e2: | |||
metric = SpanFPreRecMetric(tag_vocab=vocabs[e1], encoding_type=e2) | |||
else: | |||
s2 = set(e2) | |||
s2.update(set(e1)) | |||
if s2 == set(e2): | |||
continue | |||
with pytest.raises(AssertionError): | |||
metric = SpanFPreRecMetric(tag_vocab=vocabs[e1], encoding_type=e2) | |||
else: | |||
s2 = set(e2) | |||
s2.update(set(e1)) | |||
if s2 == set(e2): | |||
continue | |||
with self.assertRaises(AssertionError): | |||
metric = SpanFPreRecMetric(tag_vocab=vocabs[e1], encoding_type=e2) | |||
for encoding_type in ['bio', 'bioes', 'bmeso']: | |||
with self.assertRaises(AssertionError): | |||
with pytest.raises(AssertionError): | |||
metric = SpanFPreRecMetric(tag_vocab=vocabs[encoding_type], encoding_type='bmes') | |||
with self.assertWarns(Warning): | |||
with pytest.warns(Warning): | |||
vocab = Vocabulary(unknown=None, padding=None).add_word_lst(list('bmes')) | |||
metric = SpanFPreRecMetric(tag_vocab=vocab, encoding_type='bmeso') | |||
vocab = Vocabulary().add_word_lst(list('bmes')) |