|
@@ -226,7 +226,7 @@ class TestSpanFPreRecMetric: |
|
|
# print(expected_metric) |
|
|
# print(expected_metric) |
|
|
metric_value = metric.get_metric() |
|
|
metric_value = metric.get_metric() |
|
|
for key, value in expected_metric.items(): |
|
|
for key, value in expected_metric.items(): |
|
|
assert np.allclose(value, metric_value[key]) |
|
|
|
|
|
|
|
|
assert np.allclose(value, metric_value[key], 1.e-4) |
|
|
|
|
|
|
|
|
def test_auto_encoding_type_infer(self): |
|
|
def test_auto_encoding_type_infer(self): |
|
|
# 检查是否可以自动check encode的类型 |
|
|
# 检查是否可以自动check encode的类型 |
|
@@ -289,11 +289,10 @@ class TestSpanFPreRecMetric: |
|
|
with pytest.raises(AssertionError): |
|
|
with pytest.raises(AssertionError): |
|
|
metric = SpanFPreRecMetric(tag_vocab=vocabs[encoding_type], encoding_type='bmes') |
|
|
metric = SpanFPreRecMetric(tag_vocab=vocabs[encoding_type], encoding_type='bmes') |
|
|
|
|
|
|
|
|
with pytest.warns(Warning): |
|
|
|
|
|
vocab = Vocabulary(unknown=None, padding=None).add_word_lst(list('bmes')) |
|
|
|
|
|
metric = SpanFPreRecMetric(tag_vocab=vocab, encoding_type='bmeso') |
|
|
|
|
|
vocab = Vocabulary().add_word_lst(list('bmes')) |
|
|
|
|
|
metric = SpanFPreRecMetric(tag_vocab=vocab, encoding_type='bmeso') |
|
|
|
|
|
|
|
|
vocab = Vocabulary(unknown=None, padding=None).add_word_lst(list('bmes')) |
|
|
|
|
|
metric = SpanFPreRecMetric(tag_vocab=vocab, encoding_type='bmeso') |
|
|
|
|
|
vocab = Vocabulary().add_word_lst(list('bmes')) |
|
|
|
|
|
metric = SpanFPreRecMetric(tag_vocab=vocab, encoding_type='bmeso') |
|
|
|
|
|
|
|
|
def test_case5(self): |
|
|
def test_case5(self): |
|
|
# global pool |
|
|
# global pool |
|
|