Browse Source

新增端到端pos处理到parser的过度代码

tags/v0.2.0
yh yunfan 5 years ago
parent
commit
d5afffee73
1 changed files with 107 additions and 0 deletions
  1. +107
    -0
      reproduction/pos_tag_model/process/pos_processor.py

+ 107
- 0
reproduction/pos_tag_model/process/pos_processor.py View File

@@ -0,0 +1,107 @@

from collections import Counter

from fastNLP.api.processor import Processor
from fastNLP.core.dataset import DataSet

class CombineWordAndPosProcessor(Processor):
def __init__(self, word_field_name, pos_field_name):
super(CombineWordAndPosProcessor, self).__init__(None, None)

self.word_field_name = word_field_name
self.pos_field_name = pos_field_name

def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))

for ins in dataset:
chars = ins[self.word_field_name]
bmes_pos = ins[self.pos_field_name]
word_list = []
pos_list = []
pos_stack_cnt = Counter()
char_stack = []
for char, p in zip(chars, bmes_pos):
parts = p.split('-')
pre = parts[0]
post = parts[1]
if pre.lower() == 's':
if len(pos_stack_cnt) != 0:
pos = pos_stack_cnt.most_common(1)[0][0]
pos_list.append(pos)
word_list.append(''.join(char_stack))
pos_list.append(post)
word_list.append(char)
char_stack.clear()
pos_stack_cnt.clear()
elif pre.lower() == 'e':
pos_stack_cnt.update([post])
char_stack.append(char)
pos = pos_stack_cnt.most_common(1)[0][0]
pos_list.append(pos)
word_list.append(''.join(char_stack))
char_stack.clear()
pos_stack_cnt.clear()
elif pre.lower() == 'b':
if len(pos_stack_cnt) != 0:
pos = pos_stack_cnt.most_common(1)[0][0]
pos_list.append(pos)
word_list.append(''.join(char_stack))
char_stack.clear()
pos_stack_cnt.clear()
char_stack.append(char)
pos_stack_cnt.update([post])
else:
char_stack.append(char)
pos_stack_cnt.update([post])

ins['word_list'] = word_list
ins['pos_list'] = pos_list

return dataset

if __name__ == '__main__':
chars = ['迈', '向', '充', '满', '希', '望', '的', '新', '世', '纪', '—', '—', '一', '九', '九', '八', '年', '新', '年', '讲', '话', '(', '附', '图', '片', '1', '张', ')']
bmes_pos = ['B-v', 'E-v', 'B-v', 'E-v', 'B-n', 'E-n', 'S-u', 'S-a', 'B-n', 'E-n', 'B-w', 'E-w', 'B-t', 'M-t', 'M-t', 'M-t', 'E-t', 'B-t', 'E-t', 'B-n', 'E-n', 'S-w', 'S-v', 'B-n', 'E-n', 'S-m', 'S-q', 'S-w']


word_list = []
pos_list = []
pos_stack_cnt = Counter()
char_stack = []
for char, p in zip(''.join(chars), bmes_pos):
parts = p.split('-')
pre = parts[0]
post = parts[1]
if pre.lower() == 's':
if len(pos_stack_cnt) != 0:
pos = pos_stack_cnt.most_common(1)[0][0]
pos_list.append(pos)
word_list.append(''.join(char_stack))
pos_list.append(post)
word_list.append(char)
char_stack.clear()
pos_stack_cnt.clear()
elif pre.lower() == 'e':
pos_stack_cnt.update([post])
char_stack.append(char)
pos = pos_stack_cnt.most_common(1)[0][0]
pos_list.append(pos)
word_list.append(''.join(char_stack))
char_stack.clear()
pos_stack_cnt.clear()
elif pre.lower() == 'b':
if len(pos_stack_cnt) != 0:
pos = pos_stack_cnt.most_common(1)[0][0]
pos_list.append(pos)
word_list.append(''.join(char_stack))
char_stack.clear()
pos_stack_cnt.clear()
char_stack.append(char)
pos_stack_cnt.update([post])
else:
char_stack.append(char)
pos_stack_cnt.update([post])

print(word_list)
print(pos_list)

Loading…
Cancel
Save