@@ -217,11 +217,11 @@ class ModelProcessor(Processor): | |||
tmp_batch = [] | |||
value = value.cpu().numpy() | |||
if len(value.shape) == 1 or (len(value.shape)==2 and value.shape[1]==1): | |||
batch_output[key].extend(value.tolist()) | |||
else: | |||
for idx, seq_len in enumerate(seq_lens): | |||
tmp_batch.append(value[idx, :seq_len]) | |||
batch_output[key].extend(tmp_batch) | |||
else: | |||
batch_output[key].extend(value.tolist()) | |||
batch_output[self.seq_len_field_name].extend(seq_lens) | |||
@@ -53,6 +53,54 @@ class SeqLabelEvaluator(Evaluator): | |||
accuracy = total_correct / total_count | |||
return {"accuracy": float(accuracy)} | |||
class SeqLabelEvaluator2(Evaluator): | |||
# 上面的evaluator应该是错误的 | |||
def __init__(self, seq_lens_field_name='word_seq_origin_len'): | |||
super(SeqLabelEvaluator2, self).__init__() | |||
self.end_tagidx_set = set() | |||
self.seq_lens_field_name = seq_lens_field_name | |||
def __call__(self, predict, truth, **_): | |||
""" | |||
:param predict: list of batch, the network outputs from all batches. | |||
:param truth: list of dict, the ground truths from all batch_y. | |||
:return accuracy: | |||
""" | |||
seq_lens = _[self.seq_lens_field_name] | |||
corr_count = 0 | |||
pred_count = 0 | |||
truth_count = 0 | |||
for x, y, seq_len in zip(predict, truth, seq_lens): | |||
x = x.cpu().numpy() | |||
y = y.cpu().numpy() | |||
for idx, s_l in enumerate(seq_len): | |||
x_ = x[idx] | |||
y_ = y[idx] | |||
x_ = x_[:s_l] | |||
y_ = y_[:s_l] | |||
flag = True | |||
start = 0 | |||
for idx_i, (x_i, y_i) in enumerate(zip(x_, y_)): | |||
if x_i in self.end_tagidx_set: | |||
truth_count += 1 | |||
for j in range(start, idx_i + 1): | |||
if y_[j]!=x_[j]: | |||
flag = False | |||
break | |||
if flag: | |||
corr_count += 1 | |||
flag = True | |||
start = idx_i + 1 | |||
if y_i in self.end_tagidx_set: | |||
pred_count += 1 | |||
P = corr_count / (float(pred_count) + 1e-6) | |||
R = corr_count / (float(truth_count) + 1e-6) | |||
F = 2 * P * R / (P + R + 1e-6) | |||
return {"P": P, 'R':R, 'F': F} | |||
class SNLIEvaluator(Evaluator): | |||
def __init__(self): | |||
@@ -167,8 +167,10 @@ class AdvSeqLabel(SeqLabeling): | |||
x = self.Linear2(x) | |||
# x = x.view(batch_size, max_len, -1) | |||
# [batch_size, max_len, num_classes] | |||
# TODO seq_lens的key这样做不合理 | |||
return {"loss": self._internal_loss(x, truth) if truth is not None else None, | |||
"predict": self.decode(x)} | |||
"predict": self.decode(x), | |||
'word_seq_origin_len': word_seq_origin_len} | |||
def predict(self, **x): | |||
out = self.forward(**x) | |||
@@ -111,7 +111,7 @@ class POSCWSReader(DataSetLoader): | |||
continue | |||
line = ' '.join(words) | |||
if cut_long_sent: | |||
sents = cut_long_sent(line) | |||
sents = cut_long_sentence(line) | |||
else: | |||
sents = [line] | |||
for sent in sents: | |||
@@ -127,3 +127,50 @@ class POSCWSReader(DataSetLoader): | |||
return dataset | |||
class ConlluCWSReader(object): | |||
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BMES的tag)。 | |||
def __init__(self): | |||
pass | |||
def load(self, path, cut_long_sent=False): | |||
datalist = [] | |||
with open(path, 'r', encoding='utf-8') as f: | |||
sample = [] | |||
for line in f: | |||
if line.startswith('\n'): | |||
datalist.append(sample) | |||
sample = [] | |||
elif line.startswith('#'): | |||
continue | |||
else: | |||
sample.append(line.split('\t')) | |||
if len(sample) > 0: | |||
datalist.append(sample) | |||
ds = DataSet() | |||
for sample in datalist: | |||
# print(sample) | |||
res = self.get_one(sample) | |||
if res is None: | |||
continue | |||
line = ' '.join(res) | |||
if cut_long_sent: | |||
sents = cut_long_sentence(line) | |||
else: | |||
sents = [line] | |||
for raw_sentence in sents: | |||
ds.append(Instance(raw_sentence=raw_sentence)) | |||
return ds | |||
def get_one(self, sample): | |||
if len(sample)==0: | |||
return None | |||
text = [] | |||
for w in sample: | |||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||
if t3 == '_': | |||
return None | |||
text.append(t1) | |||
return text | |||
@@ -117,3 +117,56 @@ class CWSBiLSTMSegApp(BaseModel): | |||
pred_probs = pred_dict['pred_probs'] | |||
_, pred_tags = pred_probs.max(dim=-1) | |||
return {'pred_tags': pred_tags} | |||
from fastNLP.modules.decoder.CRF import ConditionalRandomField | |||
class CWSBiLSTMCRF(BaseModel): | |||
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | |||
hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1, tag_size=4): | |||
super(CWSBiLSTMCRF, self).__init__() | |||
self.tag_size = tag_size | |||
self.encoder_model = CWSBiLSTMEncoder(vocab_num, embed_dim, bigram_vocab_num, bigram_embed_dim, num_bigram_per_char, | |||
hidden_size, bidirectional, embed_drop_p, num_layers) | |||
size_layer = [hidden_size, 200, tag_size] | |||
self.decoder_model = MLP(size_layer) | |||
self.crf = ConditionalRandomField(tag_size=tag_size, include_start_end_trans=False) | |||
def forward(self, chars, tags, seq_lens, bigrams=None): | |||
device = self.parameters().__next__().device | |||
chars = chars.to(device).long() | |||
if not bigrams is None: | |||
bigrams = bigrams.to(device).long() | |||
else: | |||
bigrams = None | |||
seq_lens = seq_lens.to(device).long() | |||
masks = seq_lens_to_mask(seq_lens) | |||
feats = self.encoder_model(chars, bigrams, seq_lens) | |||
feats = self.decoder_model(feats) | |||
losses = self.crf(feats, tags, masks) | |||
pred_dict = {} | |||
pred_dict['seq_lens'] = seq_lens | |||
pred_dict['loss'] = torch.mean(losses) | |||
return pred_dict | |||
def predict(self, chars, seq_lens, bigrams=None): | |||
device = self.parameters().__next__().device | |||
chars = chars.to(device).long() | |||
if not bigrams is None: | |||
bigrams = bigrams.to(device).long() | |||
else: | |||
bigrams = None | |||
seq_lens = seq_lens.to(device).long() | |||
masks = seq_lens_to_mask(seq_lens) | |||
feats = self.encoder_model(chars, bigrams, seq_lens) | |||
feats = self.decoder_model(feats) | |||
probs = self.crf.viterbi_decode(feats, masks, get_score=False) | |||
return {'pred_tags': probs} | |||
@@ -118,6 +118,23 @@ class CWSTagProcessor(Processor): | |||
def _tags_from_word_len(self, word_len): | |||
raise NotImplementedError | |||
class CWSBMESTagProcessor(CWSTagProcessor): | |||
def __init__(self, field_name, new_added_field_name=None): | |||
super(CWSBMESTagProcessor, self).__init__(field_name, new_added_field_name) | |||
self.tag_size = 4 | |||
def _tags_from_word_len(self, word_len): | |||
tag_list = [] | |||
if word_len == 1: | |||
tag_list.append(3) | |||
else: | |||
tag_list.append(0) | |||
for _ in range(word_len-2): | |||
tag_list.append(1) | |||
tag_list.append(2) | |||
return tag_list | |||
class CWSSegAppTagProcessor(CWSTagProcessor): | |||
def __init__(self, field_name, new_added_field_name=None): | |||
@@ -239,3 +256,29 @@ class SegApp2OutputProcessor(Processor): | |||
start_idx = idx + 1 | |||
ins[self.new_added_field_name] = ' '.join(words) | |||
class BMES2OutputProcessor(Processor): | |||
def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output'): | |||
super(BMES2OutputProcessor, self).__init__(None, None) | |||
self.chars_field_name = chars_field_name | |||
self.tag_field_name = tag_field_name | |||
self.new_added_field_name = new_added_field_name | |||
def process(self, dataset): | |||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||
for ins in dataset: | |||
pred_tags = ins[self.tag_field_name] | |||
chars = ins[self.chars_field_name] | |||
words = [] | |||
start_idx = 0 | |||
for idx, tag in enumerate(pred_tags): | |||
if tag==3: | |||
# 当前没有考虑将原文替换回去 | |||
words.extend(chars[start_idx:idx+1]) | |||
start_idx = idx + 1 | |||
elif tag==2: | |||
words.append(''.join(chars[start_idx:idx+1])) | |||
start_idx = idx + 1 | |||
ins[self.new_added_field_name] = ' '.join(words) |
@@ -24,37 +24,52 @@ def refine_ys_on_seq_len(ys, seq_lens): | |||
def flat_nested_list(nested_list): | |||
return list(chain(*nested_list)) | |||
def calculate_pre_rec_f1(model, batcher): | |||
def calculate_pre_rec_f1(model, batcher, type='segapp'): | |||
true_ys, pred_ys = decode_iterator(model, batcher) | |||
true_ys = flat_nested_list(true_ys) | |||
pred_ys = flat_nested_list(pred_ys) | |||
cor_num = 0 | |||
yp_wordnum = pred_ys.count(1) | |||
yt_wordnum = true_ys.count(1) | |||
start = 0 | |||
if true_ys[0]==1 and pred_ys[0]==1: | |||
cor_num += 1 | |||
start = 1 | |||
for i in range(1, len(true_ys)): | |||
if true_ys[i] == 1: | |||
flag = True | |||
if true_ys[start-1] != pred_ys[start-1]: | |||
flag = False | |||
else: | |||
if type=='segapp': | |||
yp_wordnum = pred_ys.count(1) | |||
yt_wordnum = true_ys.count(1) | |||
if true_ys[0]==1 and pred_ys[0]==1: | |||
cor_num += 1 | |||
start = 1 | |||
for i in range(1, len(true_ys)): | |||
if true_ys[i] == 1: | |||
flag = True | |||
if true_ys[start-1] != pred_ys[start-1]: | |||
flag = False | |||
else: | |||
for j in range(start, i + 1): | |||
if true_ys[j] != pred_ys[j]: | |||
flag = False | |||
break | |||
if flag: | |||
cor_num += 1 | |||
start = i + 1 | |||
elif type=='bmes': | |||
yp_wordnum = pred_ys.count(2) + pred_ys.count(3) | |||
yt_wordnum = true_ys.count(2) + true_ys.count(3) | |||
for i in range(len(true_ys)): | |||
if true_ys[i] == 2 or true_ys[i] == 3: | |||
flag = True | |||
for j in range(start, i + 1): | |||
if true_ys[j] != pred_ys[j]: | |||
flag = False | |||
break | |||
if flag: | |||
cor_num += 1 | |||
start = i + 1 | |||
if flag: | |||
cor_num += 1 | |||
start = i + 1 | |||
P = cor_num / (float(yp_wordnum) + 1e-6) | |||
R = cor_num / (float(yt_wordnum) + 1e-6) | |||
F = 2 * P * R / (P + R + 1e-6) | |||
print(cor_num, yt_wordnum, yp_wordnum) | |||
# print(cor_num, yt_wordnum, yp_wordnum) | |||
return P, R, F | |||
@@ -0,0 +1,89 @@ | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.instance import Instance | |||
def cut_long_sentence(sent, max_sample_length=200): | |||
sent_no_space = sent.replace(' ', '') | |||
cutted_sentence = [] | |||
if len(sent_no_space) > max_sample_length: | |||
parts = sent.strip().split() | |||
new_line = '' | |||
length = 0 | |||
for part in parts: | |||
length += len(part) | |||
new_line += part + ' ' | |||
if length > max_sample_length: | |||
new_line = new_line[:-1] | |||
cutted_sentence.append(new_line) | |||
length = 0 | |||
new_line = '' | |||
if new_line != '': | |||
cutted_sentence.append(new_line[:-1]) | |||
else: | |||
cutted_sentence.append(sent) | |||
return cutted_sentence | |||
class ConlluPOSReader(object): | |||
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BMES的tag)。 | |||
def __init__(self): | |||
pass | |||
def load(self, path): | |||
datalist = [] | |||
with open(path, 'r', encoding='utf-8') as f: | |||
sample = [] | |||
for line in f: | |||
if line.startswith('\n'): | |||
datalist.append(sample) | |||
sample = [] | |||
elif line.startswith('#'): | |||
continue | |||
else: | |||
sample.append(line.split('\t')) | |||
if len(sample) > 0: | |||
datalist.append(sample) | |||
ds = DataSet() | |||
for sample in datalist: | |||
# print(sample) | |||
res = self.get_one(sample) | |||
if res is None: | |||
continue | |||
char_seq = [] | |||
pos_seq = [] | |||
for word, tag in zip(res[0], res[1]): | |||
if len(word)==1: | |||
char_seq.append(word) | |||
pos_seq.append('S-{}'.format(tag)) | |||
elif len(word)>1: | |||
pos_seq.append('B-{}'.format(tag)) | |||
for _ in range(len(word)-2): | |||
pos_seq.append('M-{}'.format(tag)) | |||
pos_seq.append('E-{}'.format(tag)) | |||
char_seq.extend(list(word)) | |||
else: | |||
raise ValueError("Zero length of word detected.") | |||
ds.append(Instance(words=char_seq, | |||
tag=pos_seq)) | |||
return ds | |||
def get_one(self, sample): | |||
if len(sample)==0: | |||
return None | |||
text = [] | |||
pos_tags = [] | |||
for w in sample: | |||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||
if t3 == '_': | |||
return None | |||
text.append(t1) | |||
pos_tags.append(t2) | |||
return text, pos_tags | |||
if __name__ == '__main__': | |||
reader = ConlluPOSReader() | |||
d = reader.load('/home/hyan/train.conllx') | |||
print('reader') |
@@ -1,16 +1,18 @@ | |||
[train] | |||
epochs = 300 | |||
epochs = 6 | |||
batch_size = 32 | |||
pickle_path = "./save/" | |||
validate = false | |||
validate = true | |||
save_best_dev = true | |||
model_saved_path = "./save/" | |||
valid_step = 250 | |||
eval_sort_key = 'accuracy' | |||
[model] | |||
rnn_hidden_units = 100 | |||
word_emb_dim = 100 | |||
rnn_hidden_units = 300 | |||
word_emb_dim = 300 | |||
dropout = 0.5 | |||
use_crf = true | |||
use_cuda = true | |||
print_every_step = 10 | |||
[test] | |||
@@ -34,4 +36,4 @@ pickle_path = "./save/" | |||
use_crf = true | |||
use_cuda = true | |||
rnn_hidden_units = 100 | |||
word_emb_dim = 100 | |||
word_emb_dim = 100 |
@@ -78,7 +78,7 @@ class PosOutputStrProcessor(Processor): | |||
word_pos_list = [] | |||
for word, pos in zip(word_list, pos_list): | |||
word_pos_list.append(word + self.sep + pos) | |||
#TODO 应该可以定制 | |||
ins['word_pos_output'] = ' '.join(word_pos_list) | |||
return dataset | |||