@@ -243,6 +243,8 @@ class DataSet(object): | |||||
:param padder: PadderBase类型或None. 设置为None即删除padder。即对该field不进行padding操作. | :param padder: PadderBase类型或None. 设置为None即删除padder。即对该field不进行padding操作. | ||||
:return: | :return: | ||||
""" | """ | ||||
if field_name not in self.field_arrays: | |||||
raise KeyError("There is no field named {}.".format(field_name)) | |||||
self.field_arrays[field_name].set_padder(padder) | self.field_arrays[field_name].set_padder(padder) | ||||
def set_pad_val(self, field_name, pad_val): | def set_pad_val(self, field_name, pad_val): | ||||
@@ -253,6 +255,8 @@ class DataSet(object): | |||||
:param pad_val: int,该field的padder会以pad_val作为padding index | :param pad_val: int,该field的padder会以pad_val作为padding index | ||||
:return: | :return: | ||||
""" | """ | ||||
if field_name not in self.field_arrays: | |||||
raise KeyError("There is no field named {}.".format(field_name)) | |||||
self.field_arrays[field_name].set_pad_val(pad_val) | self.field_arrays[field_name].set_pad_val(pad_val) | ||||
def get_input_name(self): | def get_input_name(self): | ||||
@@ -206,7 +206,7 @@ class FieldArray(object): | |||||
if list in type_set: | if list in type_set: | ||||
if len(type_set) > 1: | if len(type_set) > 1: | ||||
# list 跟 非list 混在一起 | # list 跟 非list 混在一起 | ||||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) | |||||
# >1维list | # >1维list | ||||
inner_type_set = set() | inner_type_set = set() | ||||
for l in content: | for l in content: | ||||
@@ -229,7 +229,7 @@ class FieldArray(object): | |||||
return self._basic_type_detection(inner_inner_type_set) | return self._basic_type_detection(inner_inner_type_set) | ||||
else: | else: | ||||
# list 跟 非list 混在一起 | # list 跟 非list 混在一起 | ||||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, inner_type_set)) | |||||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(inner_type_set))) | |||||
else: | else: | ||||
# 一维list | # 一维list | ||||
for content_type in type_set: | for content_type in type_set: | ||||
@@ -253,17 +253,17 @@ class FieldArray(object): | |||||
return float | return float | ||||
else: | else: | ||||
# str 跟 int 或者 float 混在一起 | # str 跟 int 或者 float 混在一起 | ||||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) | |||||
else: | else: | ||||
# str, int, float混在一起 | # str, int, float混在一起 | ||||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) | |||||
def _1d_list_check(self, val): | def _1d_list_check(self, val): | ||||
"""如果不是1D list就报错 | """如果不是1D list就报错 | ||||
""" | """ | ||||
type_set = set((type(obj) for obj in val)) | type_set = set((type(obj) for obj in val)) | ||||
if any(obj not in self.BASIC_TYPES for obj in type_set): | if any(obj not in self.BASIC_TYPES for obj in type_set): | ||||
raise ValueError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||||
raise ValueError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) | |||||
self._basic_type_detection(type_set) | self._basic_type_detection(type_set) | ||||
# otherwise: _basic_type_detection will raise error | # otherwise: _basic_type_detection will raise error | ||||
return True | return True | ||||
@@ -192,7 +192,7 @@ class ConditionalRandomField(nn.Module): | |||||
seq_len, batch_size, n_tags = logits.size() | seq_len, batch_size, n_tags = logits.size() | ||||
alpha = logits[0] | alpha = logits[0] | ||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
alpha += self.start_scores.view(1, -1) | |||||
alpha = alpha + self.start_scores.view(1, -1) | |||||
flip_mask = mask.eq(0) | flip_mask = mask.eq(0) | ||||
@@ -204,7 +204,7 @@ class ConditionalRandomField(nn.Module): | |||||
alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0) | alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0) | ||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
alpha += self.end_scores.view(1, -1) | |||||
alpha = alpha + self.end_scores.view(1, -1) | |||||
return log_sum_exp(alpha, 1) | return log_sum_exp(alpha, 1) | ||||
@@ -233,7 +233,7 @@ class ConditionalRandomField(nn.Module): | |||||
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] | st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] | ||||
last_idx = mask.long().sum(0) - 1 | last_idx = mask.long().sum(0) - 1 | ||||
ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] | ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] | ||||
score += st_scores + ed_scores | |||||
score = score + st_scores + ed_scores | |||||
# return [B,] | # return [B,] | ||||
return score | return score | ||||
@@ -216,6 +216,11 @@ class TestDataSetMethods(unittest.TestCase): | |||||
self.assertTrue(isinstance(ds, DataSet)) | self.assertTrue(isinstance(ds, DataSet)) | ||||
self.assertTrue(len(ds) > 0) | self.assertTrue(len(ds) > 0) | ||||
def test_add_null(self): | |||||
ds = DataSet() | |||||
ds.add_field('test', []) | |||||
ds.set_target('test') | |||||
class TestDataSetIter(unittest.TestCase): | class TestDataSetIter(unittest.TestCase): | ||||
def test__repr__(self): | def test__repr__(self): | ||||
@@ -101,4 +101,28 @@ class TestCRF(unittest.TestCase): | |||||
# # seq equal | # # seq equal | ||||
# self.assertListEqual([_ for _, score in allen_res], fast_res[0]) | # self.assertListEqual([_ for _, score in allen_res], fast_res[0]) | ||||
def test_case3(self): | |||||
# 测试crf的loss不会出现负数 | |||||
import torch | |||||
from fastNLP.modules.decoder.CRF import ConditionalRandomField | |||||
from fastNLP.core.utils import seq_lens_to_masks | |||||
from torch import optim | |||||
from torch import nn | |||||
num_tags, include_start_end_trans = 4, True | |||||
num_samples = 4 | |||||
lengths = torch.randint(3, 50, size=(num_samples, )).long() | |||||
max_len = lengths.max() | |||||
tags = torch.randint(num_tags, size=(num_samples, max_len)) | |||||
masks = seq_lens_to_masks(lengths) | |||||
feats = nn.Parameter(torch.randn(num_samples, max_len, num_tags)) | |||||
crf = ConditionalRandomField(num_tags, include_start_end_trans) | |||||
optimizer = optim.SGD([param for param in crf.parameters() if param.requires_grad] + [feats], lr=0.1) | |||||
for _ in range(10000): | |||||
loss = crf(feats, tags, masks).mean() | |||||
optimizer.zero_grad() | |||||
loss.backward() | |||||
optimizer.step() | |||||
if _%1000==0: | |||||
print(loss) | |||||
assert loss.item()>0, "CRF loss cannot be less than 0." |