Browse Source

add parser api

tags/v0.2.0
yunfan 6 years ago
parent
commit
b6a0d33cb1
8 changed files with 300 additions and 52 deletions
  1. +74
    -1
      fastNLP/api/api.py
  2. +0
    -37
      fastNLP/api/parser.py
  3. +14
    -1
      fastNLP/api/processor.py
  4. +1
    -1
      fastNLP/core/dataset.py
  5. +7
    -7
      fastNLP/loader/embed_loader.py
  6. +10
    -5
      reproduction/Biaffine_parser/infer.py
  7. +116
    -0
      reproduction/Biaffine_parser/run_test.py
  8. +78
    -0
      reproduction/Biaffine_parser/util.py

+ 74
- 1
fastNLP/api/api.py View File

@@ -8,6 +8,8 @@ from fastNLP.api.model_zoo import load_url
from fastNLP.api.processor import ModelProcessor
from reproduction.chinese_word_segment.cws_io.cws_reader import ConlluCWSReader
from reproduction.pos_tag_model.pos_io.pos_reader import ConlluPOSReader
from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag
from fastNLP.core.instance import Instance
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.batch import Batch
from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1
@@ -179,6 +181,72 @@ class CWS(API):

return f1, pre, rec


class Parser(API):
def __init__(self, model_path=None, device='cpu'):
super(Parser, self).__init__()
if model_path is None:
model_path = model_urls['parser']

self.load(model_path, device)

def predict(self, content):
if not hasattr(self, 'pipeline'):
raise ValueError("You have to load model first.")

sentence_list = []
# 1. 检查sentence的类型
if isinstance(content, str):
sentence_list.append(content)
elif isinstance(content, list):
sentence_list = content

# 2. 组建dataset
dataset = DataSet()
dataset.add_field('words', sentence_list)
# dataset.add_field('tag', sentence_list)

# 3. 使用pipeline
self.pipeline(dataset)
for ins in dataset:
ins['heads'] = ins['heads'].tolist()

return dataset['heads'], dataset['labels']

def test(self, filepath):
data = ConllxDataLoader().load(filepath)
ds = DataSet()
for ins1, ins2 in zip(add_seg_tag(data), data):
ds.append(Instance(words=ins1[0], tag=ins1[1],
gold_words=ins2[0], gold_pos=ins2[1],
gold_heads=ins2[2], gold_head_tags=ins2[3]))

pp = self.pipeline
for p in pp:
if p.field_name == 'word_list':
p.field_name = 'gold_words'
elif p.field_name == 'pos_list':
p.field_name = 'gold_pos'
pp(ds)
head_cor, label_cor, total = 0,0,0
for ins in ds:
head_gold = ins['gold_heads']
head_pred = ins['heads']
length = len(head_gold)
total += length
for i in range(length):
head_cor += 1 if head_pred[i] == head_gold[i] else 0
uas = head_cor/total
print('uas:{:.2f}'.format(uas))

for p in pp:
if p.field_name == 'gold_words':
p.field_name = 'word_list'
elif p.field_name == 'gold_pos':
p.field_name = 'pos_list'

return uas

if __name__ == "__main__":
# pos_model_path = '../../reproduction/pos_tag_model/pos_crf.pkl'
pos = POS(device='cpu')
@@ -195,4 +263,9 @@ if __name__ == "__main__":
'那么这款无人机到底有多厉害?']
print(cws.test('../../reproduction/chinese_word_segment/new-clean.txt.conll'))
cws.predict(s)

parser = Parser(device='cuda:0')
print(parser.test('../../reproduction/Biaffine_parser/test.conll'))
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。',
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
'那么这款无人机到底有多厉害?']
print(parser.predict(s))

+ 0
- 37
fastNLP/api/parser.py View File

@@ -1,37 +0,0 @@
from fastNLP.api.api import API
from fastNLP.core.dataset import DataSet
from fastNLP.core.predictor import Predictor
from fastNLP.api.pipeline import Pipeline
from fastNLP.api.processor import *
from fastNLP.models.biaffine_parser import BiaffineParser

from fastNLP.core.instance import Instance

import torch


class DependencyParser(API):
def __init__(self):
super(DependencyParser, self).__init__()

def predict(self, data):
if self.pipeline is None:
self.pipeline = torch.load('xxx')

