diff --git a/fastNLP/api/api.py b/fastNLP/api/api.py index ff3f4260..35590d9c 100644 --- a/fastNLP/api/api.py +++ b/fastNLP/api/api.py @@ -5,6 +5,16 @@ import os from fastNLP.core.dataset import DataSet from fastNLP.api.model_zoo import load_url +from fastNLP.api.processor import ModelProcessor +from reproduction.chinese_word_segment.cws_io.cws_reader import ConlluCWSReader +from reproduction.pos_tag_model.pos_io.pos_reader import ConlluPOSReader +from fastNLP.core.sampler import SequentialSampler +from fastNLP.core.batch import Batch +from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 +from fastNLP.api.pipeline import Pipeline +from fastNLP.core.metrics import SeqLabelEvaluator2 +from fastNLP.core.tester import Tester + model_urls = { } @@ -17,12 +27,17 @@ class API: def predict(self, *args, **kwargs): raise NotImplementedError - def load(self, path): + def load(self, path, device): if os.path.exists(os.path.expanduser(path)): - _dict = torch.load(path) + _dict = torch.load(path, map_location='cpu') else: - _dict = load_url(path) + print(os.path.expanduser(path)) + _dict = load_url(path, map_location='cpu') self.pipeline = _dict['pipeline'] + self._dict = _dict + for processor in self.pipeline.pipeline: + if isinstance(processor, ModelProcessor): + processor.set_model_device(device) class POS(API): @@ -30,12 +45,12 @@ class POS(API): """ - def __init__(self, model_path=None): + def __init__(self, model_path=None, device='cpu'): super(POS, self).__init__() if model_path is None: model_path = model_urls['pos'] - self.load(model_path) + self.load(model_path, device) def predict(self, content): """ @@ -66,14 +81,53 @@ class POS(API): elif isinstance(content, list): return output + def test(self, filepath): + + tag_proc = self._dict['tag_indexer'] + + model = self.pipeline.pipeline[2].model + pipeline = self.pipeline.pipeline[0:2] + pipeline.append(tag_proc) + pp = Pipeline(pipeline) + + reader = ConlluPOSReader() + te_dataset = reader.load(filepath) + + evaluator = SeqLabelEvaluator2('word_seq_origin_len') + end_tagidx_set = set() + tag_proc.vocab.build_vocab() + for key, value in tag_proc.vocab.word2idx.items(): + if key.startswith('E-'): + end_tagidx_set.add(value) + if key.startswith('S-'): + end_tagidx_set.add(value) + evaluator.end_tagidx_set = end_tagidx_set + + default_valid_args = {"batch_size": 64, + "use_cuda": True, "evaluator": evaluator} + + pp(te_dataset) + te_dataset.set_is_target(truth=True) + + tester = Tester(**default_valid_args) + + test_result = tester.test(model, te_dataset) + + f1 = round(test_result['F'] * 100, 2) + pre = round(test_result['P'] * 100, 2) + rec = round(test_result['R'] * 100, 2) + print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) + + return f1, pre, rec + class CWS(API): - def __init__(self, model_path=None): + def __init__(self, model_path=None, device='cpu'): super(CWS, self).__init__() if model_path is None: model_path = model_urls['cws'] - self.load(model_path) + self.load(model_path, device) def predict(self, content): @@ -100,17 +154,45 @@ class CWS(API): elif isinstance(content, list): return output + def test(self, filepath): + + tag_proc = self._dict['tag_indexer'] + cws_model = self.pipeline.pipeline[-2].model + pipeline = self.pipeline.pipeline[:5] + + pipeline.insert(1, tag_proc) + pp = Pipeline(pipeline) + + reader = ConlluCWSReader() + + # te_filename = '/home/hyan/ctb3/test.conllx' + te_dataset = reader.load(filepath) + pp(te_dataset) + + batch_size = 64 + te_batcher = Batch(te_dataset, batch_size, SequentialSampler(), use_cuda=False) + pre, rec, f1 = calculate_pre_rec_f1(cws_model, te_batcher, type='bmes') + f1 = round(f1 * 100, 2) + pre = round(pre * 100, 2) + rec = round(rec * 100, 2) + print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) + + return f1, pre, rec if __name__ == "__main__": - pos = POS() + # pos_model_path = '../../reproduction/pos_tag_model/pos_crf.pkl' + pos = POS(device='cpu') s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。' , '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', '那么这款无人机到底有多厉害?'] + print(pos.test('../../reproduction/chinese_word_segment/new-clean.txt.conll')) print(pos.predict(s)) - # cws = CWS() - # s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。' , - # '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', - # '那么这款无人机到底有多厉害?'] - # print(cws.predict(s)) + # cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl' + cws = CWS(device='cuda:0') + s = ['本品是一个抗酸抗胆汁的胃黏膜保护剂' , + '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', + '那么这款无人机到底有多厉害?'] + print(cws.test('../../reproduction/chinese_word_segment/new-clean.txt.conll')) + cws.predict(s) diff --git a/fastNLP/api/processor.py b/fastNLP/api/processor.py index 91935fd1..df868b8c 100644 --- a/fastNLP/api/processor.py +++ b/fastNLP/api/processor.py @@ -234,6 +234,10 @@ class ModelProcessor(Processor): def set_model(self, model): self.model = model + def set_model_device(self, device): + device = torch.device(device) + self.model.to(device) + class Index2WordProcessor(Processor): def __init__(self, vocab, field_name, new_added_field_name): super(Index2WordProcessor, self).__init__(field_name, new_added_field_name) diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index dfdd397d..0c7456c7 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -53,7 +53,7 @@ class Tester(object): else: # Tester doesn't care about extra arguments pass - print(default_args) + # print(default_args) self.batch_size = default_args["batch_size"] self.pickle_path = default_args["pickle_path"] @@ -84,8 +84,8 @@ class Tester(object): for k, v in batch_y.items(): truths[k].append(v) eval_results = self.evaluate(**output, **truths) - print("[tester] {}".format(self.print_eval_results(eval_results))) - logger.info("[tester] {}".format(self.print_eval_results(eval_results))) + # print("[tester] {}".format(self.print_eval_results(eval_results))) + # logger.info("[tester] {}".format(self.print_eval_results(eval_results))) self.mode(network, is_test=False) self.metrics = eval_results return eval_results