diff --git a/tests/core/metrics/test_span_f1_rec_acc_torch.py b/tests/core/metrics/test_span_f1_rec_acc_torch.py index 227c9643..b053fc66 100644 --- a/tests/core/metrics/test_span_f1_rec_acc_torch.py +++ b/tests/core/metrics/test_span_f1_rec_acc_torch.py @@ -226,7 +226,7 @@ class TestSpanFPreRecMetric: # print(expected_metric) metric_value = metric.get_metric() for key, value in expected_metric.items(): - assert np.allclose(value, metric_value[key]) + assert np.allclose(value, metric_value[key], 1.e-4) def test_auto_encoding_type_infer(self): # 检查是否可以自动check encode的类型 @@ -289,11 +289,10 @@ class TestSpanFPreRecMetric: with pytest.raises(AssertionError): metric = SpanFPreRecMetric(tag_vocab=vocabs[encoding_type], encoding_type='bmes') - with pytest.warns(Warning): - vocab = Vocabulary(unknown=None, padding=None).add_word_lst(list('bmes')) - metric = SpanFPreRecMetric(tag_vocab=vocab, encoding_type='bmeso') - vocab = Vocabulary().add_word_lst(list('bmes')) - metric = SpanFPreRecMetric(tag_vocab=vocab, encoding_type='bmeso') + vocab = Vocabulary(unknown=None, padding=None).add_word_lst(list('bmes')) + metric = SpanFPreRecMetric(tag_vocab=vocab, encoding_type='bmeso') + vocab = Vocabulary().add_word_lst(list('bmes')) + metric = SpanFPreRecMetric(tag_vocab=vocab, encoding_type='bmeso') def test_case5(self): # global pool