Browse Source

修复CRF为负数的bug

tags/v0.4.10
yh 5 years ago
parent
commit
e5f68bbd5b
5 changed files with 41 additions and 8 deletions
  1. +4
    -0
      fastNLP/core/dataset.py
  2. +5
    -5
      fastNLP/core/fieldarray.py
  3. +3
    -3
      fastNLP/modules/decoder/CRF.py
  4. +5
    -0
      test/core/test_dataset.py
  5. +24
    -0
      test/modules/decoder/test_CRF.py

+ 4
- 0
fastNLP/core/dataset.py View File

@@ -243,6 +243,8 @@ class DataSet(object):
:param padder: PadderBase类型或None. 设置为None即删除padder。即对该field不进行padding操作.
:return:
"""
if field_name not in self.field_arrays:
raise KeyError("There is no field named {}.".format(field_name))
self.field_arrays[field_name].set_padder(padder)

def set_pad_val(self, field_name, pad_val):
@@ -253,6 +255,8 @@ class DataSet(object):
:param pad_val: int,该field的padder会以pad_val作为padding index
:return:
"""
if field_name not in self.field_arrays:
raise KeyError("There is no field named {}.".format(field_name))
self.field_arrays[field_name].set_pad_val(pad_val)

def get_input_name(self):


+ 5
- 5
fastNLP/core/fieldarray.py View File

@@ -206,7 +206,7 @@ class FieldArray(object):
if list in type_set:
if len(type_set) > 1:
# list 跟 非list 混在一起
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set))
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set)))
# >1维list
inner_type_set = set()
for l in content:
@@ -229,7 +229,7 @@ class FieldArray(object):
return self._basic_type_detection(inner_inner_type_set)
else:
# list 跟 非list 混在一起
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, inner_type_set))
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(inner_type_set)))
else:
# 一维list
for content_type in type_set:
@@ -253,17 +253,17 @@ class FieldArray(object):
return float
else:
# str 跟 int 或者 float 混在一起
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set))
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set)))
else:
# str, int, float混在一起
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set))
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set)))

def _1d_list_check(self, val):
"""如果不是1D list就报错
"""
type_set = set((type(obj) for obj in val))
if any(obj not in self.BASIC_TYPES for obj in type_set):
raise ValueError("Mixed data types in Field {}: {}".format(self.name, type_set))
raise ValueError("Mixed data types in Field {}: {}".format(self.name, list(type_set)))
self._basic_type_detection(type_set)
# otherwise: _basic_type_detection will raise error
return True


+ 3
- 3
fastNLP/modules/decoder/CRF.py View File

@@ -192,7 +192,7 @@ class ConditionalRandomField(nn.Module):
seq_len, batch_size, n_tags = logits.size()
alpha = logits[0]
if self.include_start_end_trans:
alpha += self.start_scores.view(1, -1)
alpha = alpha + self.start_scores.view(1, -1)

flip_mask = mask.eq(0)

@@ -204,7 +204,7 @@ class ConditionalRandomField(nn.Module):
alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0)

if self.include_start_end_trans:
alpha += self.end_scores.view(1, -1)
alpha = alpha + self.end_scores.view(1, -1)

return log_sum_exp(alpha, 1)

@@ -233,7 +233,7 @@ class ConditionalRandomField(nn.Module):
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]]
last_idx = mask.long().sum(0) - 1
ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]]
score += st_scores + ed_scores
score = score + st_scores + ed_scores
# return [B,]
return score



+ 5
- 0
test/core/test_dataset.py View File

@@ -216,6 +216,11 @@ class TestDataSetMethods(unittest.TestCase):
self.assertTrue(isinstance(ds, DataSet))
self.assertTrue(len(ds) > 0)

def test_add_null(self):
ds = DataSet()
ds.add_field('test', [])
ds.set_target('test')


class TestDataSetIter(unittest.TestCase):
def test__repr__(self):


+ 24
- 0
test/modules/decoder/test_CRF.py View File

@@ -101,4 +101,28 @@ class TestCRF(unittest.TestCase):
# # seq equal
# self.assertListEqual([_ for _, score in allen_res], fast_res[0])

def test_case3(self):
# 测试crf的loss不会出现负数
import torch
from fastNLP.modules.decoder.CRF import ConditionalRandomField
from fastNLP.core.utils import seq_lens_to_masks
from torch import optim
from torch import nn

num_tags, include_start_end_trans = 4, True
num_samples = 4
lengths = torch.randint(3, 50, size=(num_samples, )).long()
max_len = lengths.max()
tags = torch.randint(num_tags, size=(num_samples, max_len))
masks = seq_lens_to_masks(lengths)
feats = nn.Parameter(torch.randn(num_samples, max_len, num_tags))
crf = ConditionalRandomField(num_tags, include_start_end_trans)
optimizer = optim.SGD([param for param in crf.parameters() if param.requires_grad] + [feats], lr=0.1)
for _ in range(10000):
loss = crf(feats, tags, masks).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
if _%1000==0:
print(loss)
assert loss.item()>0, "CRF loss cannot be less than 0."

Loading…
Cancel
Save