Browse Source

pos与cws开发上传

tags/v0.2.0
yh yunfan 5 years ago
parent
commit
77786509df
11 changed files with 365 additions and 50 deletions
  1. +38
    -22
      fastNLP/api/api.py
  2. +2
    -2
      fastNLP/api/processor.py
  3. +48
    -0
      fastNLP/core/metrics.py
  4. +3
    -1
      fastNLP/models/sequence_modeling.py
  5. +48
    -1
      reproduction/chinese_word_segment/cws_io/cws_reader.py
  6. +53
    -0
      reproduction/chinese_word_segment/models/cws_model.py
  7. +43
    -0
      reproduction/chinese_word_segment/process/cws_processor.py
  8. +32
    -17
      reproduction/chinese_word_segment/utils.py
  9. +89
    -0
      reproduction/pos_tag_model/pos_io/pos_reader.py
  10. +8
    -6
      reproduction/pos_tag_model/pos_tag.cfg
  11. +1
    -1
      reproduction/pos_tag_model/process/pos_processor.py

+ 38
- 22
fastNLP/api/api.py
File diff suppressed because it is too large
View File


+ 2
- 2
fastNLP/api/processor.py View File

@@ -217,11 +217,11 @@ class ModelProcessor(Processor):
tmp_batch = []
value = value.cpu().numpy()
if len(value.shape) == 1 or (len(value.shape)==2 and value.shape[1]==1):
batch_output[key].extend(value.tolist())
else:
for idx, seq_len in enumerate(seq_lens):
tmp_batch.append(value[idx, :seq_len])
batch_output[key].extend(tmp_batch)
else:
batch_output[key].extend(value.tolist())

batch_output[self.seq_len_field_name].extend(seq_lens)



+ 48
- 0
fastNLP/core/metrics.py View File

@@ -53,6 +53,54 @@ class SeqLabelEvaluator(Evaluator):
accuracy = total_correct / total_count
return {"accuracy": float(accuracy)}

class SeqLabelEvaluator2(Evaluator):
# 上面的evaluator应该是错误的
def __init__(self, seq_lens_field_name='word_seq_origin_len'):
super(SeqLabelEvaluator2, self).__init__()
self.end_tagidx_set = set()
self.seq_lens_field_name = seq_lens_field_name

def __call__(self, predict, truth, **_):
"""

:param predict: list of batch, the network outputs from all batches.
:param truth: list of dict, the ground truths from all batch_y.
:return accuracy:
"""
seq_lens = _[self.seq_lens_field_name]
corr_count = 0
pred_count = 0
truth_count = 0
for x, y, seq_len in zip(predict, truth, seq_lens):
x = x.cpu().numpy()
y = y.cpu().numpy()
for idx, s_l in enumerate(seq_len):
x_ = x[idx]
y_ = y[idx]
x_ = x_[:s_l]
y_ = y_[:s_l]
flag = True
start = 0
for idx_i, (x_i, y_i) in enumerate(zip(x_, y_)):
if x_i in self.end_tagidx_set:
truth_count += 1
for j in range(start, idx_i + 1):
if y_[j]!=x_[j]:
flag = False
break
if flag:
corr_count += 1
flag = True
start = idx_i + 1
if y_i in self.end_tagidx_set:
pred_count += 1
P = corr_count / (float(pred_count) + 1e-6)
R = corr_count / (float(truth_count) + 1e-6)
F = 2 * P * R / (P + R + 1e-6)

return {"P": P, 'R':R, 'F': F}



class SNLIEvaluator(Evaluator):
def __init__(self):


+ 3
- 1
fastNLP/models/sequence_modeling.py View File

@@ -167,8 +167,10 @@ class AdvSeqLabel(SeqLabeling):
x = self.Linear2(x)
# x = x.view(batch_size, max_len, -1)
# [batch_size, max_len, num_classes]
# TODO seq_lens的key这样做不合理
return {"loss": self._internal_loss(x, truth) if truth is not None else None,
"predict": self.decode(x)}
"predict": self.decode(x),
'word_seq_origin_len': word_seq_origin_len}

def predict(self, **x):
out = self.forward(**x)


+ 48
- 1
reproduction/chinese_word_segment/cws_io/cws_reader.py View File

@@ -111,7 +111,7 @@ class POSCWSReader(DataSetLoader):
continue
line = ' '.join(words)
if cut_long_sent:
sents = cut_long_sent(line)
sents = cut_long_sentence(line)
else:
sents = [line]
for sent in sents:
@@ -127,3 +127,50 @@ class POSCWSReader(DataSetLoader):
return dataset


class ConlluCWSReader(object):
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BMES的tag)。
def __init__(self):
pass

def load(self, path, cut_long_sent=False):
datalist = []
with open(path, 'r', encoding='utf-8') as f:
sample = []
for line in f:
if line.startswith('\n'):
datalist.append(sample)
sample = []
elif line.startswith('#'):
continue
else:
sample.append(line.split('\t'))
if len(sample) > 0:
datalist.append(sample)

