@@ -243,6 +243,8 @@ class DataSet(object): | |||
:param padder: PadderBase类型或None. 设置为None即删除padder。即对该field不进行padding操作. | |||
:return: | |||
""" | |||
if field_name not in self.field_arrays: | |||
raise KeyError("There is no field named {}.".format(field_name)) | |||
self.field_arrays[field_name].set_padder(padder) | |||
def set_pad_val(self, field_name, pad_val): | |||
@@ -253,6 +255,8 @@ class DataSet(object): | |||
:param pad_val: int,该field的padder会以pad_val作为padding index | |||
:return: | |||
""" | |||
if field_name not in self.field_arrays: | |||
raise KeyError("There is no field named {}.".format(field_name)) | |||
self.field_arrays[field_name].set_pad_val(pad_val) | |||
def get_input_name(self): | |||
@@ -206,7 +206,7 @@ class FieldArray(object): | |||
if list in type_set: | |||
if len(type_set) > 1: | |||
# list 跟 非list 混在一起 | |||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) | |||
# >1维list | |||
inner_type_set = set() | |||
for l in content: | |||
@@ -229,7 +229,7 @@ class FieldArray(object): | |||
return self._basic_type_detection(inner_inner_type_set) | |||
else: | |||
# list 跟 非list 混在一起 | |||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, inner_type_set)) | |||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(inner_type_set))) | |||
else: | |||
# 一维list | |||
for content_type in type_set: | |||
@@ -253,17 +253,17 @@ class FieldArray(object): | |||
return float | |||
else: | |||
# str 跟 int 或者 float 混在一起 | |||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) | |||
else: | |||
# str, int, float混在一起 | |||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) | |||
def _1d_list_check(self, val): | |||
"""如果不是1D list就报错 | |||
""" | |||
type_set = set((type(obj) for obj in val)) | |||
if any(obj not in self.BASIC_TYPES for obj in type_set): | |||
raise ValueError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||
raise ValueError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) | |||
self._basic_type_detection(type_set) | |||
# otherwise: _basic_type_detection will raise error | |||
return True | |||
@@ -192,7 +192,7 @@ class ConditionalRandomField(nn.Module): | |||
seq_len, batch_size, n_tags = logits.size() | |||
alpha = logits[0] | |||
if self.include_start_end_trans: | |||
alpha += self.start_scores.view(1, -1) | |||
alpha = alpha + self.start_scores.view(1, -1) | |||
flip_mask = mask.eq(0) | |||
@@ -204,7 +204,7 @@ class ConditionalRandomField(nn.Module): | |||
alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0) | |||
if self.include_start_end_trans: | |||
alpha += self.end_scores.view(1, -1) | |||
alpha = alpha + self.end_scores.view(1, -1) | |||
return log_sum_exp(alpha, 1) | |||
@@ -233,7 +233,7 @@ class ConditionalRandomField(nn.Module): | |||
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] | |||
last_idx = mask.long().sum(0) - 1 | |||
ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] | |||
score += st_scores + ed_scores | |||
score = score + st_scores + ed_scores | |||
# return [B,] | |||
return score | |||
@@ -216,6 +216,11 @@ class TestDataSetMethods(unittest.TestCase): | |||
self.assertTrue(isinstance(ds, DataSet)) | |||
self.assertTrue(len(ds) > 0) | |||
def test_add_null(self): | |||
ds = DataSet() | |||
ds.add_field('test', []) | |||
ds.set_target('test') | |||
class TestDataSetIter(unittest.TestCase): | |||
def test__repr__(self): | |||
@@ -101,4 +101,28 @@ class TestCRF(unittest.TestCase): | |||
# # seq equal | |||
# self.assertListEqual([_ for _, score in allen_res], fast_res[0]) | |||
def test_case3(self): | |||
# 测试crf的loss不会出现负数 | |||
import torch | |||
from fastNLP.modules.decoder.CRF import ConditionalRandomField | |||
from fastNLP.core.utils import seq_lens_to_masks | |||
from torch import optim | |||
from torch import nn | |||
num_tags, include_start_end_trans = 4, True | |||
num_samples = 4 | |||
lengths = torch.randint(3, 50, size=(num_samples, )).long() | |||
max_len = lengths.max() | |||
tags = torch.randint(num_tags, size=(num_samples, max_len)) | |||
masks = seq_lens_to_masks(lengths) | |||
feats = nn.Parameter(torch.randn(num_samples, max_len, num_tags)) | |||
crf = ConditionalRandomField(num_tags, include_start_end_trans) | |||
optimizer = optim.SGD([param for param in crf.parameters() if param.requires_grad] + [feats], lr=0.1) | |||
for _ in range(10000): | |||
loss = crf(feats, tags, masks).mean() | |||
optimizer.zero_grad() | |||
loss.backward() | |||
optimizer.step() | |||
if _%1000==0: | |||
print(loss) | |||
assert loss.item()>0, "CRF loss cannot be less than 0." |