From 8906155ca2e86f16868d683d27d5caa4234a653a Mon Sep 17 00:00:00 2001 From: yh Date: Wed, 14 Nov 2018 23:15:19 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=BAapi=E5=BB=BA=E7=AB=8B=E4=B8=80?= =?UTF-8?q?=E4=B8=AAAnalyzer?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/api/api.py | 138 ++++------ .../chinese_word_segment/testcontext.py | 47 ---- .../chinese_word_segment/train_context.py | 245 ------------------ reproduction/pos_tag_model/testcontext.py | 0 reproduction/pos_tag_model/train_pos_tag.py | 127 --------- 5 files changed, 51 insertions(+), 506 deletions(-) delete mode 100644 reproduction/chinese_word_segment/testcontext.py delete mode 100644 reproduction/chinese_word_segment/train_context.py delete mode 100644 reproduction/pos_tag_model/testcontext.py delete mode 100644 reproduction/pos_tag_model/train_pos_tag.py diff --git a/fastNLP/api/api.py b/fastNLP/api/api.py index 1ea78bb7..ddb855bb 100644 --- a/fastNLP/api/api.py +++ b/fastNLP/api/api.py @@ -34,7 +34,6 @@ class API: if os.path.exists(os.path.expanduser(path)): _dict = torch.load(path, map_location='cpu') else: - print(os.path.expanduser(path)) _dict = load_url(path, map_location='cpu') self.pipeline = _dict['pipeline'] self._dict = _dict @@ -58,7 +57,7 @@ class POS(API): def predict(self, content): """ - :param query: list of list of str. Each string is a token(word). + :param content: list of list of str. Each string is a token(word). :return answer: list of list of str. Each string is a tag. """ if not hasattr(self, 'pipeline'): @@ -183,99 +182,64 @@ class CWS(API): return f1, pre, rec -class Parser(API): - def __init__(self, model_path=None, device='cpu'): - super(Parser, self).__init__() - if model_path is None: - model_path = model_urls['parser'] +class Analyzer: + def __init__(self, seg=True, pos=True, parser=True, device='cpu'): - self.load(model_path, device) + self.seg = seg + self.pos = pos + self.parser = parser - def predict(self, content): - if not hasattr(self, 'pipeline'): - raise ValueError("You have to load model first.") + if self.seg: + self.cws = CWS(device=device) + if self.pos: + self.pos = POS(device=device) + if parser: + self.parser = None - sentence_list = [] - # 1. 检查sentence的类型 - if isinstance(content, str): - sentence_list.append(content) - elif isinstance(content, list): - sentence_list = content - - # 2. 组建dataset - dataset = DataSet() - dataset.add_field('words', sentence_list) - # dataset.add_field('tag', sentence_list) - - # 3. 使用pipeline - self.pipeline(dataset) - for ins in dataset: - ins['heads'] = ins['heads'].tolist() - - return dataset['heads'], dataset['labels'] + def predict(self, content): + output_dict = {} + if self.seg: + seg_output = self.cws.predict(content) + output_dict['seg'] = seg_output + if self.pos: + pos_output = self.pos.predict(content) + output_dict['pos'] = pos_output + if self.parser: + parser_output = self.parser.predict(content) + output_dict['parser'] = parser_output + + return output_dict def test(self, filepath): - data = ConllxDataLoader().load(filepath) - ds = DataSet() - for ins1, ins2 in zip(add_seg_tag(data), data): - ds.append(Instance(words=ins1[0], tag=ins1[1], - gold_words=ins2[0], gold_pos=ins2[1], - gold_heads=ins2[2], gold_head_tags=ins2[3])) - - pp = self.pipeline - for p in pp: - if p.field_name == 'word_list': - p.field_name = 'gold_words' - elif p.field_name == 'pos_list': - p.field_name = 'gold_pos' - pp(ds) - head_cor, label_cor, total = 0, 0, 0 - for ins in ds: - head_gold = ins['gold_heads'] - head_pred = ins['heads'] - length = len(head_gold) - total += length - for i in range(length): - head_cor += 1 if head_pred[i] == head_gold[i] else 0 - uas = head_cor / total - print('uas:{:.2f}'.format(uas)) - - for p in pp: - if p.field_name == 'gold_words': - p.field_name = 'word_list' - elif p.field_name == 'gold_pos': - p.field_name = 'pos_list' - - return uas + output_dict = {} + if self.seg: + seg_output = self.cws.test(filepath) + output_dict['seg'] = seg_output + if self.pos: + pos_output = self.pos.test(filepath) + output_dict['pos'] = pos_output + if self.parser: + parser_output = self.parser.test(filepath) + output_dict['parser'] = parser_output + return output_dict -if __name__ == "__main__": - # 以下路径在102 - """ - pos_model_path = '/home/hyan/fastNLP_models/upload-demo/upload/pos_crf-5e26d3b0.pkl' - pos = POS(model_path=pos_model_path, device='cpu') - s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', - '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', - '那么这款无人机到底有多厉害?'] - #print(pos.test('../../reproduction/chinese_word_segment/new-clean.txt.conll')) - print(pos.predict(s)) - """ - """ - cws_model_path = '/home/hyan/fastNLP_models/upload-demo/upload/cws_crf-5a8a3e66.pkl' - cws = CWS(model_path=cws_model_path, device='cuda:0') - s = ['本品是一个抗酸抗胆汁的胃黏膜保护剂', - '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', +if __name__ == "__main__": + # pos_model_path = '../../reproduction/pos_tag_model/pos_crf.pkl' + # pos = POS(device='cpu') + # s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。' , + # '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', + # '那么这款无人机到底有多厉害?'] + # print(pos.test('/Users/yh/Desktop/test_data/small_test.conll')) + # print(pos.predict(s)) + + # cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl' + cws = CWS(device='cpu') + s = ['本品是一个抗酸抗胆汁的胃黏膜保护剂' , + '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', '那么这款无人机到底有多厉害?'] - #print(cws.test('../../reproduction/chinese_word_segment/new-clean.txt.conll')) - cws.predict(s) - """ + print(cws.test('/Users/yh/Desktop/test_data/small_test.conll')) + print(cws.predict(s)) - parser_model_path = "/home/hyan/fastNLP_models/upload-demo/upload/parser-d57cd5fc.pkl" - parser = Parser(model_path=parser_model_path, device='cuda:0') - # print(parser.test('../../reproduction/Biaffine_parser/test.conll')) - s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', - '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', - '那么这款无人机到底有多厉害?'] - print(parser.predict(s)) diff --git a/reproduction/chinese_word_segment/testcontext.py b/reproduction/chinese_word_segment/testcontext.py deleted file mode 100644 index 44444001..00000000 --- a/reproduction/chinese_word_segment/testcontext.py +++ /dev/null @@ -1,47 +0,0 @@ - - -import torch -from reproduction.chinese_word_segment.cws_io.cws_reader import NaiveCWSReader -from fastNLP.core.sampler import SequentialSampler -from fastNLP.core.batch import Batch -from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 - -def f1(): - ds_name = 'pku' - - test_dict = torch.load('models/test_context.pkl') - - - pp = test_dict['pipeline'] - model = test_dict['model'].cuda() - - reader = NaiveCWSReader() - te_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/{}_raw_data/{}_raw_test.txt'.format(ds_name, ds_name, - ds_name) - te_dataset = reader.load(te_filename) - pp(te_dataset) - - batch_size = 64 - te_batcher = Batch(te_dataset, batch_size, SequentialSampler(), use_cuda=False) - pre, rec, f1 = calculate_pre_rec_f1(model, te_batcher) - print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1 * 100, - pre * 100, - rec * 100)) - - -def f2(): - from fastNLP.api.api import CWS - cws = CWS('models/maml-cws.pkl') - datasets = ['msr', 'as', 'pku', 'ctb', 'ncc', 'cityu', 'ckip', 'sxu'] - for dataset in datasets: - print(dataset) - with open('/hdd/fudanNLP/CWS/others/benchmark/raw_and_gold/{}_raw.txt'.format(dataset), 'r') as f: - lines = f.readlines() - results = cws.predict(lines) - - with open('/hdd/fudanNLP/CWS/others/benchmark/fastNLP_output/{}_seg.txt'.format(dataset), 'w', encoding='utf-8') as f: - for line in results: - f.write(line) - - -f1() \ No newline at end of file diff --git a/reproduction/chinese_word_segment/train_context.py b/reproduction/chinese_word_segment/train_context.py deleted file mode 100644 index 186b8720..00000000 --- a/reproduction/chinese_word_segment/train_context.py +++ /dev/null @@ -1,245 +0,0 @@ - -from fastNLP.api.pipeline import Pipeline -from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor -from fastNLP.api.processor import IndexerProcessor -from reproduction.chinese_word_segment.process.cws_processor import SpeicalSpanProcessor -from reproduction.chinese_word_segment.process.cws_processor import CWSCharSegProcessor -from reproduction.chinese_word_segment.process.cws_processor import CWSSegAppTagProcessor -from reproduction.chinese_word_segment.process.cws_processor import Pre2Post2BigramProcessor -from reproduction.chinese_word_segment.process.cws_processor import VocabProcessor -from reproduction.chinese_word_segment.process.cws_processor import SeqLenProcessor - -from reproduction.chinese_word_segment.process.span_converter import AlphaSpanConverter -from reproduction.chinese_word_segment.process.span_converter import DigitSpanConverter -from reproduction.chinese_word_segment.process.span_converter import TimeConverter -from reproduction.chinese_word_segment.process.span_converter import MixNumAlphaConverter -from reproduction.chinese_word_segment.process.span_converter import EmailConverter -from reproduction.chinese_word_segment.cws_io.cws_reader import NaiveCWSReader -from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMSegApp - -from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 - -ds_name = 'pku' -# tr_filename = '/home/hyan/CWS/Mutil_Criterion/all_data/{}/middle_files/{}_train.txt'.format(ds_name, -# ds_name) -# dev_filename = '/home/hyan/CWS/Mutil_Criterion/all_data/{}/middle_files/{}_dev.txt'.format(ds_name, -# ds_name) - -tr_filename = '/hdd/fudanNLP/CWS/CWS_semiCRF/all_data/{}/middle_files/{}_train.txt'.format(ds_name, - ds_name) -dev_filename = '/hdd/fudanNLP/CWS/CWS_semiCRF/all_data/{}/middle_files/{}_dev.txt'.format(ds_name, - ds_name) - -reader = NaiveCWSReader() - -tr_dataset = reader.load(tr_filename, cut_long_sent=True) -dev_dataset = reader.load(dev_filename) - - -# 1. 准备processor -fs2hs_proc = FullSpaceToHalfSpaceProcessor('raw_sentence') - -# sp_proc = SpeicalSpanProcessor('raw_sentence', 'sentence') -# sp_proc.add_span_converter(EmailConverter()) -# sp_proc.add_span_converter(MixNumAlphaConverter()) -# sp_proc.add_span_converter(AlphaSpanConverter()) -# sp_proc.add_span_converter(DigitSpanConverter()) -# sp_proc.add_span_converter(TimeConverter()) - - -char_proc = CWSCharSegProcessor('raw_sentence', 'chars_list') - -tag_proc = CWSSegAppTagProcessor('raw_sentence', 'tags') - -bigram_proc = Pre2Post2BigramProcessor('chars_list', 'bigrams_list') - -char_vocab_proc = VocabProcessor('chars_list') -bigram_vocab_proc = VocabProcessor('bigrams_list', min_count=4) - -# 2. 使用processor -fs2hs_proc(tr_dataset) - -# sp_proc(tr_dataset) - -char_proc(tr_dataset) -tag_proc(tr_dataset) -bigram_proc(tr_dataset) - -char_vocab_proc(tr_dataset) -bigram_vocab_proc(tr_dataset) - -char_index_proc = IndexerProcessor(char_vocab_proc.get_vocab(), 'chars_list', 'chars', - delete_old_field=False) -bigram_index_proc = IndexerProcessor(bigram_vocab_proc.get_vocab(), 'bigrams_list','bigrams', - delete_old_field=True) -seq_len_proc = SeqLenProcessor('chars') - -char_index_proc(tr_dataset) -bigram_index_proc(tr_dataset) -seq_len_proc(tr_dataset) - -# 2.1 处理dev_dataset -fs2hs_proc(dev_dataset) -# sp_proc(dev_dataset) - -char_proc(dev_dataset) -tag_proc(dev_dataset) -bigram_proc(dev_dataset) - -char_index_proc(dev_dataset) -bigram_index_proc(dev_dataset) -seq_len_proc(dev_dataset) - -print("Finish preparing data.") -print("Vocab size:{}, bigram size:{}.".format(char_vocab_proc.get_vocab_size(), bigram_vocab_proc.get_vocab_size())) - - -# 3. 得到数据集可以用于训练了 -# TODO pretrain的embedding是怎么解决的? - -from reproduction.chinese_word_segment.utils import FocalLoss -from reproduction.chinese_word_segment.utils import seq_lens_to_mask -from fastNLP.core.batch import Batch -from fastNLP.core.sampler import BucketSampler -from fastNLP.core.sampler import SequentialSampler - -import torch -from torch import optim -import sys -from tqdm import tqdm - - -tag_size = tag_proc.tag_size - -cws_model = CWSBiLSTMSegApp(char_vocab_proc.get_vocab_size(), embed_dim=100, - bigram_vocab_num=bigram_vocab_proc.get_vocab_size(), - bigram_embed_dim=100, num_bigram_per_char=8, - hidden_size=200, bidirectional=True, embed_drop_p=None, - num_layers=1, tag_size=tag_size) -cws_model.cuda() - -num_epochs = 3 -loss_fn = FocalLoss(class_num=tag_size) -optimizer = optim.Adagrad(cws_model.parameters(), lr=0.02) - - -print_every = 50 -batch_size = 32 -tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False) -dev_batcher = Batch(dev_dataset, batch_size, SequentialSampler(), use_cuda=False) -num_batch_per_epoch = len(tr_dataset) // batch_size -best_f1 = 0 -best_epoch = 0 -for num_epoch in range(num_epochs): - print('X' * 10 + ' Epoch: {}/{} '.format(num_epoch + 1, num_epochs) + 'X' * 10) - sys.stdout.flush() - avg_loss = 0 - with tqdm(total=num_batch_per_epoch, leave=True) as pbar: - pbar.set_description_str('Epoch:%d' % (num_epoch + 1)) - cws_model.train() - for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1): - optimizer.zero_grad() - - pred_dict = cws_model(**batch_x) # B x L x tag_size - - seq_lens = pred_dict['seq_lens'] - masks = seq_lens_to_mask(seq_lens).float() - tags = batch_y['tags'].long().to(seq_lens.device) - - loss = torch.sum(loss_fn(pred_dict['pred_probs'].view(-1, tag_size), - tags.view(-1)) * masks.view(-1)) / torch.sum(masks) - # loss = torch.mean(F.cross_entropy(probs.view(-1, 2), tags.view(-1)) * masks.float()) - - avg_loss += loss.item() - - loss.backward() - for group in optimizer.param_groups: - for param in group['params']: - param.grad.clamp_(-5, 5) - - optimizer.step() - - if batch_idx % print_every == 0: - pbar.set_postfix_str('batch=%d, avg_loss=%.5f' % (batch_idx, avg_loss / print_every)) - avg_loss = 0 - pbar.update(print_every) - tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False) - # 验证集 - pre, rec, f1 = calculate_pre_rec_f1(cws_model, dev_batcher) - print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1*100, - pre*100, - rec*100)) - if best_f1