ds = DataSet()
for sample in datalist:
# print(sample)
res = self.get_one(sample)
if res is None:
continue
line = ' '.join(res)
if cut_long_sent:
sents = cut_long_sentence(line)
else:
sents = [line]
for raw_sentence in sents:
ds.append(Instance(raw_sentence=raw_sentence))

return ds

def get_one(self, sample):
if len(sample)==0:
return None
text = []
for w in sample:
t1, t2, t3, t4 = w[1], w[3], w[6], w[7]
if t3 == '_':
return None
text.append(t1)
return text


+ 53
- 0
reproduction/chinese_word_segment/models/cws_model.py View File

@@ -117,3 +117,56 @@ class CWSBiLSTMSegApp(BaseModel):
pred_probs = pred_dict['pred_probs']
_, pred_tags = pred_probs.max(dim=-1)
return {'pred_tags': pred_tags}


from fastNLP.modules.decoder.CRF import ConditionalRandomField

class CWSBiLSTMCRF(BaseModel):
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None,
hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1, tag_size=4):
super(CWSBiLSTMCRF, self).__init__()

self.tag_size = tag_size

self.encoder_model = CWSBiLSTMEncoder(vocab_num, embed_dim, bigram_vocab_num, bigram_embed_dim, num_bigram_per_char,
hidden_size, bidirectional, embed_drop_p, num_layers)

size_layer = [hidden_size, 200, tag_size]
self.decoder_model = MLP(size_layer)
self.crf = ConditionalRandomField(tag_size=tag_size, include_start_end_trans=False)


def forward(self, chars, tags, seq_lens, bigrams=None):
device = self.parameters().__next__().device
chars = chars.to(device).long()
if not bigrams is None:
bigrams = bigrams.to(device).long()
else:
bigrams = None
seq_lens = seq_lens.to(device).long()
masks = seq_lens_to_mask(seq_lens)
feats = self.encoder_model(chars, bigrams, seq_lens)
feats = self.decoder_model(feats)
losses = self.crf(feats, tags, masks)

pred_dict = {}
pred_dict['seq_lens'] = seq_lens
pred_dict['loss'] = torch.mean(losses)

return pred_dict

def predict(self, chars, seq_lens, bigrams=None):
device = self.parameters().__next__().device
chars = chars.to(device).long()
if not bigrams is None:
bigrams = bigrams.to(device).long()
else:
bigrams = None
seq_lens = seq_lens.to(device).long()
masks = seq_lens_to_mask(seq_lens)
feats = self.encoder_model(chars, bigrams, seq_lens)
feats = self.decoder_model(feats)
probs = self.crf.viterbi_decode(feats, masks, get_score=False)

return {'pred_tags': probs}


+ 43
- 0
reproduction/chinese_word_segment/process/cws_processor.py View File

@@ -118,6 +118,23 @@ class CWSTagProcessor(Processor):
def _tags_from_word_len(self, word_len):
raise NotImplementedError

class CWSBMESTagProcessor(CWSTagProcessor):
def __init__(self, field_name, new_added_field_name=None):
super(CWSBMESTagProcessor, self).__init__(field_name, new_added_field_name)

self.tag_size = 4

def _tags_from_word_len(self, word_len):
tag_list = []
if word_len == 1:
tag_list.append(3)
else:
tag_list.append(0)
for _ in range(word_len-2):
tag_list.append(1)
tag_list.append(2)

return tag_list

class CWSSegAppTagProcessor(CWSTagProcessor):
def __init__(self, field_name, new_added_field_name=None):
@@ -239,3 +256,29 @@ class SegApp2OutputProcessor(Processor):
start_idx = idx + 1
ins[self.new_added_field_name] = ' '.join(words)


class BMES2OutputProcessor(Processor):
def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output'):
super(BMES2OutputProcessor, self).__init__(None, None)

self.chars_field_name = chars_field_name
self.tag_field_name = tag_field_name

self.new_added_field_name = new_added_field_name

def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
pred_tags = ins[self.tag_field_name]
chars = ins[self.chars_field_name]
words = []
start_idx = 0
for idx, tag in enumerate(pred_tags):
if tag==3:
# 当前没有考虑将原文替换回去
words.extend(chars[start_idx:idx+1])
start_idx = idx + 1
elif tag==2:
words.append(''.join(chars[start_idx:idx+1]))
start_idx = idx + 1
ins[self.new_added_field_name] = ' '.join(words)

+ 32
- 17
reproduction/chinese_word_segment/utils.py View File

@@ -24,37 +24,52 @@ def refine_ys_on_seq_len(ys, seq_lens):
def flat_nested_list(nested_list):
return list(chain(*nested_list))

def calculate_pre_rec_f1(model, batcher):
def calculate_pre_rec_f1(model, batcher, type='segapp'):
true_ys, pred_ys = decode_iterator(model, batcher)

true_ys = flat_nested_list(true_ys)
pred_ys = flat_nested_list(pred_ys)

