From e5f68bbd5b564e5dfa2d6003481a4fec9254e0dc Mon Sep 17 00:00:00 2001 From: yh Date: Sat, 23 Mar 2019 18:12:32 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8DCRF=E4=B8=BA=E8=B4=9F?= =?UTF-8?q?=E6=95=B0=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataset.py | 4 ++++ fastNLP/core/fieldarray.py | 10 +++++----- fastNLP/modules/decoder/CRF.py | 6 +++--- test/core/test_dataset.py | 5 +++++ test/modules/decoder/test_CRF.py | 24 ++++++++++++++++++++++++ 5 files changed, 41 insertions(+), 8 deletions(-) diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 4b995c94..24376a72 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -243,6 +243,8 @@ class DataSet(object): :param padder: PadderBase类型或None. 设置为None即删除padder。即对该field不进行padding操作. :return: """ + if field_name not in self.field_arrays: + raise KeyError("There is no field named {}.".format(field_name)) self.field_arrays[field_name].set_padder(padder) def set_pad_val(self, field_name, pad_val): @@ -253,6 +255,8 @@ class DataSet(object): :param pad_val: int,该field的padder会以pad_val作为padding index :return: """ + if field_name not in self.field_arrays: + raise KeyError("There is no field named {}.".format(field_name)) self.field_arrays[field_name].set_pad_val(pad_val) def get_input_name(self): diff --git a/fastNLP/core/fieldarray.py b/fastNLP/core/fieldarray.py index 8e42f500..72bb30b5 100644 --- a/fastNLP/core/fieldarray.py +++ b/fastNLP/core/fieldarray.py @@ -206,7 +206,7 @@ class FieldArray(object): if list in type_set: if len(type_set) > 1: # list 跟 非list 混在一起 - raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) + raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) # >1维list inner_type_set = set() for l in content: @@ -229,7 +229,7 @@ class FieldArray(object): return self._basic_type_detection(inner_inner_type_set) else: # list 跟 非list 混在一起 - raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, inner_type_set)) + raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(inner_type_set))) else: # 一维list for content_type in type_set: @@ -253,17 +253,17 @@ class FieldArray(object): return float else: # str 跟 int 或者 float 混在一起 - raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) + raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) else: # str, int, float混在一起 - raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) + raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) def _1d_list_check(self, val): """如果不是1D list就报错 """ type_set = set((type(obj) for obj in val)) if any(obj not in self.BASIC_TYPES for obj in type_set): - raise ValueError("Mixed data types in Field {}: {}".format(self.name, type_set)) + raise ValueError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) self._basic_type_detection(type_set) # otherwise: _basic_type_detection will raise error return True diff --git a/fastNLP/modules/decoder/CRF.py b/fastNLP/modules/decoder/CRF.py index 46350945..df004224 100644 --- a/fastNLP/modules/decoder/CRF.py +++ b/fastNLP/modules/decoder/CRF.py @@ -192,7 +192,7 @@ class ConditionalRandomField(nn.Module): seq_len, batch_size, n_tags = logits.size() alpha = logits[0] if self.include_start_end_trans: - alpha += self.start_scores.view(1, -1) + alpha = alpha + self.start_scores.view(1, -1) flip_mask = mask.eq(0) @@ -204,7 +204,7 @@ class ConditionalRandomField(nn.Module): alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0) if self.include_start_end_trans: - alpha += self.end_scores.view(1, -1) + alpha = alpha + self.end_scores.view(1, -1) return log_sum_exp(alpha, 1) @@ -233,7 +233,7 @@ class ConditionalRandomField(nn.Module): st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] last_idx = mask.long().sum(0) - 1 ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] - score += st_scores + ed_scores + score = score + st_scores + ed_scores # return [B,] return score diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py index eb4c97e8..607f9a13 100644 --- a/test/core/test_dataset.py +++ b/test/core/test_dataset.py @@ -216,6 +216,11 @@ class TestDataSetMethods(unittest.TestCase): self.assertTrue(isinstance(ds, DataSet)) self.assertTrue(len(ds) > 0) + def test_add_null(self): + ds = DataSet() + ds.add_field('test', []) + ds.set_target('test') + class TestDataSetIter(unittest.TestCase): def test__repr__(self): diff --git a/test/modules/decoder/test_CRF.py b/test/modules/decoder/test_CRF.py index 4576d518..a176348f 100644 --- a/test/modules/decoder/test_CRF.py +++ b/test/modules/decoder/test_CRF.py @@ -101,4 +101,28 @@ class TestCRF(unittest.TestCase): # # seq equal # self.assertListEqual([_ for _, score in allen_res], fast_res[0]) + def test_case3(self): + # 测试crf的loss不会出现负数 + import torch + from fastNLP.modules.decoder.CRF import ConditionalRandomField + from fastNLP.core.utils import seq_lens_to_masks + from torch import optim + from torch import nn + num_tags, include_start_end_trans = 4, True + num_samples = 4 + lengths = torch.randint(3, 50, size=(num_samples, )).long() + max_len = lengths.max() + tags = torch.randint(num_tags, size=(num_samples, max_len)) + masks = seq_lens_to_masks(lengths) + feats = nn.Parameter(torch.randn(num_samples, max_len, num_tags)) + crf = ConditionalRandomField(num_tags, include_start_end_trans) + optimizer = optim.SGD([param for param in crf.parameters() if param.requires_grad] + [feats], lr=0.1) + for _ in range(10000): + loss = crf(feats, tags, masks).mean() + optimizer.zero_grad() + loss.backward() + optimizer.step() + if _%1000==0: + print(loss) + assert loss.item()>0, "CRF loss cannot be less than 0."