@@ -402,7 +402,7 @@ def fastnlp_torch_all_gather(obj:Any, device=None, group=None)->List: | |||||
if _TORCH_GREATER_EQUAL_1_8: | if _TORCH_GREATER_EQUAL_1_8: | ||||
objs = [None for _ in range(dist.get_world_size(group))] | objs = [None for _ in range(dist.get_world_size(group))] | ||||
dist.all_gather_object(objs, obj) | dist.all_gather_object(objs, obj) | ||||
apply_to_collection(obj, torch.Tensor, _to_device, device=device) # 保证如果有tensor的话,所有tensor都在当前卡上 | |||||
objs = apply_to_collection(objs, torch.Tensor, _to_device, device=device) # 保证如果有tensor的话,所有tensor都在当前卡上 | |||||
return objs | return objs | ||||
group = group if group is not None else torch.distributed.group.WORLD | group = group if group is not None else torch.distributed.group.WORLD | ||||
data = convert_to_tensors(obj, device=device) | data = convert_to_tensors(obj, device=device) | ||||
@@ -33,8 +33,7 @@ class TorchBackend(Backend): | |||||
if dist.is_initialized(): | if dist.is_initialized(): | ||||
if method is None: | if method is None: | ||||
raise AggregateMethodError(should_have_aggregate_method=True) | raise AggregateMethodError(should_have_aggregate_method=True) | ||||
tensor = self._gather_all(tensor) | |||||
# tensor = self.all_gather_object(tensor) | |||||
tensor = fastnlp_torch_all_gather(tensor) | |||||
if isinstance(tensor[0], torch.Tensor): | if isinstance(tensor[0], torch.Tensor): | ||||
tensor = torch.stack(tensor) | tensor = torch.stack(tensor) | ||||
# 第一步, aggregate结果 | # 第一步, aggregate结果 | ||||
@@ -69,59 +68,6 @@ class TorchBackend(Backend): | |||||
def get_scalar(self, tensor) -> float: | def get_scalar(self, tensor) -> float: | ||||
return tensor.item() | return tensor.item() | ||||
@staticmethod | |||||
def _gather_all(result, group: Optional[Any] = None) -> List: | |||||
"""Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes. | |||||
Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case | |||||
tensors are padded, gathered and then trimmed to secure equal workload for all processes. | |||||
Args: | |||||
result: the value to sync | |||||
group: the process group to gather results from. Defaults to all processes (world) | |||||
Return: | |||||
gathered_result: list with size equal to the process group where | |||||
gathered_result[i] corresponds to result tensor from process i | |||||
""" | |||||
if group is None: | |||||
group = dist.group.WORLD | |||||
# convert tensors to contiguous format | |||||
result = result.contiguous() | |||||
world_size = dist.get_world_size(group) | |||||
dist.barrier(group=group) | |||||
# if the tensor is scalar, things are easy | |||||
if result.ndim == 0: | |||||
return _simple_gather_all_tensors(result, group, world_size) | |||||
# 1. Gather sizes of all tensors | |||||
local_size = torch.tensor(result.shape, device=result.device) | |||||
local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)] | |||||
dist.all_gather(local_sizes, local_size, group=group) | |||||
max_size = torch.stack(local_sizes).max(dim=0).values | |||||
all_sizes_equal = all(all(ls == max_size) for ls in local_sizes) | |||||
# 2. If shapes are all the same, then do a simple gather: | |||||
if all_sizes_equal: | |||||
return _simple_gather_all_tensors(result, group, world_size) | |||||
# 3. If not, we need to pad each local tensor to maximum size, gather and then truncate | |||||
pad_dims = [] | |||||
pad_by = (max_size - local_size).detach().cpu() | |||||
for val in reversed(pad_by): | |||||
pad_dims.append(0) | |||||
pad_dims.append(val.item()) | |||||
result_padded = torch.nn.functional.pad(result, pad_dims) | |||||
gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)] | |||||
dist.all_gather(gathered_result, result_padded, group) | |||||
for idx, item_size in enumerate(local_sizes): | |||||
slice_param = [slice(dim_size) for dim_size in item_size] | |||||
gathered_result[idx] = gathered_result[idx][slice_param] | |||||
return gathered_result | |||||
def tensor2numpy(self, tensor) -> np.array: | def tensor2numpy(self, tensor) -> np.array: | ||||
""" | """ | ||||
将对应的tensor转为numpy对象 | 将对应的tensor转为numpy对象 | ||||
@@ -11,12 +11,12 @@ from fastNLP.envs.env import FASTNLP_GLOBAL_RANK | |||||
class Element: | class Element: | ||||
def __init__(self, value: float, aggregate_method, backend: Backend, name=None): | |||||
def __init__(self, name, value: float, aggregate_method, backend: Backend): | |||||
self.name = name | |||||
self.init_value = value | self.init_value = value | ||||
self.aggregate_method = aggregate_method | self.aggregate_method = aggregate_method | ||||
self.name = name | |||||
if backend == 'auto': | if backend == 'auto': | ||||
raise RuntimeError("You have to specify the backend.") | |||||
raise RuntimeError(f"You have to specify the backend for Element:{self.name}.") | |||||
elif isinstance(backend, AutoBackend): | elif isinstance(backend, AutoBackend): | ||||
self.backend = backend | self.backend = backend | ||||
else: | else: | ||||
@@ -41,14 +41,9 @@ class Element: | |||||
msg = 'If you see this message, please report a bug.' | msg = 'If you see this message, please report a bug.' | ||||
if self.name and e.should_have_aggregate_method: | if self.name and e.should_have_aggregate_method: | ||||
msg = f"Element:{self.name} has no specified `aggregate_method`." | msg = f"Element:{self.name} has no specified `aggregate_method`." | ||||
elif e.should_have_aggregate_method: | |||||
msg = "Element has no specified `aggregate_method`." | |||||
elif self.name and not e.should_have_aggregate_method: | elif self.name and not e.should_have_aggregate_method: | ||||
msg = f"Element:{self.name}'s backend:{self.backend.__class__.__name__} does not support " \ | msg = f"Element:{self.name}'s backend:{self.backend.__class__.__name__} does not support " \ | ||||
f'aggregate_method:{self.aggregate_method}.' | f'aggregate_method:{self.aggregate_method}.' | ||||
elif not e.should_have_aggregate_method: | |||||
msg = f"Element's backend:{self.backend.__class__.__name__} does not support " \ | |||||
f'aggregate_method:{self.aggregate_method}.' | |||||
if e.only_warn: | if e.only_warn: | ||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | ||||
logger.warning(msg) | logger.warning(msg) | ||||
@@ -75,6 +70,7 @@ class Element: | |||||
return self._value | return self._value | ||||
def get_scalar(self) -> float: | def get_scalar(self) -> float: | ||||
self._check_value_initialized() | |||||
return self.backend.get_scalar(self._value) | return self.backend.get_scalar(self._value) | ||||
def fill_value(self, value): | def fill_value(self, value): | ||||
@@ -96,7 +92,7 @@ class Element: | |||||
def _check_value_when_call(self): | def _check_value_when_call(self): | ||||
if self.value is None: | if self.value is None: | ||||
prefix = f'Element:`{self.name}`' if self.name else 'Element' | |||||
prefix = f'Element:`{self.name}`' | |||||
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " | raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " | ||||
"element, or use it after it being used by the `Metric.compute()` method.") | "element, or use it after it being used by the `Metric.compute()` method.") | ||||
@@ -274,9 +270,10 @@ class Element: | |||||
""" | """ | ||||
try: | try: | ||||
if self._value is None: | if self._value is None: | ||||
prefix = f'Element:`{self.name}`' if self.name else 'Element' | |||||
prefix = f'Element:`{self.name}`' | |||||
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " | raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " | ||||
"element, or use it after it being used by the `Metric.compute()` method.") | "element, or use it after it being used by the `Metric.compute()` method.") | ||||
return getattr(self._value, item) | return getattr(self._value, item) | ||||
except AttributeError as e: | except AttributeError as e: | ||||
logger.error(f"Element:{self.name} has no `{item}` attribute.") | |||||
raise e | raise e |
@@ -35,7 +35,7 @@ class Metric: | |||||
def elements(self) -> dict: | def elements(self) -> dict: | ||||
return self._elements | return self._elements | ||||
def register_element(self, name=None, value: float = 0, aggregate_method=None, backend='auto') -> Element: | |||||
def register_element(self, name, value: float = 0, aggregate_method=None, backend='auto') -> Element: | |||||
""" | """ | ||||
注册一个 element 对象,注册之后便可以通过在 Metric 中直接通过 self.{name} 进行调用,可以认为该对象即为对应 backend 的 | 注册一个 element 对象,注册之后便可以通过在 Metric 中直接通过 self.{name} 进行调用,可以认为该对象即为对应 backend 的 | ||||
tensor 直接进行加减乘除计算即可。 | tensor 直接进行加减乘除计算即可。 | ||||
@@ -57,11 +57,9 @@ class Metric: | |||||
else: | else: | ||||
backend = AutoBackend(backend) | backend = AutoBackend(backend) | ||||
# 当name为None,默认为变量取得变量名 | |||||
if name is None: | |||||
name = f'ele_var_{len(self._elements)}' | |||||
assert name is not None and name not in self.elements | |||||
element = Element(value=value, aggregate_method=aggregate_method, backend=backend, name=name) | |||||
element = Element(name=name, value=value, aggregate_method=aggregate_method, backend=backend) | |||||
self.elements[name] = element | self.elements[name] = element | ||||
setattr(self, name, element) | setattr(self, name, element) | ||||
return element | return element | ||||
@@ -219,6 +219,23 @@ class SpanFPreRecMetric(Metric): | |||||
def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None, | def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None, | ||||
only_gross: bool = True, f_type='micro', | only_gross: bool = True, f_type='micro', | ||||
beta=1, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True,) -> None: | beta=1, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True,) -> None: | ||||
r""" | |||||
:param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), | |||||
在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'. | |||||
:param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用 `pred` 取数据 | |||||
:param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用 `target` 取数据 | |||||
:param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用 `seq_len` 取数据。 | |||||
:param str encoding_type: 目前支持bio, bmes, bmeso, bioes。默认为None,通过tag_vocab自动判断. | |||||
:param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'个label | |||||
:param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个label的f1, pre, rec | |||||
:param str f_type: `micro` 或 `macro` . `micro` :通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; `macro` : 分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) | |||||
:param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . 常用为 `beta=0.5, 1, 2` 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||||
:param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update() | |||||
函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。 | |||||
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | |||||
当 backend 不支持分布式时,该参数无意义。 | |||||
""" | |||||
super(SpanFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | super(SpanFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | ||||
if f_type not in ('micro', 'macro'): | if f_type not in ('micro', 'macro'): | ||||
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | ||||
@@ -255,7 +272,7 @@ class SpanFPreRecMetric(Metric): | |||||
for word, _ in tag_vocab: | for word, _ in tag_vocab: | ||||
word = word.lower() | word = word.lower() | ||||
if word != 'o': | if word != 'o': | ||||
word = word.split('-')[1] | |||||
word = word[2:] | |||||
if word in self._true_positives: | if word in self._true_positives: | ||||
continue | continue | ||||
self._true_positives[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', backend=backend) | self._true_positives[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', backend=backend) | ||||
@@ -266,8 +283,8 @@ class SpanFPreRecMetric(Metric): | |||||
evaluate_result = {} | evaluate_result = {} | ||||
if not self.only_gross or self.f_type == 'macro': | if not self.only_gross or self.f_type == 'macro': | ||||
tags = set(self._false_negatives.keys()) | tags = set(self._false_negatives.keys()) | ||||
tags.update(set(self._false_positives.keys())) | |||||
tags.update(set(self._true_positives.keys())) | |||||
tags.update(self._false_positives.keys()) | |||||
tags.update(self._true_positives.keys()) | |||||
f_sum = 0 | f_sum = 0 | ||||
pre_sum = 0 | pre_sum = 0 | ||||
rec_sum = 0 | rec_sum = 0 | ||||
@@ -275,6 +292,9 @@ class SpanFPreRecMetric(Metric): | |||||
tp = self._true_positives[tag].get_scalar() | tp = self._true_positives[tag].get_scalar() | ||||
fn = self._false_negatives[tag].get_scalar() | fn = self._false_negatives[tag].get_scalar() | ||||
fp = self._false_positives[tag].get_scalar() | fp = self._false_positives[tag].get_scalar() | ||||
if tp == fn == fp == 0: | |||||
continue | |||||
f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp) | f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp) | ||||
f_sum += f | f_sum += f | ||||
pre_sum += pre | pre_sum += pre | ||||
@@ -150,7 +150,7 @@ def seed_jittor_global_seed(global_seed): | |||||
pass | pass | ||||
def dump_fastnlp_backend(default:bool = False): | |||||
def dump_fastnlp_backend(default:bool = False, backend=None): | |||||
""" | """ | ||||
将 fastNLP 的设置写入到 ~/.fastNLP/envs/ 文件夹下, | 将 fastNLP 的设置写入到 ~/.fastNLP/envs/ 文件夹下, | ||||
若 default 为 True,则保存的文件为 ~/.fastNLP/envs/default.json 。 | 若 default 为 True,则保存的文件为 ~/.fastNLP/envs/default.json 。 | ||||
@@ -162,6 +162,7 @@ def dump_fastnlp_backend(default:bool = False): | |||||
会保存的环境变量为 FASTNLP_BACKEND 。 | 会保存的环境变量为 FASTNLP_BACKEND 。 | ||||
:param default: | :param default: | ||||
:param backend: 保存使用的 backend 为哪个值,允许的值有 ['torch', 'paddle', 'jittor']。如果为 None ,则使用环境变量中的值。 | |||||
:return: | :return: | ||||
""" | """ | ||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | ||||
@@ -176,10 +177,16 @@ def dump_fastnlp_backend(default:bool = False): | |||||
os.makedirs(os.path.dirname(env_path), exist_ok=True) | os.makedirs(os.path.dirname(env_path), exist_ok=True) | ||||
envs = {} | envs = {} | ||||
if FASTNLP_BACKEND in os.environ: | |||||
envs[FASTNLP_BACKEND] = os.environ[FASTNLP_BACKEND] | |||||
assert backend in SUPPORT_BACKENDS, f"fastNLP only supports {SUPPORT_BACKENDS} right now." | |||||
if backend is None: | |||||
if FASTNLP_BACKEND in os.environ: | |||||
envs[FASTNLP_BACKEND] = os.environ[FASTNLP_BACKEND] | |||||
else: | |||||
envs[FASTNLP_BACKEND] = backend | |||||
if len(envs): | if len(envs): | ||||
with open(env_path, 'w', encoding='utf8') as f: | with open(env_path, 'w', encoding='utf8') as f: | ||||
json.dump(fp=f, obj=envs) | json.dump(fp=f, obj=envs) | ||||
print(f"Writing the default fastNLP backend:{envs[FASTNLP_BACKEND]} to {env_path}.") | print(f"Writing the default fastNLP backend:{envs[FASTNLP_BACKEND]} to {env_path}.") | ||||
else: | |||||
raise RuntimeError("No backend specified.") |
@@ -48,7 +48,8 @@ def set_env_on_import_paddle(): | |||||
# TODO jittor may need set this | # TODO jittor may need set this | ||||
def set_env_on_import_jittor(): | def set_env_on_import_jittor(): | ||||
# todo 需要设置 FASTNLP_GLOBAL_RANK 和 FASTNLP_BACKEND_LAUNCH | # todo 需要设置 FASTNLP_GLOBAL_RANK 和 FASTNLP_BACKEND_LAUNCH | ||||
pass | |||||
if 'log_silent' not in os.environ: | |||||
os.environ['log_silent'] = '1' | |||||
def set_env_on_import(): | def set_env_on_import(): | ||||
@@ -99,7 +99,7 @@ def _test(local_rank: int, | |||||
assert my_result == sklearn_metric | assert my_result == sklearn_metric | ||||
class SpanFPreRecMetricTest(unittest.TestCase): | |||||
class TestSpanFPreRecMetric: | |||||
def test_case1(self): | def test_case1(self): | ||||
from fastNLP.core.metrics.span_f1_pre_rec_metric import _bmes_tag_to_spans | from fastNLP.core.metrics.span_f1_pre_rec_metric import _bmes_tag_to_spans | ||||
@@ -136,33 +136,31 @@ class SpanFPreRecMetricTest(unittest.TestCase): | |||||
fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) | fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) | ||||
fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels)) | fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels)) | ||||
fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) | fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) | ||||
bio_sequence = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, | |||||
-0.3782, 0.8240], | |||||
[-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, | |||||
bio_sequence = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, | |||||
-0.3782, 0.8240], | |||||
[-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, | |||||
-0.3562, -1.4116], | -0.3562, -1.4116], | ||||
[1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858, | |||||
2.0023, 0.7075], | |||||
[-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186, | |||||
[ 1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858, | |||||
2.0023, 0.7075], | |||||
[-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186, | |||||
0.3832, -0.1540], | 0.3832, -0.1540], | ||||
[-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120, | |||||
[-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120, | |||||
-1.3508, -0.9513], | -1.3508, -0.9513], | ||||
[1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, | |||||
-0.0842, -0.4294]], | |||||
[ | |||||
[[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, | |||||
-1.4138, -0.8853], | |||||
[-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, | |||||
-1.0726, 0.0364], | |||||
[0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, | |||||
-0.8836, -0.9320], | |||||
[0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, | |||||
-1.6857, 1.1571], | |||||
[1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, | |||||
-0.5837, 1.0184], | |||||
[1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, | |||||
-0.9025, 0.0864]] | |||||
] | |||||
]) | |||||
[ 1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, | |||||
-0.0842, -0.4294]], | |||||
[[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, | |||||
-1.4138, -0.8853], | |||||
[-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, | |||||
-1.0726, 0.0364], | |||||
[ 0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, | |||||
-0.8836, -0.9320], | |||||
[ 0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, | |||||
-1.6857, 1.1571], | |||||
[ 1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, | |||||
-0.5837, 1.0184], | |||||
[ 1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, | |||||
-0.9025, 0.0864]]]) | |||||
bio_target = torch.LongTensor([[3, 6, 0, 8, 2, 4], [4, 1, 7, 0, 4, 7]]) | bio_target = torch.LongTensor([[3, 6, 0, 8, 2, 4], [4, 1, 7, 0, 4, 7]]) | ||||
fastnlp_bio_metric.update(bio_sequence, bio_target, [6, 6]) | fastnlp_bio_metric.update(bio_sequence, bio_target, [6, 6]) | ||||
expect_bio_res = {'pre-1': 0.333333, 'rec-1': 0.333333, 'f-1': 0.333333, 'pre-2': 0.5, 'rec-2': 0.5, | expect_bio_res = {'pre-1': 0.333333, 'rec-1': 0.333333, 'f-1': 0.333333, 'pre-2': 0.5, 'rec-2': 0.5, | ||||
@@ -254,7 +252,7 @@ class SpanFPreRecMetricTest(unittest.TestCase): | |||||
# print(expected_metric) | # print(expected_metric) | ||||
metric_value = metric.get_metric() | metric_value = metric.get_metric() | ||||
for key, value in expected_metric.items(): | for key, value in expected_metric.items(): | ||||
self.assertAlmostEqual(value, metric_value[key], places=5) | |||||
np.allclose(value, metric_value[key]) | |||||
def test_auto_encoding_type_infer(self): | def test_auto_encoding_type_infer(self): | ||||
# 检查是否可以自动check encode的类型 | # 检查是否可以自动check encode的类型 | ||||
@@ -271,9 +269,8 @@ class SpanFPreRecMetricTest(unittest.TestCase): | |||||
vocab.add_word('o') | vocab.add_word('o') | ||||
vocabs[encoding_type] = vocab | vocabs[encoding_type] = vocab | ||||
for e in ['bio', 'bioes', 'bmeso']: | for e in ['bio', 'bioes', 'bmeso']: | ||||
with self.subTest(e=e): | |||||
metric = SpanFPreRecMetric(tag_vocab=vocabs[e]) | |||||
assert metric.encoding_type == e | |||||
metric = SpanFPreRecMetric(tag_vocab=vocabs[e]) | |||||
assert metric.encoding_type == e | |||||
bmes_vocab = _generate_tags('bmes') | bmes_vocab = _generate_tags('bmes') | ||||
vocab = Vocabulary() | vocab = Vocabulary() | ||||
@@ -286,7 +283,7 @@ class SpanFPreRecMetricTest(unittest.TestCase): | |||||
vocab = Vocabulary() | vocab = Vocabulary() | ||||
for i in range(10): | for i in range(10): | ||||
vocab.add_word(str(i)) | vocab.add_word(str(i)) | ||||
with self.assertRaises(Exception): | |||||
with pytest.raises(Exception): | |||||
metric = SpanFPreRecMetric(vocab) | metric = SpanFPreRecMetric(vocab) | ||||
def test_encoding_type(self): | def test_encoding_type(self): | ||||
@@ -305,21 +302,20 @@ class SpanFPreRecMetricTest(unittest.TestCase): | |||||
vocab.add_word('o') | vocab.add_word('o') | ||||
vocabs[encoding_type] = vocab | vocabs[encoding_type] = vocab | ||||
for e1, e2 in product(['bio', 'bioes', 'bmeso'], ['bio', 'bioes', 'bmeso']): | for e1, e2 in product(['bio', 'bioes', 'bmeso'], ['bio', 'bioes', 'bmeso']): | ||||
with self.subTest(e1=e1, e2=e2): | |||||
if e1 == e2: | |||||
if e1 == e2: | |||||
metric = SpanFPreRecMetric(tag_vocab=vocabs[e1], encoding_type=e2) | |||||
else: | |||||
s2 = set(e2) | |||||
s2.update(set(e1)) | |||||
if s2 == set(e2): | |||||
continue | |||||
with pytest.raises(AssertionError): | |||||
metric = SpanFPreRecMetric(tag_vocab=vocabs[e1], encoding_type=e2) | metric = SpanFPreRecMetric(tag_vocab=vocabs[e1], encoding_type=e2) | ||||
else: | |||||
s2 = set(e2) | |||||
s2.update(set(e1)) | |||||
if s2 == set(e2): | |||||
continue | |||||
with self.assertRaises(AssertionError): | |||||
metric = SpanFPreRecMetric(tag_vocab=vocabs[e1], encoding_type=e2) | |||||
for encoding_type in ['bio', 'bioes', 'bmeso']: | for encoding_type in ['bio', 'bioes', 'bmeso']: | ||||
with self.assertRaises(AssertionError): | |||||
with pytest.raises(AssertionError): | |||||
metric = SpanFPreRecMetric(tag_vocab=vocabs[encoding_type], encoding_type='bmes') | metric = SpanFPreRecMetric(tag_vocab=vocabs[encoding_type], encoding_type='bmes') | ||||
with self.assertWarns(Warning): | |||||
with pytest.warns(Warning): | |||||
vocab = Vocabulary(unknown=None, padding=None).add_word_lst(list('bmes')) | vocab = Vocabulary(unknown=None, padding=None).add_word_lst(list('bmes')) | ||||
metric = SpanFPreRecMetric(tag_vocab=vocab, encoding_type='bmeso') | metric = SpanFPreRecMetric(tag_vocab=vocab, encoding_type='bmeso') | ||||
vocab = Vocabulary().add_word_lst(list('bmes')) | vocab = Vocabulary().add_word_lst(list('bmes')) |