dataset = DataSet()
for sent, pos_seq in data:
dataset.append(Instance(sentence=sent, sent_pos=pos_seq))
dataset = self.pipeline.process(dataset)

return dataset['heads'], dataset['labels']

if __name__ == '__main__':
data = [
(['我', '是', '谁'], ['NR', 'VV', 'NR']),
(['自古', '英雄', '识', '英雄'], ['AD', 'NN', 'VV', 'NN']),
]
parser = DependencyParser()
with open('/home/yfshao/workdir/dev_fastnlp/reproduction/Biaffine_parser/pipe/pipeline.pkl', 'rb') as f:
parser.pipeline = torch.load(f)
output = parser.predict(data)
print(output)

+ 14
- 1
fastNLP/api/processor.py View File

@@ -198,12 +198,12 @@ class ModelProcessor(Processor):
:param batch_size:
"""
super(ModelProcessor, self).__init__(None, None)

self.batch_size = batch_size
self.seq_len_field_name = seq_len_field_name
self.model = model

def process(self, dataset):
self.model.eval()
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
data_iterator = Batch(dataset, batch_size=self.batch_size, sampler=SequentialSampler(), use_cuda=False)

@@ -261,3 +261,16 @@ class SetTensorProcessor(Processor):
set_dict.update(self.field_dict)
dataset.set_need_tensor(**set_dict)
return dataset


class SetIsTargetProcessor(Processor):
def __init__(self, field_dict, default=False):
super(SetIsTargetProcessor, self).__init__(None, None)
self.field_dict = field_dict
self.default = default

def process(self, dataset):
set_dict = {name: self.default for name in dataset.get_fields().keys()}
set_dict.update(self.field_dict)
dataset.set_is_target(**set_dict)
return dataset

+ 1
- 1
fastNLP/core/dataset.py View File

@@ -43,7 +43,7 @@ class DataSet(object):
self.dataset[name][self.idx] = val

def __repr__(self):
return " ".join([repr(self.dataset[name][self.idx]) for name in self.dataset])
return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name in self.dataset.get_fields().keys()])

def __init__(self, instance=None):
self.field_arrays = {}


+ 7
- 7
fastNLP/loader/embed_loader.py View File

@@ -30,7 +30,7 @@ class EmbedLoader(BaseLoader):
with open(emb_file, 'r', encoding='utf-8') as f:
for line in f:
line = list(filter(lambda w: len(w)>0, line.strip().split(' ')))
if len(line) > 0:
if len(line) > 2:
emb[line[0]] = torch.Tensor(list(map(float, line[1:])))
return emb

@@ -61,10 +61,10 @@ class EmbedLoader(BaseLoader):
TODO: fragile code
"""
# If the embedding pickle exists, load it and return.
if os.path.exists(emb_pkl):
with open(emb_pkl, "rb") as f:
embedding_tensor, vocab = _pickle.load(f)
return embedding_tensor, vocab
# if os.path.exists(emb_pkl):
# with open(emb_pkl, "rb") as f:
# embedding_tensor, vocab = _pickle.load(f)
# return embedding_tensor, vocab
# Otherwise, load the pre-trained embedding.
pretrain = EmbedLoader._load_pretrain(emb_file, emb_type)
if vocab is None:
@@ -80,6 +80,6 @@ class EmbedLoader(BaseLoader):
embedding_tensor[vocab[w]] = v

# save and return the result
with open(emb_pkl, "wb") as f:
_pickle.dump((embedding_tensor, vocab), f)
# with open(emb_pkl, "wb") as f:
# _pickle.dump((embedding_tensor, vocab), f)
return embedding_tensor, vocab

+ 10
- 5
reproduction/Biaffine_parser/infer.py View File

@@ -24,6 +24,7 @@ def _load_all(src):
word_v = _load(src+'/word_v.pkl')
pos_v = _load(src+'/pos_v.pkl')
tag_v = _load(src+'/tag_v.pkl')
pos_pp = torch.load(src+'/pos_pp.pkl')['pipeline']

model_args = ConfigSection()
ConfigLoader.load_config('cfg.cfg', {'model': model_args})
@@ -38,6 +39,7 @@ def _load_all(src):
'pos_v': pos_v,
'tag_v': tag_v,
'model': model,
'pos_pp':pos_pp,
}

