Browse Source

- fix test

tags/v0.4.10
yunfan 5 years ago
parent
commit
58f373d371
2 changed files with 8 additions and 10 deletions
  1. +2
    -4
      fastNLP/core/callback.py
  2. +6
    -6
      test/core/test_metrics.py

+ 2
- 4
fastNLP/core/callback.py View File

@@ -249,13 +249,11 @@ class GradientClipCallback(Callback):
self.parameters = parameters
self.clip_value = clip_value

def on_backward_end(self, model):
def on_backward_end(self):
if self.parameters is None:
self.clip_fun(model.parameters(), self.clip_value)
self.clip_fun(self.model.parameters(), self.clip_value)
else:
self.clip_fun(self.parameters, self.clip_value)
def on_backward_end(self):
self.clip_fun(self.model.parameters(), self.clip_value)


class CallbackException(BaseException):


+ 6
- 6
test/core/test_metrics.py View File

@@ -141,11 +141,11 @@ class SpanF1PreRecMetric(unittest.TestCase):
bmes_lst = ['M-8', 'S-2', 'S-0', 'B-9', 'B-6', 'E-5', 'B-7', 'S-2', 'E-7', 'S-8']
bio_lst = ['O-8', 'O-2', 'B-0', 'O-9', 'I-6', 'I-5', 'I-7', 'I-2', 'I-7', 'O-8']
expect_bmes_res = set()
expect_bmes_res.update([('8', (0, 0)), ('2', (1, 1)), ('0', (2, 2)), ('9', (3, 3)), ('6', (4, 4)),
('5', (5, 5)), ('7', (6, 6)), ('2', (7, 7)), ('7', (8, 8)), ('8', (9, 9))])
expect_bmes_res.update([('8', (0, 1)), ('2', (1, 2)), ('0', (2, 3)), ('9', (3, 4)), ('6', (4, 5)),
('5', (5, 6)), ('7', (6, 7)), ('2', (7, 8)), ('7', (8, 9)), ('8', (9, 10))])
expect_bio_res = set()
expect_bio_res.update([('7', (8, 8)), ('0', (2, 2)), ('2', (7, 7)), ('5', (5, 5)),
('6', (4, 4)), ('7', (6, 6))])
expect_bio_res.update([('7', (8, 9)), ('0', (2, 3)), ('2', (7, 8)), ('5', (5, 6)),
('6', (4, 5)), ('7', (6, 7))])
self.assertSetEqual(expect_bmes_res,set(bmes_tag_to_spans(bmes_lst)))
self.assertSetEqual(expect_bio_res, set(bio_tag_to_spans(bio_lst)))
# 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试
@@ -168,9 +168,9 @@ class SpanF1PreRecMetric(unittest.TestCase):
bmes_lst = ['B', 'E', 'B', 'S', 'B', 'M', 'E', 'M', 'B', 'E']
bio_lst = ['I', 'B', 'O', 'O', 'I', 'O', 'I', 'B', 'O', 'O']
expect_bmes_res = set()
expect_bmes_res.update([('', (0, 1)), ('', (2, 2)), ('', (3, 3)), ('', (4, 6)), ('', (7, 7)), ('', (8, 9))])
expect_bmes_res.update([('', (0, 2)), ('', (2, 3)), ('', (3, 4)), ('', (4, 7)), ('', (7, 8)), ('', (8, 10))])
expect_bio_res = set()
expect_bio_res.update([('', (7, 7)), ('', (6, 6)), ('', (4, 4)), ('', (0, 0)), ('', (1, 1))])
expect_bio_res.update([('', (7, 8)), ('', (6, 7)), ('', (4, 5)), ('', (0, 1)), ('', (1, 2))])
self.assertSetEqual(expect_bmes_res,set(bmes_tag_to_spans(bmes_lst)))
self.assertSetEqual(expect_bio_res, set(bio_tag_to_spans(bio_lst)))
# 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试


Loading…
Cancel
Save