cor_num = 0
yp_wordnum = pred_ys.count(1)
yt_wordnum = true_ys.count(1)
start = 0
if true_ys[0]==1 and pred_ys[0]==1:
cor_num += 1
start = 1

for i in range(1, len(true_ys)):
if true_ys[i] == 1:
flag = True
if true_ys[start-1] != pred_ys[start-1]:
flag = False
else:
if type=='segapp':
yp_wordnum = pred_ys.count(1)
yt_wordnum = true_ys.count(1)

if true_ys[0]==1 and pred_ys[0]==1:
cor_num += 1
start = 1

for i in range(1, len(true_ys)):
if true_ys[i] == 1:
flag = True
if true_ys[start-1] != pred_ys[start-1]:
flag = False
else:
for j in range(start, i + 1):
if true_ys[j] != pred_ys[j]:
flag = False
break
if flag:
cor_num += 1
start = i + 1
elif type=='bmes':
yp_wordnum = pred_ys.count(2) + pred_ys.count(3)
yt_wordnum = true_ys.count(2) + true_ys.count(3)
for i in range(len(true_ys)):
if true_ys[i] == 2 or true_ys[i] == 3:
flag = True
for j in range(start, i + 1):
if true_ys[j] != pred_ys[j]:
flag = False
break
if flag:
cor_num += 1
start = i + 1
if flag:
cor_num += 1
start = i + 1
P = cor_num / (float(yp_wordnum) + 1e-6)
R = cor_num / (float(yt_wordnum) + 1e-6)
F = 2 * P * R / (P + R + 1e-6)
print(cor_num, yt_wordnum, yp_wordnum)
# print(cor_num, yt_wordnum, yp_wordnum)
return P, R, F




+ 89
- 0
reproduction/pos_tag_model/pos_io/pos_reader.py View File

@@ -0,0 +1,89 @@

from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance

def cut_long_sentence(sent, max_sample_length=200):
sent_no_space = sent.replace(' ', '')
cutted_sentence = []
if len(sent_no_space) > max_sample_length:
parts = sent.strip().split()
new_line = ''
length = 0
for part in parts:
length += len(part)
new_line += part + ' '
if length > max_sample_length:
new_line = new_line[:-1]
cutted_sentence.append(new_line)
length = 0
new_line = ''
if new_line != '':
cutted_sentence.append(new_line[:-1])
else:
cutted_sentence.append(sent)
return cutted_sentence


class ConlluPOSReader(object):
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BMES的tag)。
def __init__(self):
pass

def load(self, path):
datalist = []
with open(path, 'r', encoding='utf-8') as f:
sample = []
for line in f:
if line.startswith('\n'):
datalist.append(sample)
sample = []
elif line.startswith('#'):
continue
else:
sample.append(line.split('\t'))
if len(sample) > 0:
datalist.append(sample)

ds = DataSet()
for sample in datalist:
# print(sample)
res = self.get_one(sample)
if res is None:
continue
char_seq = []
pos_seq = []
for word, tag in zip(res[0], res[1]):
if len(word)==1:
char_seq.append(word)
pos_seq.append('S-{}'.format(tag))
elif len(word)>1:
pos_seq.append('B-{}'.format(tag))
for _ in range(len(word)-2):
pos_seq.append('M-{}'.format(tag))
pos_seq.append('E-{}'.format(tag))
char_seq.extend(list(word))
else:
raise ValueError("Zero length of word detected.")

ds.append(Instance(words=char_seq,
tag=pos_seq))

return ds

def get_one(self, sample):
if len(sample)==0:
return None
text = []
pos_tags = []
for w in sample:
t1, t2, t3, t4 = w[1], w[3], w[6], w[7]
if t3 == '_':
return None
text.append(t1)
pos_tags.append(t2)
return text, pos_tags

if __name__ == '__main__':
reader = ConlluPOSReader()
d = reader.load('/home/hyan/train.conllx')
print('reader')

+ 8
- 6
reproduction/pos_tag_model/pos_tag.cfg View File

@@ -1,16 +1,18 @@
[train]
epochs = 300
epochs = 6
batch_size = 32
pickle_path = "./save/"
validate = false
validate = true
save_best_dev = true
model_saved_path = "./save/"
valid_step = 250
eval_sort_key = 'accuracy'

[model]
rnn_hidden_units = 100
word_emb_dim = 100
rnn_hidden_units = 300
word_emb_dim = 300
dropout = 0.5
use_crf = true
use_cuda = true
print_every_step = 10

[test]
@@ -34,4 +36,4 @@ pickle_path = "./save/"
use_crf = true
use_cuda = true
rnn_hidden_units = 100
word_emb_dim = 100
word_emb_dim = 100

+ 1
- 1
reproduction/pos_tag_model/process/pos_processor.py View File

@@ -78,7 +78,7 @@ class PosOutputStrProcessor(Processor):
word_pos_list = []
for word, pos in zip(word_list, pos_list):
word_pos_list.append(word + self.sep + pos)
#TODO 应该可以定制
ins['word_pos_output'] = ' '.join(word_pos_list)

return dataset


Loading…
Cancel
Save