Browse Source

* refactor test API for POS tagging

* add default sampler for Batch
* fix bug in metrics.py: slice must be integer
tags/v0.3.0^2
FengZiYjun 6 years ago
parent
commit
e9c93ad077
3 changed files with 40 additions and 48 deletions
  1. +35
    -45
      fastNLP/api/api.py
  2. +3
    -1
      fastNLP/core/batch.py
  3. +2
    -2
      fastNLP/core/metrics.py

+ 35
- 45
fastNLP/api/api.py View File

@@ -10,13 +10,15 @@ from fastNLP.core.dataset import DataSet
from fastNLP.api.model_zoo import load_url
from fastNLP.api.processor import ModelProcessor
from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader
from reproduction.pos_tag_model.pos_reader import ConllPOSReader
from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader
from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag
from fastNLP.core.instance import Instance
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.batch import Batch
from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1
from fastNLP.api.pipeline import Pipeline
from fastNLP.core.metrics import SpanFPreRecMetric
from fastNLP.api.processor import IndexerProcessor


# TODO add pretrain urls
@@ -65,7 +67,7 @@ class POS(API):
:param content: list of list of str. Each string is a token(word).
:return answer: list of list of str. Each string is a tag.
"""
if not hasattr(self, 'pipeline'):
if not hasattr(self, "pipeline"):
raise ValueError("You have to load model first.")

sentence_list = []
@@ -104,47 +106,35 @@ class POS(API):
elif isinstance(content, list):
return output

def test(self, filepath):

tag_proc = self._dict['tag_indexer']

model = self.pipeline.pipeline[2].model
pipeline = self.pipeline.pipeline[0:2]
pipeline.append(tag_proc)
pp = Pipeline(pipeline)

reader = ConllPOSReader()
te_dataset = reader.load(filepath)

"""
evaluator = SeqLabelEvaluator2('word_seq_origin_len')
end_tagidx_set = set()
tag_proc.vocab.build_vocab()
for key, value in tag_proc.vocab.word2idx.items():
if key.startswith('E-'):
end_tagidx_set.add(value)
if key.startswith('S-'):
end_tagidx_set.add(value)
evaluator.end_tagidx_set = end_tagidx_set

pp(te_dataset)
te_dataset.set_target(truth=True)

default_valid_args = {"batch_size": 64,
"use_cuda": True, "evaluator": evaluator,
"model": model, "data": te_dataset}

tester = Tester(**default_valid_args)

test_result = tester.test()

f1 = round(test_result['F'] * 100, 2)
pre = round(test_result['P'] * 100, 2)
rec = round(test_result['R'] * 100, 2)
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec))

return f1, pre, rec
"""
def test(self, file_path):
test_data = ZhConllPOSReader().load(file_path)

tag_vocab = self._dict["tag_vocab"]
pipeline = self._dict["pipeline"]
index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False)
pipeline.pipeline = [index_tag] + pipeline.pipeline

pipeline(test_data)
test_data.set_target("truth")
prediction = test_data.field_arrays["predict"].content
truth = test_data.field_arrays["truth"].content
seq_len = test_data.field_arrays["word_seq_origin_len"].content

# padding by hand
max_length = max([len(seq) for seq in prediction])
for idx in range(len(prediction)):
prediction[idx] = list(prediction[idx]) + ([0] * (max_length - len(prediction[idx])))
truth[idx] = list(truth[idx]) + ([0] * (max_length - len(truth[idx])))
evaluator = SpanFPreRecMetric(tag_vocab=tag_vocab, pred="predict", target="truth",
seq_lens="word_seq_origin_len")
evaluator({"predict": torch.Tensor(prediction), "word_seq_origin_len": torch.Tensor(seq_len)},
{"truth": torch.Tensor(truth)})
test_result = evaluator.get_metric()
f1 = round(test_result['f'] * 100, 2)
pre = round(test_result['pre'] * 100, 2)
rec = round(test_result['rec'] * 100, 2)

return {"F1": f1, "precision": pre, "recall": rec}


class CWS(API):
@@ -316,8 +306,8 @@ if __name__ == "__main__":
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。',
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
'那么这款无人机到底有多厉害?']
# print(pos.test('/Users/yh/Desktop/test_data/pos_test.conll'))
print(pos.predict(s))
print(pos.test("/home/zyfeng/data/sample.conllx"))
# print(pos.predict(s))

# cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl'
# cws = CWS(device='cpu')


+ 3
- 1
fastNLP/core/batch.py View File

@@ -1,6 +1,8 @@
import numpy as np
import torch

from fastNLP.core.sampler import RandomSampler


class Batch(object):
"""Batch is an iterable object which iterates over mini-batches.
@@ -17,7 +19,7 @@ class Batch(object):

"""

def __init__(self, dataset, batch_size, sampler, as_numpy=False):
def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False):
self.dataset = dataset
self.batch_size = batch_size
self.sampler = sampler


+ 2
- 2
fastNLP/core/metrics.py View File

@@ -451,8 +451,8 @@ class SpanFPreRecMetric(MetricBase):

batch_size = pred.size(0)
for i in range(batch_size):
pred_tags = pred[i, :seq_lens[i]].tolist()
gold_tags = target[i, :seq_lens[i]].tolist()
pred_tags = pred[i, :int(seq_lens[i])].tolist()
gold_tags = target[i, :int(seq_lens[i])].tolist()

pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags]
gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags]


Loading…
Cancel
Save