|
|
@@ -10,13 +10,15 @@ from fastNLP.core.dataset import DataSet |
|
|
|
from fastNLP.api.model_zoo import load_url |
|
|
|
from fastNLP.api.processor import ModelProcessor |
|
|
|
from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader |
|
|
|
from reproduction.pos_tag_model.pos_reader import ConllPOSReader |
|
|
|
from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader |
|
|
|
from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag |
|
|
|
from fastNLP.core.instance import Instance |
|
|
|
from fastNLP.core.sampler import SequentialSampler |
|
|
|
from fastNLP.core.batch import Batch |
|
|
|
from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 |
|
|
|
from fastNLP.api.pipeline import Pipeline |
|
|
|
from fastNLP.core.metrics import SpanFPreRecMetric |
|
|
|
from fastNLP.api.processor import IndexerProcessor |
|
|
|
|
|
|
|
|
|
|
|
# TODO add pretrain urls |
|
|
@@ -65,7 +67,7 @@ class POS(API): |
|
|
|
:param content: list of list of str. Each string is a token(word). |
|
|
|
:return answer: list of list of str. Each string is a tag. |
|
|
|
""" |
|
|
|
if not hasattr(self, 'pipeline'): |
|
|
|
if not hasattr(self, "pipeline"): |
|
|
|
raise ValueError("You have to load model first.") |
|
|
|
|
|
|
|
sentence_list = [] |
|
|
@@ -104,47 +106,35 @@ class POS(API): |
|
|
|
elif isinstance(content, list): |
|
|
|
return output |
|
|
|
|
|
|
|
def test(self, filepath): |
|
|
|
|
|
|
|
tag_proc = self._dict['tag_indexer'] |
|
|
|
|
|
|
|
model = self.pipeline.pipeline[2].model |
|
|
|
pipeline = self.pipeline.pipeline[0:2] |
|
|
|
pipeline.append(tag_proc) |
|
|
|
pp = Pipeline(pipeline) |
|
|
|
|
|
|
|
reader = ConllPOSReader() |
|
|
|
te_dataset = reader.load(filepath) |
|
|
|
|
|
|
|
""" |
|
|
|
evaluator = SeqLabelEvaluator2('word_seq_origin_len') |
|
|
|
end_tagidx_set = set() |
|
|
|
tag_proc.vocab.build_vocab() |
|
|
|
for key, value in tag_proc.vocab.word2idx.items(): |
|
|
|
if key.startswith('E-'): |
|
|
|
end_tagidx_set.add(value) |
|
|
|
if key.startswith('S-'): |
|
|
|
end_tagidx_set.add(value) |
|
|
|
evaluator.end_tagidx_set = end_tagidx_set |
|
|
|
|
|
|
|
pp(te_dataset) |
|
|
|
te_dataset.set_target(truth=True) |
|
|
|
|
|
|
|
default_valid_args = {"batch_size": 64, |
|
|
|
"use_cuda": True, "evaluator": evaluator, |
|
|
|
"model": model, "data": te_dataset} |
|
|
|
|
|
|
|
tester = Tester(**default_valid_args) |
|
|
|
|
|
|
|
test_result = tester.test() |
|
|
|
|
|
|
|
f1 = round(test_result['F'] * 100, 2) |
|
|
|
pre = round(test_result['P'] * 100, 2) |
|
|
|
rec = round(test_result['R'] * 100, 2) |
|
|
|
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) |
|
|
|
|
|
|
|
return f1, pre, rec |
|
|
|
""" |
|
|
|
def test(self, file_path): |
|
|
|
test_data = ZhConllPOSReader().load(file_path) |
|
|
|
|
|
|
|
tag_vocab = self._dict["tag_vocab"] |
|
|
|
pipeline = self._dict["pipeline"] |
|
|
|
index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False) |
|
|
|
pipeline.pipeline = [index_tag] + pipeline.pipeline |
|
|
|
|
|
|
|
pipeline(test_data) |
|
|
|
test_data.set_target("truth") |
|
|
|
prediction = test_data.field_arrays["predict"].content |
|
|
|
truth = test_data.field_arrays["truth"].content |
|
|
|
seq_len = test_data.field_arrays["word_seq_origin_len"].content |
|
|
|
|
|
|
|
# padding by hand |
|
|
|
max_length = max([len(seq) for seq in prediction]) |
|
|
|
for idx in range(len(prediction)): |
|
|
|
prediction[idx] = list(prediction[idx]) + ([0] * (max_length - len(prediction[idx]))) |
|
|
|
truth[idx] = list(truth[idx]) + ([0] * (max_length - len(truth[idx]))) |
|
|
|
evaluator = SpanFPreRecMetric(tag_vocab=tag_vocab, pred="predict", target="truth", |
|
|
|
seq_lens="word_seq_origin_len") |
|
|
|
evaluator({"predict": torch.Tensor(prediction), "word_seq_origin_len": torch.Tensor(seq_len)}, |
|
|
|
{"truth": torch.Tensor(truth)}) |
|
|
|
test_result = evaluator.get_metric() |
|
|
|
f1 = round(test_result['f'] * 100, 2) |
|
|
|
pre = round(test_result['pre'] * 100, 2) |
|
|
|
rec = round(test_result['rec'] * 100, 2) |
|
|
|
|
|
|
|
return {"F1": f1, "precision": pre, "recall": rec} |
|
|
|
|
|
|
|
|
|
|
|
class CWS(API): |
|
|
@@ -316,8 +306,8 @@ if __name__ == "__main__": |
|
|
|
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', |
|
|
|
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', |
|
|
|
'那么这款无人机到底有多厉害?'] |
|
|
|
# print(pos.test('/Users/yh/Desktop/test_data/pos_test.conll')) |
|
|
|
print(pos.predict(s)) |
|
|
|
print(pos.test("/home/zyfeng/data/sample.conllx")) |
|
|
|
# print(pos.predict(s)) |
|
|
|
|
|
|
|
# cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl' |
|
|
|
# cws = CWS(device='cpu') |
|
|
|