@@ -217,11 +217,11 @@ class ModelProcessor(Processor): | |||||
tmp_batch = [] | tmp_batch = [] | ||||
value = value.cpu().numpy() | value = value.cpu().numpy() | ||||
if len(value.shape) == 1 or (len(value.shape)==2 and value.shape[1]==1): | if len(value.shape) == 1 or (len(value.shape)==2 and value.shape[1]==1): | ||||
batch_output[key].extend(value.tolist()) | |||||
else: | |||||
for idx, seq_len in enumerate(seq_lens): | for idx, seq_len in enumerate(seq_lens): | ||||
tmp_batch.append(value[idx, :seq_len]) | tmp_batch.append(value[idx, :seq_len]) | ||||
batch_output[key].extend(tmp_batch) | batch_output[key].extend(tmp_batch) | ||||
else: | |||||
batch_output[key].extend(value.tolist()) | |||||
batch_output[self.seq_len_field_name].extend(seq_lens) | batch_output[self.seq_len_field_name].extend(seq_lens) | ||||
@@ -53,6 +53,54 @@ class SeqLabelEvaluator(Evaluator): | |||||
accuracy = total_correct / total_count | accuracy = total_correct / total_count | ||||
return {"accuracy": float(accuracy)} | return {"accuracy": float(accuracy)} | ||||
class SeqLabelEvaluator2(Evaluator): | |||||
# 上面的evaluator应该是错误的 | |||||
def __init__(self, seq_lens_field_name='word_seq_origin_len'): | |||||
super(SeqLabelEvaluator2, self).__init__() | |||||
self.end_tagidx_set = set() | |||||
self.seq_lens_field_name = seq_lens_field_name | |||||
def __call__(self, predict, truth, **_): | |||||
""" | |||||
:param predict: list of batch, the network outputs from all batches. | |||||
:param truth: list of dict, the ground truths from all batch_y. | |||||
:return accuracy: | |||||
""" | |||||
seq_lens = _[self.seq_lens_field_name] | |||||
corr_count = 0 | |||||
pred_count = 0 | |||||
truth_count = 0 | |||||
for x, y, seq_len in zip(predict, truth, seq_lens): | |||||
x = x.cpu().numpy() | |||||
y = y.cpu().numpy() | |||||
for idx, s_l in enumerate(seq_len): | |||||
x_ = x[idx] | |||||
y_ = y[idx] | |||||
x_ = x_[:s_l] | |||||
y_ = y_[:s_l] | |||||
flag = True | |||||
start = 0 | |||||
for idx_i, (x_i, y_i) in enumerate(zip(x_, y_)): | |||||
if x_i in self.end_tagidx_set: | |||||
truth_count += 1 | |||||
for j in range(start, idx_i + 1): | |||||
if y_[j]!=x_[j]: | |||||
flag = False | |||||
break | |||||
if flag: | |||||
corr_count += 1 | |||||
flag = True | |||||
start = idx_i + 1 | |||||
if y_i in self.end_tagidx_set: | |||||
pred_count += 1 | |||||
P = corr_count / (float(pred_count) + 1e-6) | |||||
R = corr_count / (float(truth_count) + 1e-6) | |||||
F = 2 * P * R / (P + R + 1e-6) | |||||
return {"P": P, 'R':R, 'F': F} | |||||
class SNLIEvaluator(Evaluator): | class SNLIEvaluator(Evaluator): | ||||
def __init__(self): | def __init__(self): | ||||
@@ -167,8 +167,10 @@ class AdvSeqLabel(SeqLabeling): | |||||
x = self.Linear2(x) | x = self.Linear2(x) | ||||
# x = x.view(batch_size, max_len, -1) | # x = x.view(batch_size, max_len, -1) | ||||
# [batch_size, max_len, num_classes] | # [batch_size, max_len, num_classes] | ||||
# TODO seq_lens的key这样做不合理 | |||||
return {"loss": self._internal_loss(x, truth) if truth is not None else None, | return {"loss": self._internal_loss(x, truth) if truth is not None else None, | ||||
"predict": self.decode(x)} | |||||
"predict": self.decode(x), | |||||
'word_seq_origin_len': word_seq_origin_len} | |||||
def predict(self, **x): | def predict(self, **x): | ||||
out = self.forward(**x) | out = self.forward(**x) | ||||
@@ -111,7 +111,7 @@ class POSCWSReader(DataSetLoader): | |||||
continue | continue | ||||
line = ' '.join(words) | line = ' '.join(words) | ||||
if cut_long_sent: | if cut_long_sent: | ||||
sents = cut_long_sent(line) | |||||
sents = cut_long_sentence(line) | |||||
else: | else: | ||||
sents = [line] | sents = [line] | ||||
for sent in sents: | for sent in sents: | ||||
@@ -127,3 +127,50 @@ class POSCWSReader(DataSetLoader): | |||||
return dataset | return dataset | ||||
class ConlluCWSReader(object): | |||||
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BMES的tag)。 | |||||
def __init__(self): | |||||
pass | |||||
def load(self, path, cut_long_sent=False): | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
ds = DataSet() | |||||
for sample in datalist: | |||||
# print(sample) | |||||
res = self.get_one(sample) | |||||
if res is None: | |||||
continue | |||||
line = ' '.join(res) | |||||
if cut_long_sent: | |||||
sents = cut_long_sentence(line) | |||||
else: | |||||
sents = [line] | |||||
for raw_sentence in sents: | |||||
ds.append(Instance(raw_sentence=raw_sentence)) | |||||
return ds | |||||
def get_one(self, sample): | |||||
if len(sample)==0: | |||||
return None | |||||
text = [] | |||||
for w in sample: | |||||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||||
if t3 == '_': | |||||
return None | |||||
text.append(t1) | |||||
return text | |||||
@@ -117,3 +117,56 @@ class CWSBiLSTMSegApp(BaseModel): | |||||
pred_probs = pred_dict['pred_probs'] | pred_probs = pred_dict['pred_probs'] | ||||
_, pred_tags = pred_probs.max(dim=-1) | _, pred_tags = pred_probs.max(dim=-1) | ||||
return {'pred_tags': pred_tags} | return {'pred_tags': pred_tags} | ||||
from fastNLP.modules.decoder.CRF import ConditionalRandomField | |||||
class CWSBiLSTMCRF(BaseModel): | |||||
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | |||||
hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1, tag_size=4): | |||||
super(CWSBiLSTMCRF, self).__init__() | |||||
self.tag_size = tag_size | |||||
self.encoder_model = CWSBiLSTMEncoder(vocab_num, embed_dim, bigram_vocab_num, bigram_embed_dim, num_bigram_per_char, | |||||
hidden_size, bidirectional, embed_drop_p, num_layers) | |||||
size_layer = [hidden_size, 200, tag_size] | |||||
self.decoder_model = MLP(size_layer) | |||||
self.crf = ConditionalRandomField(tag_size=tag_size, include_start_end_trans=False) | |||||
def forward(self, chars, tags, seq_lens, bigrams=None): | |||||
device = self.parameters().__next__().device | |||||
chars = chars.to(device).long() | |||||
if not bigrams is None: | |||||
bigrams = bigrams.to(device).long() | |||||
else: | |||||
bigrams = None | |||||
seq_lens = seq_lens.to(device).long() | |||||
masks = seq_lens_to_mask(seq_lens) | |||||
feats = self.encoder_model(chars, bigrams, seq_lens) | |||||
feats = self.decoder_model(feats) | |||||
losses = self.crf(feats, tags, masks) | |||||
pred_dict = {} | |||||
pred_dict['seq_lens'] = seq_lens | |||||
pred_dict['loss'] = torch.mean(losses) | |||||
return pred_dict | |||||
def predict(self, chars, seq_lens, bigrams=None): | |||||
device = self.parameters().__next__().device | |||||
chars = chars.to(device).long() | |||||
if not bigrams is None: | |||||
bigrams = bigrams.to(device).long() | |||||
else: | |||||
bigrams = None | |||||
seq_lens = seq_lens.to(device).long() | |||||
masks = seq_lens_to_mask(seq_lens) | |||||
feats = self.encoder_model(chars, bigrams, seq_lens) | |||||
feats = self.decoder_model(feats) | |||||
probs = self.crf.viterbi_decode(feats, masks, get_score=False) | |||||
return {'pred_tags': probs} | |||||
@@ -118,6 +118,23 @@ class CWSTagProcessor(Processor): | |||||
def _tags_from_word_len(self, word_len): | def _tags_from_word_len(self, word_len): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
class CWSBMESTagProcessor(CWSTagProcessor): | |||||
def __init__(self, field_name, new_added_field_name=None): | |||||
super(CWSBMESTagProcessor, self).__init__(field_name, new_added_field_name) | |||||
self.tag_size = 4 | |||||
def _tags_from_word_len(self, word_len): | |||||
tag_list = [] | |||||
if word_len == 1: | |||||
tag_list.append(3) | |||||
else: | |||||
tag_list.append(0) | |||||
for _ in range(word_len-2): | |||||
tag_list.append(1) | |||||
tag_list.append(2) | |||||
return tag_list | |||||
class CWSSegAppTagProcessor(CWSTagProcessor): | class CWSSegAppTagProcessor(CWSTagProcessor): | ||||
def __init__(self, field_name, new_added_field_name=None): | def __init__(self, field_name, new_added_field_name=None): | ||||
@@ -239,3 +256,29 @@ class SegApp2OutputProcessor(Processor): | |||||
start_idx = idx + 1 | start_idx = idx + 1 | ||||
ins[self.new_added_field_name] = ' '.join(words) | ins[self.new_added_field_name] = ' '.join(words) | ||||
class BMES2OutputProcessor(Processor): | |||||
def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output'): | |||||
super(BMES2OutputProcessor, self).__init__(None, None) | |||||
self.chars_field_name = chars_field_name | |||||
self.tag_field_name = tag_field_name | |||||
self.new_added_field_name = new_added_field_name | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
pred_tags = ins[self.tag_field_name] | |||||
chars = ins[self.chars_field_name] | |||||
words = [] | |||||
start_idx = 0 | |||||
for idx, tag in enumerate(pred_tags): | |||||
if tag==3: | |||||
# 当前没有考虑将原文替换回去 | |||||
words.extend(chars[start_idx:idx+1]) | |||||
start_idx = idx + 1 | |||||
elif tag==2: | |||||
words.append(''.join(chars[start_idx:idx+1])) | |||||
start_idx = idx + 1 | |||||
ins[self.new_added_field_name] = ' '.join(words) |
@@ -24,37 +24,52 @@ def refine_ys_on_seq_len(ys, seq_lens): | |||||
def flat_nested_list(nested_list): | def flat_nested_list(nested_list): | ||||
return list(chain(*nested_list)) | return list(chain(*nested_list)) | ||||
def calculate_pre_rec_f1(model, batcher): | |||||
def calculate_pre_rec_f1(model, batcher, type='segapp'): | |||||
true_ys, pred_ys = decode_iterator(model, batcher) | true_ys, pred_ys = decode_iterator(model, batcher) | ||||
true_ys = flat_nested_list(true_ys) | true_ys = flat_nested_list(true_ys) | ||||
pred_ys = flat_nested_list(pred_ys) | pred_ys = flat_nested_list(pred_ys) | ||||
cor_num = 0 | cor_num = 0 | ||||
yp_wordnum = pred_ys.count(1) | |||||
yt_wordnum = true_ys.count(1) | |||||
start = 0 | start = 0 | ||||
if true_ys[0]==1 and pred_ys[0]==1: | |||||
cor_num += 1 | |||||
start = 1 | |||||
for i in range(1, len(true_ys)): | |||||
if true_ys[i] == 1: | |||||
flag = True | |||||
if true_ys[start-1] != pred_ys[start-1]: | |||||
flag = False | |||||
else: | |||||
if type=='segapp': | |||||
yp_wordnum = pred_ys.count(1) | |||||
yt_wordnum = true_ys.count(1) | |||||
if true_ys[0]==1 and pred_ys[0]==1: | |||||
cor_num += 1 | |||||
start = 1 | |||||
for i in range(1, len(true_ys)): | |||||
if true_ys[i] == 1: | |||||
flag = True | |||||
if true_ys[start-1] != pred_ys[start-1]: | |||||
flag = False | |||||
else: | |||||
for j in range(start, i + 1): | |||||
if true_ys[j] != pred_ys[j]: | |||||
flag = False | |||||
break | |||||
if flag: | |||||
cor_num += 1 | |||||
start = i + 1 | |||||
elif type=='bmes': | |||||
yp_wordnum = pred_ys.count(2) + pred_ys.count(3) | |||||
yt_wordnum = true_ys.count(2) + true_ys.count(3) | |||||
for i in range(len(true_ys)): | |||||
if true_ys[i] == 2 or true_ys[i] == 3: | |||||
flag = True | |||||
for j in range(start, i + 1): | for j in range(start, i + 1): | ||||
if true_ys[j] != pred_ys[j]: | if true_ys[j] != pred_ys[j]: | ||||
flag = False | flag = False | ||||
break | break | ||||
if flag: | |||||
cor_num += 1 | |||||
start = i + 1 | |||||
if flag: | |||||
cor_num += 1 | |||||
start = i + 1 | |||||
P = cor_num / (float(yp_wordnum) + 1e-6) | P = cor_num / (float(yp_wordnum) + 1e-6) | ||||
R = cor_num / (float(yt_wordnum) + 1e-6) | R = cor_num / (float(yt_wordnum) + 1e-6) | ||||
F = 2 * P * R / (P + R + 1e-6) | F = 2 * P * R / (P + R + 1e-6) | ||||
print(cor_num, yt_wordnum, yp_wordnum) | |||||
# print(cor_num, yt_wordnum, yp_wordnum) | |||||
return P, R, F | return P, R, F | ||||
@@ -0,0 +1,89 @@ | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.instance import Instance | |||||
def cut_long_sentence(sent, max_sample_length=200): | |||||
sent_no_space = sent.replace(' ', '') | |||||
cutted_sentence = [] | |||||
if len(sent_no_space) > max_sample_length: | |||||
parts = sent.strip().split() | |||||
new_line = '' | |||||
length = 0 | |||||
for part in parts: | |||||
length += len(part) | |||||
new_line += part + ' ' | |||||
if length > max_sample_length: | |||||
new_line = new_line[:-1] | |||||
cutted_sentence.append(new_line) | |||||
length = 0 | |||||
new_line = '' | |||||
if new_line != '': | |||||
cutted_sentence.append(new_line[:-1]) | |||||
else: | |||||
cutted_sentence.append(sent) | |||||
return cutted_sentence | |||||
class ConlluPOSReader(object): | |||||
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BMES的tag)。 | |||||
def __init__(self): | |||||
pass | |||||
def load(self, path): | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
ds = DataSet() | |||||
for sample in datalist: | |||||
# print(sample) | |||||
res = self.get_one(sample) | |||||
if res is None: | |||||
continue | |||||
char_seq = [] | |||||
pos_seq = [] | |||||
for word, tag in zip(res[0], res[1]): | |||||
if len(word)==1: | |||||
char_seq.append(word) | |||||
pos_seq.append('S-{}'.format(tag)) | |||||
elif len(word)>1: | |||||
pos_seq.append('B-{}'.format(tag)) | |||||
for _ in range(len(word)-2): | |||||
pos_seq.append('M-{}'.format(tag)) | |||||
pos_seq.append('E-{}'.format(tag)) | |||||
char_seq.extend(list(word)) | |||||
else: | |||||
raise ValueError("Zero length of word detected.") | |||||
ds.append(Instance(words=char_seq, | |||||
tag=pos_seq)) | |||||
return ds | |||||
def get_one(self, sample): | |||||
if len(sample)==0: | |||||
return None | |||||
text = [] | |||||
pos_tags = [] | |||||
for w in sample: | |||||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||||
if t3 == '_': | |||||
return None | |||||
text.append(t1) | |||||
pos_tags.append(t2) | |||||
return text, pos_tags | |||||
if __name__ == '__main__': | |||||
reader = ConlluPOSReader() | |||||
d = reader.load('/home/hyan/train.conllx') | |||||
print('reader') |
@@ -1,16 +1,18 @@ | |||||
[train] | [train] | ||||
epochs = 300 | |||||
epochs = 6 | |||||
batch_size = 32 | batch_size = 32 | ||||
pickle_path = "./save/" | pickle_path = "./save/" | ||||
validate = false | |||||
validate = true | |||||
save_best_dev = true | save_best_dev = true | ||||
model_saved_path = "./save/" | model_saved_path = "./save/" | ||||
valid_step = 250 | |||||
eval_sort_key = 'accuracy' | |||||
[model] | [model] | ||||
rnn_hidden_units = 100 | |||||
word_emb_dim = 100 | |||||
rnn_hidden_units = 300 | |||||
word_emb_dim = 300 | |||||
dropout = 0.5 | |||||
use_crf = true | use_crf = true | ||||
use_cuda = true | |||||
print_every_step = 10 | print_every_step = 10 | ||||
[test] | [test] | ||||
@@ -34,4 +36,4 @@ pickle_path = "./save/" | |||||
use_crf = true | use_crf = true | ||||
use_cuda = true | use_cuda = true | ||||
rnn_hidden_units = 100 | rnn_hidden_units = 100 | ||||
word_emb_dim = 100 | |||||
word_emb_dim = 100 |
@@ -78,7 +78,7 @@ class PosOutputStrProcessor(Processor): | |||||
word_pos_list = [] | word_pos_list = [] | ||||
for word, pos in zip(word_list, pos_list): | for word, pos in zip(word_list, pos_list): | ||||
word_pos_list.append(word + self.sep + pos) | word_pos_list.append(word + self.sep + pos) | ||||
#TODO 应该可以定制 | |||||
ins['word_pos_output'] = ' '.join(word_pos_list) | ins['word_pos_output'] = ' '.join(word_pos_list) | ||||
return dataset | return dataset | ||||