def build(load_path, save_path):
@@ -47,19 +49,22 @@ def build(load_path, save_path):
word_vocab = _dict['word_v']
pos_vocab = _dict['pos_v']
tag_vocab = _dict['tag_v']
pos_pp = _dict['pos_pp']
model = _dict['model']
print('load model from {}'.format(load_path))
word_seq = 'raw_word_seq'
pos_seq = 'raw_pos_seq'

# build pipeline
pipe = Pipeline()
pipe.add_processor(Num2TagProcessor(NUM, 'sentence', word_seq))
# input
pipe = pos_pp
pipe.pipeline.pop(-1)
pipe.add_processor(Num2TagProcessor(NUM, 'word_list', word_seq))
pipe.add_processor(PreAppendProcessor(BOS, word_seq))
pipe.add_processor(PreAppendProcessor(BOS, 'sent_pos', pos_seq))
pipe.add_processor(PreAppendProcessor(BOS, 'pos_list', pos_seq))
pipe.add_processor(IndexerProcessor(word_vocab, word_seq, 'word_seq'))
pipe.add_processor(IndexerProcessor(pos_vocab, pos_seq, 'pos_seq'))
pipe.add_processor(SeqLenProcessor(word_seq, 'word_seq_origin_len'))
pipe.add_processor(SeqLenProcessor('word_seq', 'word_seq_origin_len'))
pipe.add_processor(SetTensorProcessor({'word_seq':True, 'pos_seq':True, 'word_seq_origin_len':True}, default=False))
pipe.add_processor(ModelProcessor(model, 'word_seq_origin_len'))
pipe.add_processor(SliceProcessor(1, None, None, 'head_pred', 'heads'))
@@ -68,7 +73,7 @@ def build(load_path, save_path):
if not os.path.exists(save_path):
os.makedirs(save_path)
with open(save_path+'/pipeline.pkl', 'wb') as f:
torch.save(pipe, f)
torch.save({'pipeline': pipe}, f)
print('save pipeline in {}'.format(save_path))




+ 116
- 0
reproduction/Biaffine_parser/run_test.py View File

@@ -0,0 +1,116 @@
import sys
import os

sys.path.extend(['/home/yfshao/workdir/dev_fastnlp'])

import torch
import argparse
import numpy as np

from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag
from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance

parser = argparse.ArgumentParser()
parser.add_argument('--pipe', type=str, default='')
parser.add_argument('--gold_data', type=str, default='')
parser.add_argument('--new_data', type=str)
args = parser.parse_args()

pipe = torch.load(args.pipe)['pipeline']
for p in pipe:
if p.field_name == 'word_list':
print(p.field_name)
p.field_name = 'gold_words'
elif p.field_name == 'pos_list':
print(p.field_name)
p.field_name = 'gold_pos'


data = ConllxDataLoader().load(args.gold_data)
ds = DataSet()
for ins1, ins2 in zip(add_seg_tag(data), data):
ds.append(Instance(words=ins1[0], tag=ins1[1],
gold_words=ins2[0], gold_pos=ins2[1],
gold_heads=ins2[2], gold_head_tags=ins2[3]))

ds = pipe(ds)

seg_threshold = 0.
pos_threshold = 0.
parse_threshold = 0.74


def get_heads(ins, head_f, word_f):
head_pred = []
for i, idx in enumerate(ins[head_f]):
j = idx - 1 if idx != 0 else i
head_pred.append(ins[word_f][j])
return head_pred

def evaluate(ins):
seg_count = sum([1 for i, j in zip(ins['word_list'], ins['gold_words']) if i == j])
pos_count = sum([1 for i, j in zip(ins['pos_list'], ins['gold_pos']) if i == j])
head_count = sum([1 for i, j in zip(ins['heads'], ins['gold_heads']) if i == j])
total = len(ins['gold_words'])
return seg_count / total, pos_count / total, head_count / total

def is_ok(x):
seg, pos, head = x[1]
return seg > seg_threshold and pos > pos_threshold and head > parse_threshold

res_list = []

for i, ins in enumerate(ds):
res_list.append((i, evaluate(ins)))

res_list = list(filter(is_ok, res_list))
print('{} {}'.format(len(ds), len(res_list)))

