diff --git a/fastNLP/api/api.py b/fastNLP/api/api.py index 35590d9c..972d3271 100644 --- a/fastNLP/api/api.py +++ b/fastNLP/api/api.py @@ -8,6 +8,8 @@ from fastNLP.api.model_zoo import load_url from fastNLP.api.processor import ModelProcessor from reproduction.chinese_word_segment.cws_io.cws_reader import ConlluCWSReader from reproduction.pos_tag_model.pos_io.pos_reader import ConlluPOSReader +from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag +from fastNLP.core.instance import Instance from fastNLP.core.sampler import SequentialSampler from fastNLP.core.batch import Batch from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 @@ -179,6 +181,72 @@ class CWS(API): return f1, pre, rec + +class Parser(API): + def __init__(self, model_path=None, device='cpu'): + super(Parser, self).__init__() + if model_path is None: + model_path = model_urls['parser'] + + self.load(model_path, device) + + def predict(self, content): + if not hasattr(self, 'pipeline'): + raise ValueError("You have to load model first.") + + sentence_list = [] + # 1. 检查sentence的类型 + if isinstance(content, str): + sentence_list.append(content) + elif isinstance(content, list): + sentence_list = content + + # 2. 组建dataset + dataset = DataSet() + dataset.add_field('words', sentence_list) + # dataset.add_field('tag', sentence_list) + + # 3. 使用pipeline + self.pipeline(dataset) + for ins in dataset: + ins['heads'] = ins['heads'].tolist() + + return dataset['heads'], dataset['labels'] + + def test(self, filepath): + data = ConllxDataLoader().load(filepath) + ds = DataSet() + for ins1, ins2 in zip(add_seg_tag(data), data): + ds.append(Instance(words=ins1[0], tag=ins1[1], + gold_words=ins2[0], gold_pos=ins2[1], + gold_heads=ins2[2], gold_head_tags=ins2[3])) + + pp = self.pipeline + for p in pp: + if p.field_name == 'word_list': + p.field_name = 'gold_words' + elif p.field_name == 'pos_list': + p.field_name = 'gold_pos' + pp(ds) + head_cor, label_cor, total = 0,0,0 + for ins in ds: + head_gold = ins['gold_heads'] + head_pred = ins['heads'] + length = len(head_gold) + total += length + for i in range(length): + head_cor += 1 if head_pred[i] == head_gold[i] else 0 + uas = head_cor/total + print('uas:{:.2f}'.format(uas)) + + for p in pp: + if p.field_name == 'gold_words': + p.field_name = 'word_list' + elif p.field_name == 'gold_pos': + p.field_name = 'pos_list' + + return uas + if __name__ == "__main__": # pos_model_path = '../../reproduction/pos_tag_model/pos_crf.pkl' pos = POS(device='cpu') @@ -195,4 +263,9 @@ if __name__ == "__main__": '那么这款无人机到底有多厉害?'] print(cws.test('../../reproduction/chinese_word_segment/new-clean.txt.conll')) cws.predict(s) - + parser = Parser(device='cuda:0') + print(parser.test('../../reproduction/Biaffine_parser/test.conll')) + s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', + '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', + '那么这款无人机到底有多厉害?'] + print(parser.predict(s)) diff --git a/fastNLP/api/parser.py b/fastNLP/api/parser.py deleted file mode 100644 index ec821754..00000000 --- a/fastNLP/api/parser.py +++ /dev/null @@ -1,37 +0,0 @@ -from fastNLP.api.api import API -from fastNLP.core.dataset import DataSet -from fastNLP.core.predictor import Predictor -from fastNLP.api.pipeline import Pipeline -from fastNLP.api.processor import * -from fastNLP.models.biaffine_parser import BiaffineParser - -from fastNLP.core.instance import Instance - -import torch - - -class DependencyParser(API): - def __init__(self): - super(DependencyParser, self).__init__() - - def predict(self, data): - if self.pipeline is None: - self.pipeline = torch.load('xxx') - - dataset = DataSet() - for sent, pos_seq in data: - dataset.append(Instance(sentence=sent, sent_pos=pos_seq)) - dataset = self.pipeline.process(dataset) - - return dataset['heads'], dataset['labels'] - -if __name__ == '__main__': - data = [ - (['我', '是', '谁'], ['NR', 'VV', 'NR']), - (['自古', '英雄', '识', '英雄'], ['AD', 'NN', 'VV', 'NN']), - ] - parser = DependencyParser() - with open('/home/yfshao/workdir/dev_fastnlp/reproduction/Biaffine_parser/pipe/pipeline.pkl', 'rb') as f: - parser.pipeline = torch.load(f) - output = parser.predict(data) - print(output) diff --git a/fastNLP/api/processor.py b/fastNLP/api/processor.py index df868b8c..999cebac 100644 --- a/fastNLP/api/processor.py +++ b/fastNLP/api/processor.py @@ -198,12 +198,12 @@ class ModelProcessor(Processor): :param batch_size: """ super(ModelProcessor, self).__init__(None, None) - self.batch_size = batch_size self.seq_len_field_name = seq_len_field_name self.model = model def process(self, dataset): + self.model.eval() assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) data_iterator = Batch(dataset, batch_size=self.batch_size, sampler=SequentialSampler(), use_cuda=False) @@ -261,3 +261,16 @@ class SetTensorProcessor(Processor): set_dict.update(self.field_dict) dataset.set_need_tensor(**set_dict) return dataset + + +class SetIsTargetProcessor(Processor): + def __init__(self, field_dict, default=False): + super(SetIsTargetProcessor, self).__init__(None, None) + self.field_dict = field_dict + self.default = default + + def process(self, dataset): + set_dict = {name: self.default for name in dataset.get_fields().keys()} + set_dict.update(self.field_dict) + dataset.set_is_target(**set_dict) + return dataset diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 2922699e..3e92e711 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -43,7 +43,7 @@ class DataSet(object): self.dataset[name][self.idx] = val def __repr__(self): - return " ".join([repr(self.dataset[name][self.idx]) for name in self.dataset]) + return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name in self.dataset.get_fields().keys()]) def __init__(self, instance=None): self.field_arrays = {} diff --git a/fastNLP/loader/embed_loader.py b/fastNLP/loader/embed_loader.py index 415cb1b9..1b9e0b0b 100644 --- a/fastNLP/loader/embed_loader.py +++ b/fastNLP/loader/embed_loader.py @@ -30,7 +30,7 @@ class EmbedLoader(BaseLoader): with open(emb_file, 'r', encoding='utf-8') as f: for line in f: line = list(filter(lambda w: len(w)>0, line.strip().split(' '))) - if len(line) > 0: + if len(line) > 2: emb[line[0]] = torch.Tensor(list(map(float, line[1:]))) return emb @@ -61,10 +61,10 @@ class EmbedLoader(BaseLoader): TODO: fragile code """ # If the embedding pickle exists, load it and return. - if os.path.exists(emb_pkl): - with open(emb_pkl, "rb") as f: - embedding_tensor, vocab = _pickle.load(f) - return embedding_tensor, vocab + # if os.path.exists(emb_pkl): + # with open(emb_pkl, "rb") as f: + # embedding_tensor, vocab = _pickle.load(f) + # return embedding_tensor, vocab # Otherwise, load the pre-trained embedding. pretrain = EmbedLoader._load_pretrain(emb_file, emb_type) if vocab is None: @@ -80,6 +80,6 @@ class EmbedLoader(BaseLoader): embedding_tensor[vocab[w]] = v # save and return the result - with open(emb_pkl, "wb") as f: - _pickle.dump((embedding_tensor, vocab), f) + # with open(emb_pkl, "wb") as f: + # _pickle.dump((embedding_tensor, vocab), f) return embedding_tensor, vocab diff --git a/reproduction/Biaffine_parser/infer.py b/reproduction/Biaffine_parser/infer.py index 691c01d0..dc2ccc51 100644 --- a/reproduction/Biaffine_parser/infer.py +++ b/reproduction/Biaffine_parser/infer.py @@ -24,6 +24,7 @@ def _load_all(src): word_v = _load(src+'/word_v.pkl') pos_v = _load(src+'/pos_v.pkl') tag_v = _load(src+'/tag_v.pkl') + pos_pp = torch.load(src+'/pos_pp.pkl')['pipeline'] model_args = ConfigSection() ConfigLoader.load_config('cfg.cfg', {'model': model_args}) @@ -38,6 +39,7 @@ def _load_all(src): 'pos_v': pos_v, 'tag_v': tag_v, 'model': model, + 'pos_pp':pos_pp, } def build(load_path, save_path): @@ -47,19 +49,22 @@ def build(load_path, save_path): word_vocab = _dict['word_v'] pos_vocab = _dict['pos_v'] tag_vocab = _dict['tag_v'] + pos_pp = _dict['pos_pp'] model = _dict['model'] print('load model from {}'.format(load_path)) word_seq = 'raw_word_seq' pos_seq = 'raw_pos_seq' # build pipeline - pipe = Pipeline() - pipe.add_processor(Num2TagProcessor(NUM, 'sentence', word_seq)) + # input + pipe = pos_pp + pipe.pipeline.pop(-1) + pipe.add_processor(Num2TagProcessor(NUM, 'word_list', word_seq)) pipe.add_processor(PreAppendProcessor(BOS, word_seq)) - pipe.add_processor(PreAppendProcessor(BOS, 'sent_pos', pos_seq)) + pipe.add_processor(PreAppendProcessor(BOS, 'pos_list', pos_seq)) pipe.add_processor(IndexerProcessor(word_vocab, word_seq, 'word_seq')) pipe.add_processor(IndexerProcessor(pos_vocab, pos_seq, 'pos_seq')) - pipe.add_processor(SeqLenProcessor(word_seq, 'word_seq_origin_len')) + pipe.add_processor(SeqLenProcessor('word_seq', 'word_seq_origin_len')) pipe.add_processor(SetTensorProcessor({'word_seq':True, 'pos_seq':True, 'word_seq_origin_len':True}, default=False)) pipe.add_processor(ModelProcessor(model, 'word_seq_origin_len')) pipe.add_processor(SliceProcessor(1, None, None, 'head_pred', 'heads')) @@ -68,7 +73,7 @@ def build(load_path, save_path): if not os.path.exists(save_path): os.makedirs(save_path) with open(save_path+'/pipeline.pkl', 'wb') as f: - torch.save(pipe, f) + torch.save({'pipeline': pipe}, f) print('save pipeline in {}'.format(save_path)) diff --git a/reproduction/Biaffine_parser/run_test.py b/reproduction/Biaffine_parser/run_test.py new file mode 100644 index 00000000..6a67f45a --- /dev/null +++ b/reproduction/Biaffine_parser/run_test.py @@ -0,0 +1,116 @@ +import sys +import os + +sys.path.extend(['/home/yfshao/workdir/dev_fastnlp']) + +import torch +import argparse +import numpy as np + +from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag +from fastNLP.core.dataset import DataSet +from fastNLP.core.instance import Instance + +parser = argparse.ArgumentParser() +parser.add_argument('--pipe', type=str, default='') +parser.add_argument('--gold_data', type=str, default='') +parser.add_argument('--new_data', type=str) +args = parser.parse_args() + +pipe = torch.load(args.pipe)['pipeline'] +for p in pipe: + if p.field_name == 'word_list': + print(p.field_name) + p.field_name = 'gold_words' + elif p.field_name == 'pos_list': + print(p.field_name) + p.field_name = 'gold_pos' + + +data = ConllxDataLoader().load(args.gold_data) +ds = DataSet() +for ins1, ins2 in zip(add_seg_tag(data), data): + ds.append(Instance(words=ins1[0], tag=ins1[1], + gold_words=ins2[0], gold_pos=ins2[1], + gold_heads=ins2[2], gold_head_tags=ins2[3])) + +ds = pipe(ds) + +seg_threshold = 0. +pos_threshold = 0. +parse_threshold = 0.74 + + +def get_heads(ins, head_f, word_f): + head_pred = [] + for i, idx in enumerate(ins[head_f]): + j = idx - 1 if idx != 0 else i + head_pred.append(ins[word_f][j]) + return head_pred + +def evaluate(ins): + seg_count = sum([1 for i, j in zip(ins['word_list'], ins['gold_words']) if i == j]) + pos_count = sum([1 for i, j in zip(ins['pos_list'], ins['gold_pos']) if i == j]) + head_count = sum([1 for i, j in zip(ins['heads'], ins['gold_heads']) if i == j]) + total = len(ins['gold_words']) + return seg_count / total, pos_count / total, head_count / total + +def is_ok(x): + seg, pos, head = x[1] + return seg > seg_threshold and pos > pos_threshold and head > parse_threshold + +res_list = [] + +for i, ins in enumerate(ds): + res_list.append((i, evaluate(ins))) + +res_list = list(filter(is_ok, res_list)) +print('{} {}'.format(len(ds), len(res_list))) + +seg_cor, pos_cor, head_cor, label_cor, total = 0,0,0,0,0 +for i, _ in res_list: + ins = ds[i] + # print(i) + # print('gold_words:\t', ins['gold_words']) + # print('predict_words:\t', ins['word_list']) + # print('gold_tag:\t', ins['gold_pos']) + # print('predict_tag:\t', ins['pos_list']) + # print('gold_heads:\t', ins['gold_heads']) + # print('predict_heads:\t', ins['heads'].tolist()) + # print('gold_head_tags:\t', ins['gold_head_tags']) + # print('predict_labels:\t', ins['labels']) + # print() + + head_pred = ins['heads'] + head_gold = ins['gold_heads'] + label_pred = ins['labels'] + label_gold = ins['gold_head_tags'] + total += len(head_gold) + seg_cor += sum([1 for i, j in zip(ins['word_list'], ins['gold_words']) if i == j]) + pos_cor += sum([1 for i, j in zip(ins['pos_list'], ins['gold_pos']) if i == j]) + length = len(head_gold) + for i in range(length): + head_cor += 1 if head_pred[i] == head_gold[i] else 0 + label_cor += 1 if head_pred[i] == head_gold[i] and label_gold[i] == label_pred[i] else 0 + + +print('SEG: {}, POS: {}, UAS: {}, LAS: {}'.format(seg_cor/total, pos_cor/total, head_cor/total, label_cor/total)) + +colln_path = args.gold_data +new_colln_path = args.new_data + +index_list = [x[0] for x in res_list] + +with open(colln_path, 'r', encoding='utf-8') as f1, \ + open(new_colln_path, 'w', encoding='utf-8') as f2: + for idx, ins in enumerate(ds): + if idx in index_list: + length = len(ins['gold_words']) + pad = ['_' for _ in range(length)] + for x in zip( + map(str, range(1, length+1)), ins['gold_words'], ins['gold_words'], ins['gold_pos'], + pad, pad, map(str, ins['gold_heads']), ins['gold_head_tags']): + new_lines = '\t'.join(x) + f2.write(new_lines) + f2.write('\n') + f2.write('\n') diff --git a/reproduction/Biaffine_parser/util.py b/reproduction/Biaffine_parser/util.py new file mode 100644 index 00000000..793b1fb2 --- /dev/null +++ b/reproduction/Biaffine_parser/util.py @@ -0,0 +1,78 @@ +class ConllxDataLoader(object): + def load(self, path): + datalist = [] + with open(path, 'r', encoding='utf-8') as f: + sample = [] + for line in f: + if line.startswith('\n'): + datalist.append(sample) + sample = [] + elif line.startswith('#'): + continue + else: + sample.append(line.split('\t')) + if len(sample) > 0: + datalist.append(sample) + + data = [self.get_one(sample) for sample in datalist] + return list(filter(lambda x: x is not None, data)) + + def get_one(self, sample): + sample = list(map(list, zip(*sample))) + if len(sample) == 0: + return None + for w in sample[7]: + if w == '_': + print('Error Sample {}'.format(sample)) + return None + # return word_seq, pos_seq, head_seq, head_tag_seq + return sample[1], sample[3], list(map(int, sample[6])), sample[7] + + +class MyDataloader: + def load(self, data_path): + with open(data_path, "r", encoding="utf-8") as f: + lines = f.readlines() + data = self.parse(lines) + return data + + def parse(self, lines): + """ + [ + [word], [pos], [head_index], [head_tag] + ] + """ + sample = [] + data = [] + for i, line in enumerate(lines): + line = line.strip() + if len(line) == 0 or i + 1 == len(lines): + data.append(list(map(list, zip(*sample)))) + sample = [] + else: + sample.append(line.split()) + if len(sample) > 0: + data.append(list(map(list, zip(*sample)))) + return data + + +def add_seg_tag(data): + """ + + :param data: list of ([word], [pos], [heads], [head_tags]) + :return: list of ([word], [pos]) + """ + + _processed = [] + for word_list, pos_list, _, _ in data: + new_sample = [] + for word, pos in zip(word_list, pos_list): + if len(word) == 1: + new_sample.append((word, 'S-' + pos)) + else: + new_sample.append((word[0], 'B-' + pos)) + for c in word[1:-1]: + new_sample.append((c, 'M-' + pos)) + new_sample.append((word[-1], 'E-' + pos)) + _processed.append(list(map(list, zip(*new_sample)))) + return _processed \ No newline at end of file