diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 4a4e29a5..1bda1f93 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -249,13 +249,11 @@ class GradientClipCallback(Callback): self.parameters = parameters self.clip_value = clip_value - def on_backward_end(self, model): + def on_backward_end(self): if self.parameters is None: - self.clip_fun(model.parameters(), self.clip_value) + self.clip_fun(self.model.parameters(), self.clip_value) else: self.clip_fun(self.parameters, self.clip_value) - def on_backward_end(self): - self.clip_fun(self.model.parameters(), self.clip_value) class CallbackException(BaseException): diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py index 80ed54e2..25138478 100644 --- a/test/core/test_metrics.py +++ b/test/core/test_metrics.py @@ -141,11 +141,11 @@ class SpanF1PreRecMetric(unittest.TestCase): bmes_lst = ['M-8', 'S-2', 'S-0', 'B-9', 'B-6', 'E-5', 'B-7', 'S-2', 'E-7', 'S-8'] bio_lst = ['O-8', 'O-2', 'B-0', 'O-9', 'I-6', 'I-5', 'I-7', 'I-2', 'I-7', 'O-8'] expect_bmes_res = set() - expect_bmes_res.update([('8', (0, 0)), ('2', (1, 1)), ('0', (2, 2)), ('9', (3, 3)), ('6', (4, 4)), - ('5', (5, 5)), ('7', (6, 6)), ('2', (7, 7)), ('7', (8, 8)), ('8', (9, 9))]) + expect_bmes_res.update([('8', (0, 1)), ('2', (1, 2)), ('0', (2, 3)), ('9', (3, 4)), ('6', (4, 5)), + ('5', (5, 6)), ('7', (6, 7)), ('2', (7, 8)), ('7', (8, 9)), ('8', (9, 10))]) expect_bio_res = set() - expect_bio_res.update([('7', (8, 8)), ('0', (2, 2)), ('2', (7, 7)), ('5', (5, 5)), - ('6', (4, 4)), ('7', (6, 6))]) + expect_bio_res.update([('7', (8, 9)), ('0', (2, 3)), ('2', (7, 8)), ('5', (5, 6)), + ('6', (4, 5)), ('7', (6, 7))]) self.assertSetEqual(expect_bmes_res,set(bmes_tag_to_spans(bmes_lst))) self.assertSetEqual(expect_bio_res, set(bio_tag_to_spans(bio_lst))) # 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 @@ -168,9 +168,9 @@ class SpanF1PreRecMetric(unittest.TestCase): bmes_lst = ['B', 'E', 'B', 'S', 'B', 'M', 'E', 'M', 'B', 'E'] bio_lst = ['I', 'B', 'O', 'O', 'I', 'O', 'I', 'B', 'O', 'O'] expect_bmes_res = set() - expect_bmes_res.update([('', (0, 1)), ('', (2, 2)), ('', (3, 3)), ('', (4, 6)), ('', (7, 7)), ('', (8, 9))]) + expect_bmes_res.update([('', (0, 2)), ('', (2, 3)), ('', (3, 4)), ('', (4, 7)), ('', (7, 8)), ('', (8, 10))]) expect_bio_res = set() - expect_bio_res.update([('', (7, 7)), ('', (6, 6)), ('', (4, 4)), ('', (0, 0)), ('', (1, 1))]) + expect_bio_res.update([('', (7, 8)), ('', (6, 7)), ('', (4, 5)), ('', (0, 1)), ('', (1, 2))]) self.assertSetEqual(expect_bmes_res,set(bmes_tag_to_spans(bmes_lst))) self.assertSetEqual(expect_bio_res, set(bio_tag_to_spans(bio_lst))) # 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试