seg_cor, pos_cor, head_cor, label_cor, total = 0,0,0,0,0
for i, _ in res_list:
ins = ds[i]
# print(i)
# print('gold_words:\t', ins['gold_words'])
# print('predict_words:\t', ins['word_list'])
# print('gold_tag:\t', ins['gold_pos'])
# print('predict_tag:\t', ins['pos_list'])
# print('gold_heads:\t', ins['gold_heads'])
# print('predict_heads:\t', ins['heads'].tolist())
# print('gold_head_tags:\t', ins['gold_head_tags'])
# print('predict_labels:\t', ins['labels'])
# print()

head_pred = ins['heads']
head_gold = ins['gold_heads']
label_pred = ins['labels']
label_gold = ins['gold_head_tags']
total += len(head_gold)
seg_cor += sum([1 for i, j in zip(ins['word_list'], ins['gold_words']) if i == j])
pos_cor += sum([1 for i, j in zip(ins['pos_list'], ins['gold_pos']) if i == j])
length = len(head_gold)
for i in range(length):
head_cor += 1 if head_pred[i] == head_gold[i] else 0
label_cor += 1 if head_pred[i] == head_gold[i] and label_gold[i] == label_pred[i] else 0


print('SEG: {}, POS: {}, UAS: {}, LAS: {}'.format(seg_cor/total, pos_cor/total, head_cor/total, label_cor/total))

colln_path = args.gold_data
new_colln_path = args.new_data

index_list = [x[0] for x in res_list]

with open(colln_path, 'r', encoding='utf-8') as f1, \
open(new_colln_path, 'w', encoding='utf-8') as f2:
for idx, ins in enumerate(ds):
if idx in index_list:
length = len(ins['gold_words'])
pad = ['_' for _ in range(length)]
for x in zip(
map(str, range(1, length+1)), ins['gold_words'], ins['gold_words'], ins['gold_pos'],
pad, pad, map(str, ins['gold_heads']), ins['gold_head_tags']):
new_lines = '\t'.join(x)
f2.write(new_lines)
f2.write('\n')
f2.write('\n')

+ 78
- 0
reproduction/Biaffine_parser/util.py View File

@@ -0,0 +1,78 @@
class ConllxDataLoader(object):
def load(self, path):
datalist = []
with open(path, 'r', encoding='utf-8') as f:
sample = []
for line in f:
if line.startswith('\n'):
datalist.append(sample)
sample = []
elif line.startswith('#'):
continue
else:
sample.append(line.split('\t'))
if len(sample) > 0:
datalist.append(sample)

data = [self.get_one(sample) for sample in datalist]
return list(filter(lambda x: x is not None, data))

def get_one(self, sample):
sample = list(map(list, zip(*sample)))
if len(sample) == 0:
return None
for w in sample[7]:
if w == '_':
print('Error Sample {}'.format(sample))
return None
# return word_seq, pos_seq, head_seq, head_tag_seq
return sample[1], sample[3], list(map(int, sample[6])), sample[7]


class MyDataloader:
def load(self, data_path):
with open(data_path, "r", encoding="utf-8") as f:
lines = f.readlines()
data = self.parse(lines)
return data

def parse(self, lines):
"""
[
[word], [pos], [head_index], [head_tag]
]
"""
sample = []
data = []
for i, line in enumerate(lines):
line = line.strip()
if len(line) == 0 or i + 1 == len(lines):
data.append(list(map(list, zip(*sample))))
sample = []
else:
sample.append(line.split())
if len(sample) > 0:
data.append(list(map(list, zip(*sample))))
return data


def add_seg_tag(data):
"""

:param data: list of ([word], [pos], [heads], [head_tags])
:return: list of ([word], [pos])
"""

_processed = []
for word_list, pos_list, _, _ in data:
new_sample = []
for word, pos in zip(word_list, pos_list):
if len(word) == 1:
new_sample.append((word, 'S-' + pos))
else:
new_sample.append((word[0], 'B-' + pos))
for c in word[1:-1]:
new_sample.append((c, 'M-' + pos))
new_sample.append((word[-1], 'E-' + pos))
_processed.append(list(map(list, zip(*new_sample))))
return _processed

Loading…
Cancel
Save