From e9c93ad0779a25903bc98d0079d88560d9dcb03c Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Tue, 8 Jan 2019 16:09:36 +0800 Subject: [PATCH] * refactor test API for POS tagging * add default sampler for Batch * fix bug in metrics.py: slice must be integer --- fastNLP/api/api.py | 80 ++++++++++++++++++----------------------- fastNLP/core/batch.py | 4 ++- fastNLP/core/metrics.py | 4 +-- 3 files changed, 40 insertions(+), 48 deletions(-) diff --git a/fastNLP/api/api.py b/fastNLP/api/api.py index 2e6cc247..47c29214 100644 --- a/fastNLP/api/api.py +++ b/fastNLP/api/api.py @@ -10,13 +10,15 @@ from fastNLP.core.dataset import DataSet from fastNLP.api.model_zoo import load_url from fastNLP.api.processor import ModelProcessor from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader -from reproduction.pos_tag_model.pos_reader import ConllPOSReader +from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag from fastNLP.core.instance import Instance from fastNLP.core.sampler import SequentialSampler from fastNLP.core.batch import Batch from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 from fastNLP.api.pipeline import Pipeline +from fastNLP.core.metrics import SpanFPreRecMetric +from fastNLP.api.processor import IndexerProcessor # TODO add pretrain urls @@ -65,7 +67,7 @@ class POS(API): :param content: list of list of str. Each string is a token(word). :return answer: list of list of str. Each string is a tag. """ - if not hasattr(self, 'pipeline'): + if not hasattr(self, "pipeline"): raise ValueError("You have to load model first.") sentence_list = [] @@ -104,47 +106,35 @@ class POS(API): elif isinstance(content, list): return output - def test(self, filepath): - - tag_proc = self._dict['tag_indexer'] - - model = self.pipeline.pipeline[2].model - pipeline = self.pipeline.pipeline[0:2] - pipeline.append(tag_proc) - pp = Pipeline(pipeline) - - reader = ConllPOSReader() - te_dataset = reader.load(filepath) - - """ - evaluator = SeqLabelEvaluator2('word_seq_origin_len') - end_tagidx_set = set() - tag_proc.vocab.build_vocab() - for key, value in tag_proc.vocab.word2idx.items(): - if key.startswith('E-'): - end_tagidx_set.add(value) - if key.startswith('S-'): - end_tagidx_set.add(value) - evaluator.end_tagidx_set = end_tagidx_set - - pp(te_dataset) - te_dataset.set_target(truth=True) - - default_valid_args = {"batch_size": 64, - "use_cuda": True, "evaluator": evaluator, - "model": model, "data": te_dataset} - - tester = Tester(**default_valid_args) - - test_result = tester.test() - - f1 = round(test_result['F'] * 100, 2) - pre = round(test_result['P'] * 100, 2) - rec = round(test_result['R'] * 100, 2) - # print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) - - return f1, pre, rec - """ + def test(self, file_path): + test_data = ZhConllPOSReader().load(file_path) + + tag_vocab = self._dict["tag_vocab"] + pipeline = self._dict["pipeline"] + index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False) + pipeline.pipeline = [index_tag] + pipeline.pipeline + + pipeline(test_data) + test_data.set_target("truth") + prediction = test_data.field_arrays["predict"].content + truth = test_data.field_arrays["truth"].content + seq_len = test_data.field_arrays["word_seq_origin_len"].content + + # padding by hand + max_length = max([len(seq) for seq in prediction]) + for idx in range(len(prediction)): + prediction[idx] = list(prediction[idx]) + ([0] * (max_length - len(prediction[idx]))) + truth[idx] = list(truth[idx]) + ([0] * (max_length - len(truth[idx]))) + evaluator = SpanFPreRecMetric(tag_vocab=tag_vocab, pred="predict", target="truth", + seq_lens="word_seq_origin_len") + evaluator({"predict": torch.Tensor(prediction), "word_seq_origin_len": torch.Tensor(seq_len)}, + {"truth": torch.Tensor(truth)}) + test_result = evaluator.get_metric() + f1 = round(test_result['f'] * 100, 2) + pre = round(test_result['pre'] * 100, 2) + rec = round(test_result['rec'] * 100, 2) + + return {"F1": f1, "precision": pre, "recall": rec} class CWS(API): @@ -316,8 +306,8 @@ if __name__ == "__main__": s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', '那么这款无人机到底有多厉害?'] - # print(pos.test('/Users/yh/Desktop/test_data/pos_test.conll')) - print(pos.predict(s)) + print(pos.test("/home/zyfeng/data/sample.conllx")) + # print(pos.predict(s)) # cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl' # cws = CWS(device='cpu') diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 7b3b6d11..05160312 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -1,6 +1,8 @@ import numpy as np import torch +from fastNLP.core.sampler import RandomSampler + class Batch(object): """Batch is an iterable object which iterates over mini-batches. @@ -17,7 +19,7 @@ class Batch(object): """ - def __init__(self, dataset, batch_size, sampler, as_numpy=False): + def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False): self.dataset = dataset self.batch_size = batch_size self.sampler = sampler diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 39d5bcf3..dfb20480 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -451,8 +451,8 @@ class SpanFPreRecMetric(MetricBase): batch_size = pred.size(0) for i in range(batch_size): - pred_tags = pred[i, :seq_lens[i]].tolist() - gold_tags = target[i, :seq_lens[i]].tolist() + pred_tags = pred[i, :int(seq_lens[i])].tolist() + gold_tags = target[i, :int(seq_lens[i])].tolist() pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags] gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags]