|
@@ -1,6 +1,3 @@ |
|
|
""" |
|
|
|
|
|
api.api的介绍文档 |
|
|
|
|
|
""" |
|
|
|
|
|
import warnings |
|
|
import warnings |
|
|
|
|
|
|
|
|
import torch |
|
|
import torch |
|
@@ -8,15 +5,14 @@ import torch |
|
|
warnings.filterwarnings('ignore') |
|
|
warnings.filterwarnings('ignore') |
|
|
import os |
|
|
import os |
|
|
|
|
|
|
|
|
from fastNLP.core.dataset import DataSet |
|
|
|
|
|
|
|
|
|
|
|
from fastNLP.api.utils import load_url |
|
|
|
|
|
from fastNLP.api.processor import ModelProcessor |
|
|
|
|
|
from fastNLP.io.dataset_loader import _cut_long_sentence, ConllLoader |
|
|
|
|
|
from fastNLP.core.instance import Instance |
|
|
|
|
|
from fastNLP.api.pipeline import Pipeline |
|
|
|
|
|
from fastNLP.core.metrics import SpanFPreRecMetric |
|
|
|
|
|
from fastNLP.api.processor import IndexerProcessor |
|
|
|
|
|
|
|
|
from ..core.dataset import DataSet |
|
|
|
|
|
from .utils import load_url |
|
|
|
|
|
from .processor import ModelProcessor |
|
|
|
|
|
from ..io.dataset_loader import _cut_long_sentence, ConllLoader |
|
|
|
|
|
from ..core.instance import Instance |
|
|
|
|
|
from ..api.pipeline import Pipeline |
|
|
|
|
|
from ..core.metrics import SpanFPreRecMetric |
|
|
|
|
|
from .processor import IndexerProcessor |
|
|
|
|
|
|
|
|
# TODO add pretrain urls |
|
|
# TODO add pretrain urls |
|
|
model_urls = { |
|
|
model_urls = { |
|
@@ -28,9 +24,10 @@ model_urls = { |
|
|
|
|
|
|
|
|
class ConllCWSReader(object): |
|
|
class ConllCWSReader(object): |
|
|
"""Deprecated. Use ConllLoader for all types of conll-format files.""" |
|
|
"""Deprecated. Use ConllLoader for all types of conll-format files.""" |
|
|
|
|
|
|
|
|
def __init__(self): |
|
|
def __init__(self): |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load(self, path, cut_long_sent=False): |
|
|
def load(self, path, cut_long_sent=False): |
|
|
""" |
|
|
""" |
|
|
返回的DataSet只包含raw_sentence这个field,内容为str。 |
|
|
返回的DataSet只包含raw_sentence这个field,内容为str。 |
|
@@ -63,7 +60,7 @@ class ConllCWSReader(object): |
|
|
sample.append(line.strip().split()) |
|
|
sample.append(line.strip().split()) |
|
|
if len(sample) > 0: |
|
|
if len(sample) > 0: |
|
|
datalist.append(sample) |
|
|
datalist.append(sample) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ds = DataSet() |
|
|
ds = DataSet() |
|
|
for sample in datalist: |
|
|
for sample in datalist: |
|
|
# print(sample) |
|
|
# print(sample) |
|
@@ -78,7 +75,7 @@ class ConllCWSReader(object): |
|
|
for raw_sentence in sents: |
|
|
for raw_sentence in sents: |
|
|
ds.append(Instance(raw_sentence=raw_sentence)) |
|
|
ds.append(Instance(raw_sentence=raw_sentence)) |
|
|
return ds |
|
|
return ds |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_char_lst(self, sample): |
|
|
def get_char_lst(self, sample): |
|
|
if len(sample) == 0: |
|
|
if len(sample) == 0: |
|
|
return None |
|
|
return None |
|
@@ -90,11 +87,13 @@ class ConllCWSReader(object): |
|
|
text.append(t1) |
|
|
text.append(t1) |
|
|
return text |
|
|
return text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConllxDataLoader(ConllLoader): |
|
|
class ConllxDataLoader(ConllLoader): |
|
|
"""返回“词级别”的标签信息,包括词、词性、(句法)头依赖、(句法)边标签。跟``ZhConllPOSReader``完全不同。 |
|
|
"""返回“词级别”的标签信息,包括词、词性、(句法)头依赖、(句法)边标签。跟``ZhConllPOSReader``完全不同。 |
|
|
|
|
|
|
|
|
Deprecated. Use ConllLoader for all types of conll-format files. |
|
|
Deprecated. Use ConllLoader for all types of conll-format files. |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self): |
|
|
def __init__(self): |
|
|
headers = [ |
|
|
headers = [ |
|
|
'words', 'pos_tags', 'heads', 'labels', |
|
|
'words', 'pos_tags', 'heads', 'labels', |
|
@@ -106,18 +105,15 @@ class ConllxDataLoader(ConllLoader): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class API: |
|
|
class API: |
|
|
""" |
|
|
|
|
|
这是 API 类的文档 |
|
|
|
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
def __init__(self): |
|
|
self.pipeline = None |
|
|
self.pipeline = None |
|
|
self._dict = None |
|
|
self._dict = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict(self, *args, **kwargs): |
|
|
def predict(self, *args, **kwargs): |
|
|
"""Do prediction for the given input. |
|
|
"""Do prediction for the given input. |
|
|
""" |
|
|
""" |
|
|
raise NotImplementedError |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test(self, file_path): |
|
|
def test(self, file_path): |
|
|
"""Test performance over the given data set. |
|
|
"""Test performance over the given data set. |
|
|
|
|
|
|
|
@@ -125,7 +121,7 @@ class API: |
|
|
:return: a dictionary of metric values |
|
|
:return: a dictionary of metric values |
|
|
""" |
|
|
""" |
|
|
raise NotImplementedError |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load(self, path, device): |
|
|
def load(self, path, device): |
|
|
if os.path.exists(os.path.expanduser(path)): |
|
|
if os.path.exists(os.path.expanduser(path)): |
|
|
_dict = torch.load(path, map_location='cpu') |
|
|
_dict = torch.load(path, map_location='cpu') |
|
@@ -145,14 +141,14 @@ class POS(API): |
|
|
:param str device: device name such as "cpu" or "cuda:0". Use the same notation as PyTorch. |
|
|
:param str device: device name such as "cpu" or "cuda:0". Use the same notation as PyTorch. |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, model_path=None, device='cpu'): |
|
|
def __init__(self, model_path=None, device='cpu'): |
|
|
super(POS, self).__init__() |
|
|
super(POS, self).__init__() |
|
|
if model_path is None: |
|
|
if model_path is None: |
|
|
model_path = model_urls['pos'] |
|
|
model_path = model_urls['pos'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.load(model_path, device) |
|
|
self.load(model_path, device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict(self, content): |
|
|
def predict(self, content): |
|
|
"""predict函数的介绍, |
|
|
"""predict函数的介绍, |
|
|
函数介绍的第二句,这句话不会换行 |
|
|
函数介绍的第二句,这句话不会换行 |
|
@@ -162,48 +158,48 @@ class POS(API): |
|
|
""" |
|
|
""" |
|
|
if not hasattr(self, "pipeline"): |
|
|
if not hasattr(self, "pipeline"): |
|
|
raise ValueError("You have to load model first.") |
|
|
raise ValueError("You have to load model first.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sentence_list = content |
|
|
sentence_list = content |
|
|
# 1. 检查sentence的类型 |
|
|
# 1. 检查sentence的类型 |
|
|
for sentence in sentence_list: |
|
|
for sentence in sentence_list: |
|
|
if not all((type(obj) == str for obj in sentence)): |
|
|
if not all((type(obj) == str for obj in sentence)): |
|
|
raise ValueError("Input must be list of list of string.") |
|
|
raise ValueError("Input must be list of list of string.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 组建dataset |
|
|
# 2. 组建dataset |
|
|
dataset = DataSet() |
|
|
dataset = DataSet() |
|
|
dataset.add_field("words", sentence_list) |
|
|
dataset.add_field("words", sentence_list) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 3. 使用pipeline |
|
|
# 3. 使用pipeline |
|
|
self.pipeline(dataset) |
|
|
self.pipeline(dataset) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def merge_tag(words_list, tags_list): |
|
|
def merge_tag(words_list, tags_list): |
|
|
rtn = [] |
|
|
rtn = [] |
|
|
for words, tags in zip(words_list, tags_list): |
|
|
for words, tags in zip(words_list, tags_list): |
|
|
rtn.append([w + "/" + t for w, t in zip(words, tags)]) |
|
|
rtn.append([w + "/" + t for w, t in zip(words, tags)]) |
|
|
return rtn |
|
|
return rtn |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output = dataset.field_arrays["tag"].content |
|
|
output = dataset.field_arrays["tag"].content |
|
|
if isinstance(content, str): |
|
|
if isinstance(content, str): |
|
|
return output[0] |
|
|
return output[0] |
|
|
elif isinstance(content, list): |
|
|
elif isinstance(content, list): |
|
|
return merge_tag(content, output) |
|
|
return merge_tag(content, output) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test(self, file_path): |
|
|
def test(self, file_path): |
|
|
test_data = ConllxDataLoader().load(file_path) |
|
|
test_data = ConllxDataLoader().load(file_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
save_dict = self._dict |
|
|
save_dict = self._dict |
|
|
tag_vocab = save_dict["tag_vocab"] |
|
|
tag_vocab = save_dict["tag_vocab"] |
|
|
pipeline = save_dict["pipeline"] |
|
|
pipeline = save_dict["pipeline"] |
|
|
index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False) |
|
|
index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False) |
|
|
pipeline.pipeline = [index_tag] + pipeline.pipeline |
|
|
pipeline.pipeline = [index_tag] + pipeline.pipeline |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_data.rename_field("pos_tags", "tag") |
|
|
test_data.rename_field("pos_tags", "tag") |
|
|
pipeline(test_data) |
|
|
pipeline(test_data) |
|
|
test_data.set_target("truth") |
|
|
test_data.set_target("truth") |
|
|
prediction = test_data.field_arrays["predict"].content |
|
|
prediction = test_data.field_arrays["predict"].content |
|
|
truth = test_data.field_arrays["truth"].content |
|
|
truth = test_data.field_arrays["truth"].content |
|
|
seq_len = test_data.field_arrays["word_seq_origin_len"].content |
|
|
seq_len = test_data.field_arrays["word_seq_origin_len"].content |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# padding by hand |
|
|
# padding by hand |
|
|
max_length = max([len(seq) for seq in prediction]) |
|
|
max_length = max([len(seq) for seq in prediction]) |
|
|
for idx in range(len(prediction)): |
|
|
for idx in range(len(prediction)): |
|
@@ -217,7 +213,7 @@ class POS(API): |
|
|
f1 = round(test_result['f'] * 100, 2) |
|
|
f1 = round(test_result['f'] * 100, 2) |
|
|
pre = round(test_result['pre'] * 100, 2) |
|
|
pre = round(test_result['pre'] * 100, 2) |
|
|
rec = round(test_result['rec'] * 100, 2) |
|
|
rec = round(test_result['rec'] * 100, 2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return {"F1": f1, "precision": pre, "recall": rec} |
|
|
return {"F1": f1, "precision": pre, "recall": rec} |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -228,14 +224,15 @@ class CWS(API): |
|
|
:param model_path: 当model_path为None,使用默认位置的model。如果默认位置不存在,则自动下载模型 |
|
|
:param model_path: 当model_path为None,使用默认位置的model。如果默认位置不存在,则自动下载模型 |
|
|
:param device: str,可以为'cpu', 'cuda'或'cuda:0'等。会将模型load到相应device进行推断。 |
|
|
:param device: str,可以为'cpu', 'cuda'或'cuda:0'等。会将模型load到相应device进行推断。 |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, model_path=None, device='cpu'): |
|
|
def __init__(self, model_path=None, device='cpu'): |
|
|
|
|
|
|
|
|
super(CWS, self).__init__() |
|
|
super(CWS, self).__init__() |
|
|
if model_path is None: |
|
|
if model_path is None: |
|
|
model_path = model_urls['cws'] |
|
|
model_path = model_urls['cws'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.load(model_path, device) |
|
|
self.load(model_path, device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict(self, content): |
|
|
def predict(self, content): |
|
|
""" |
|
|
""" |
|
|
分词接口。 |
|
|
分词接口。 |
|
@@ -246,27 +243,27 @@ class CWS(API): |
|
|
""" |
|
|
""" |
|
|
if not hasattr(self, 'pipeline'): |
|
|
if not hasattr(self, 'pipeline'): |
|
|
raise ValueError("You have to load model first.") |
|
|
raise ValueError("You have to load model first.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sentence_list = [] |
|
|
sentence_list = [] |
|
|
# 1. 检查sentence的类型 |
|
|
# 1. 检查sentence的类型 |
|
|
if isinstance(content, str): |
|
|
if isinstance(content, str): |
|
|
sentence_list.append(content) |
|
|
sentence_list.append(content) |
|
|
elif isinstance(content, list): |
|
|
elif isinstance(content, list): |
|
|
sentence_list = content |
|
|
sentence_list = content |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 组建dataset |
|
|
# 2. 组建dataset |
|
|
dataset = DataSet() |
|
|
dataset = DataSet() |
|
|
dataset.add_field('raw_sentence', sentence_list) |
|
|
dataset.add_field('raw_sentence', sentence_list) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 3. 使用pipeline |
|
|
# 3. 使用pipeline |
|
|
self.pipeline(dataset) |
|
|
self.pipeline(dataset) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output = dataset.get_field('output').content |
|
|
output = dataset.get_field('output').content |
|
|
if isinstance(content, str): |
|
|
if isinstance(content, str): |
|
|
return output[0] |
|
|
return output[0] |
|
|
elif isinstance(content, list): |
|
|
elif isinstance(content, list): |
|
|
return output |
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test(self, filepath): |
|
|
def test(self, filepath): |
|
|
""" |
|
|
""" |
|
|
传入一个分词文件路径,返回该数据集上分词f1, precision, recall。 |
|
|
传入一个分词文件路径,返回该数据集上分词f1, precision, recall。 |
|
@@ -292,28 +289,28 @@ class CWS(API): |
|
|
tag_proc = self._dict['tag_proc'] |
|
|
tag_proc = self._dict['tag_proc'] |
|
|
cws_model = self.pipeline.pipeline[-2].model |
|
|
cws_model = self.pipeline.pipeline[-2].model |
|
|
pipeline = self.pipeline.pipeline[:-2] |
|
|
pipeline = self.pipeline.pipeline[:-2] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pipeline.insert(1, tag_proc) |
|
|
pipeline.insert(1, tag_proc) |
|
|
pp = Pipeline(pipeline) |
|
|
pp = Pipeline(pipeline) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
reader = ConllCWSReader() |
|
|
reader = ConllCWSReader() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# te_filename = '/home/hyan/ctb3/test.conllx' |
|
|
# te_filename = '/home/hyan/ctb3/test.conllx' |
|
|
te_dataset = reader.load(filepath) |
|
|
te_dataset = reader.load(filepath) |
|
|
pp(te_dataset) |
|
|
pp(te_dataset) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from fastNLP.core.tester import Tester |
|
|
from fastNLP.core.tester import Tester |
|
|
from fastNLP.core.metrics import BMESF1PreRecMetric |
|
|
from fastNLP.core.metrics import BMESF1PreRecMetric |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tester = Tester(data=te_dataset, model=cws_model, metrics=BMESF1PreRecMetric(target='target'), batch_size=64, |
|
|
tester = Tester(data=te_dataset, model=cws_model, metrics=BMESF1PreRecMetric(target='target'), batch_size=64, |
|
|
verbose=0) |
|
|
verbose=0) |
|
|
eval_res = tester.test() |
|
|
eval_res = tester.test() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
f1 = eval_res['BMESF1PreRecMetric']['f'] |
|
|
f1 = eval_res['BMESF1PreRecMetric']['f'] |
|
|
pre = eval_res['BMESF1PreRecMetric']['pre'] |
|
|
pre = eval_res['BMESF1PreRecMetric']['pre'] |
|
|
rec = eval_res['BMESF1PreRecMetric']['rec'] |
|
|
rec = eval_res['BMESF1PreRecMetric']['rec'] |
|
|
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) |
|
|
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return {"F1": f1, "precision": pre, "recall": rec} |
|
|
return {"F1": f1, "precision": pre, "recall": rec} |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -322,25 +319,25 @@ class Parser(API): |
|
|
super(Parser, self).__init__() |
|
|
super(Parser, self).__init__() |
|
|
if model_path is None: |
|
|
if model_path is None: |
|
|
model_path = model_urls['parser'] |
|
|
model_path = model_urls['parser'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.pos_tagger = POS(device=device) |
|
|
self.pos_tagger = POS(device=device) |
|
|
self.load(model_path, device) |
|
|
self.load(model_path, device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict(self, content): |
|
|
def predict(self, content): |
|
|
if not hasattr(self, 'pipeline'): |
|
|
if not hasattr(self, 'pipeline'): |
|
|
raise ValueError("You have to load model first.") |
|
|
raise ValueError("You have to load model first.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 1. 利用POS得到分词和pos tagging结果 |
|
|
# 1. 利用POS得到分词和pos tagging结果 |
|
|
pos_out = self.pos_tagger.predict(content) |
|
|
pos_out = self.pos_tagger.predict(content) |
|
|
# pos_out = ['这里/NN 是/VB 分词/NN 结果/NN'.split()] |
|
|
# pos_out = ['这里/NN 是/VB 分词/NN 结果/NN'.split()] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 组建dataset |
|
|
# 2. 组建dataset |
|
|
dataset = DataSet() |
|
|
dataset = DataSet() |
|
|
dataset.add_field('wp', pos_out) |
|
|
dataset.add_field('wp', pos_out) |
|
|
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[0] for w in x['wp']], new_field_name='words') |
|
|
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[0] for w in x['wp']], new_field_name='words') |
|
|
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[1] for w in x['wp']], new_field_name='pos') |
|
|
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[1] for w in x['wp']], new_field_name='pos') |
|
|
dataset.rename_field("words", "raw_words") |
|
|
dataset.rename_field("words", "raw_words") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 3. 使用pipeline |
|
|
# 3. 使用pipeline |
|
|
self.pipeline(dataset) |
|
|
self.pipeline(dataset) |
|
|
dataset.apply(lambda x: [str(arc) for arc in x['arc_pred']], new_field_name='arc_pred') |
|
|
dataset.apply(lambda x: [str(arc) for arc in x['arc_pred']], new_field_name='arc_pred') |
|
@@ -348,7 +345,7 @@ class Parser(API): |
|
|
zip(x['arc_pred'], x['label_pred_seq'])][1:], new_field_name='output') |
|
|
zip(x['arc_pred'], x['label_pred_seq'])][1:], new_field_name='output') |
|
|
# output like: [['2/top', '0/root', '4/nn', '2/dep']] |
|
|
# output like: [['2/top', '0/root', '4/nn', '2/dep']] |
|
|
return dataset.field_arrays['output'].content |
|
|
return dataset.field_arrays['output'].content |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_test_file(self, path): |
|
|
def load_test_file(self, path): |
|
|
def get_one(sample): |
|
|
def get_one(sample): |
|
|
sample = list(map(list, zip(*sample))) |
|
|
sample = list(map(list, zip(*sample))) |
|
@@ -360,7 +357,7 @@ class Parser(API): |
|
|
return None |
|
|
return None |
|
|
# return word_seq, pos_seq, head_seq, head_tag_seq |
|
|
# return word_seq, pos_seq, head_seq, head_tag_seq |
|
|
return sample[1], sample[3], list(map(int, sample[6])), sample[7] |
|
|
return sample[1], sample[3], list(map(int, sample[6])), sample[7] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
datalist = [] |
|
|
datalist = [] |
|
|
with open(path, 'r', encoding='utf-8') as f: |
|
|
with open(path, 'r', encoding='utf-8') as f: |
|
|
sample = [] |
|
|
sample = [] |
|
@@ -374,14 +371,14 @@ class Parser(API): |
|
|
sample.append(line.split('\t')) |
|
|
sample.append(line.split('\t')) |
|
|
if len(sample) > 0: |
|
|
if len(sample) > 0: |
|
|
datalist.append(sample) |
|
|
datalist.append(sample) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data = [get_one(sample) for sample in datalist] |
|
|
data = [get_one(sample) for sample in datalist] |
|
|
data_list = list(filter(lambda x: x is not None, data)) |
|
|
data_list = list(filter(lambda x: x is not None, data)) |
|
|
return data_list |
|
|
return data_list |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test(self, filepath): |
|
|
def test(self, filepath): |
|
|
data = self.load_test_file(filepath) |
|
|
data = self.load_test_file(filepath) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert(data): |
|
|
def convert(data): |
|
|
BOS = '<BOS>' |
|
|
BOS = '<BOS>' |
|
|
dataset = DataSet() |
|
|
dataset = DataSet() |
|
@@ -396,7 +393,7 @@ class Parser(API): |
|
|
arc_true=heads, |
|
|
arc_true=heads, |
|
|
tags=head_tags)) |
|
|
tags=head_tags)) |
|
|
return dataset |
|
|
return dataset |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ds = convert(data) |
|
|
ds = convert(data) |
|
|
pp = self.pipeline |
|
|
pp = self.pipeline |
|
|
for p in pp: |
|
|
for p in pp: |
|
@@ -417,23 +414,23 @@ class Parser(API): |
|
|
head_cor += 1 if head_pred[i] == head_gold[i] else 0 |
|
|
head_cor += 1 if head_pred[i] == head_gold[i] else 0 |
|
|
uas = head_cor / total |
|
|
uas = head_cor / total |
|
|
# print('uas:{:.2f}'.format(uas)) |
|
|
# print('uas:{:.2f}'.format(uas)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for p in pp: |
|
|
for p in pp: |
|
|
if p.field_name == 'gold_words': |
|
|
if p.field_name == 'gold_words': |
|
|
p.field_name = 'word_list' |
|
|
p.field_name = 'word_list' |
|
|
elif p.field_name == 'gold_pos': |
|
|
elif p.field_name == 'gold_pos': |
|
|
p.field_name = 'pos_list' |
|
|
p.field_name = 'pos_list' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return {"USA": round(uas, 5)} |
|
|
return {"USA": round(uas, 5)} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Analyzer: |
|
|
class Analyzer: |
|
|
def __init__(self, device='cpu'): |
|
|
def __init__(self, device='cpu'): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.cws = CWS(device=device) |
|
|
self.cws = CWS(device=device) |
|
|
self.pos = POS(device=device) |
|
|
self.pos = POS(device=device) |
|
|
self.parser = Parser(device=device) |
|
|
self.parser = Parser(device=device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict(self, content, seg=False, pos=False, parser=False): |
|
|
def predict(self, content, seg=False, pos=False, parser=False): |
|
|
if seg is False and pos is False and parser is False: |
|
|
if seg is False and pos is False and parser is False: |
|
|
seg = True |
|
|
seg = True |
|
@@ -447,9 +444,9 @@ class Analyzer: |
|
|
if parser: |
|
|
if parser: |
|
|
parser_output = self.parser.predict(content) |
|
|
parser_output = self.parser.predict(content) |
|
|
output_dict['parser'] = parser_output |
|
|
output_dict['parser'] = parser_output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return output_dict |
|
|
return output_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test(self, filepath): |
|
|
def test(self, filepath): |
|
|
output_dict = {} |
|
|
output_dict = {} |
|
|
if self.cws: |
|
|
if self.cws: |
|
@@ -461,5 +458,5 @@ class Analyzer: |
|
|
if self.parser: |
|
|
if self.parser: |
|
|
parser_output = self.parser.test(filepath) |
|
|
parser_output = self.parser.test(filepath) |
|
|
output_dict['parser'] = parser_output |
|
|
output_dict['parser'] = parser_output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return output_dict |
|
|
return output_dict |