| @@ -226,7 +226,7 @@ class TestSpanFPreRecMetric: | |||||
| # print(expected_metric) | # print(expected_metric) | ||||
| metric_value = metric.get_metric() | metric_value = metric.get_metric() | ||||
| for key, value in expected_metric.items(): | for key, value in expected_metric.items(): | ||||
| assert np.allclose(value, metric_value[key]) | |||||
| assert np.allclose(value, metric_value[key], 1.e-4) | |||||
| def test_auto_encoding_type_infer(self): | def test_auto_encoding_type_infer(self): | ||||
| # 检查是否可以自动check encode的类型 | # 检查是否可以自动check encode的类型 | ||||
| @@ -289,11 +289,10 @@ class TestSpanFPreRecMetric: | |||||
| with pytest.raises(AssertionError): | with pytest.raises(AssertionError): | ||||
| metric = SpanFPreRecMetric(tag_vocab=vocabs[encoding_type], encoding_type='bmes') | metric = SpanFPreRecMetric(tag_vocab=vocabs[encoding_type], encoding_type='bmes') | ||||
| with pytest.warns(Warning): | |||||
| vocab = Vocabulary(unknown=None, padding=None).add_word_lst(list('bmes')) | |||||
| metric = SpanFPreRecMetric(tag_vocab=vocab, encoding_type='bmeso') | |||||
| vocab = Vocabulary().add_word_lst(list('bmes')) | |||||
| metric = SpanFPreRecMetric(tag_vocab=vocab, encoding_type='bmeso') | |||||
| vocab = Vocabulary(unknown=None, padding=None).add_word_lst(list('bmes')) | |||||
| metric = SpanFPreRecMetric(tag_vocab=vocab, encoding_type='bmeso') | |||||
| vocab = Vocabulary().add_word_lst(list('bmes')) | |||||
| metric = SpanFPreRecMetric(tag_vocab=vocab, encoding_type='bmeso') | |||||
| def test_case5(self): | def test_case5(self): | ||||
| # global pool | # global pool | ||||