@@ -72,7 +72,9 @@ __all__ = [ | |||||
"QuoraLoader", | "QuoraLoader", | ||||
"SNLILoader", | "SNLILoader", | ||||
"QNLILoader", | "QNLILoader", | ||||
"RTELoader" | |||||
"RTELoader", | |||||
"CRLoader" | |||||
] | ] | ||||
from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader | from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader | ||||
from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader | from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader | ||||
@@ -82,3 +84,4 @@ from .json import JsonLoader | |||||
from .loader import Loader | from .loader import Loader | ||||
from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader | from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader | ||||
from .conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader | from .conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader | ||||
from .coreference import CRLoader |
@@ -0,0 +1,46 @@ | |||||
"""undocumented""" | |||||
from ...core.dataset import DataSet | |||||
from ..file_reader import _read_json | |||||
from ...core.instance import Instance | |||||
from ...core.const import Const | |||||
from .json import JsonLoader | |||||
class CRLoader(JsonLoader): | |||||
""" | |||||
原始数据中内容应该为, 每一行为一个json对象,其中doc_key包含文章的种类信息,speakers包含每句话的说话者信息,cluster是指向现实中同一个事物的聚集,sentences是文本信息内容。 | |||||
Example:: | |||||
{"doc_key":"bc/cctv/00/cctv_001", | |||||
"speakers":"[["Speaker1","Speaker1","Speaker1"],["Speaker1","Speaker1","Speaker1"]]", | |||||
"clusters":"[[[2,3],[4,5]],[7,8],[18,20]]]", | |||||
"sentences":[["I","have","an","apple"],["It","is","good"]] | |||||
} | |||||
读取预处理好的Conll2012数据。 | |||||
""" | |||||
def __init__(self, fields=None, dropna=False): | |||||
super().__init__(fields, dropna) | |||||
# self.fields = {"doc_key":Const.INPUTS(0),"speakers":Const.INPUTS(1),"clusters":Const.TARGET,"sentences":Const.INPUTS(2)} | |||||
# TODO check 1 | |||||
self.fields = {"doc_key": Const.RAW_WORDS(0), "speakers": Const.RAW_WORDS(1), "clusters": Const.RAW_WORDS(2), | |||||
"sentences": Const.RAW_WORDS(3)} | |||||
def _load(self, path): | |||||
""" | |||||
加载数据 | |||||
:param path: 数据文件路径,文件为json | |||||
:return: | |||||
""" | |||||
dataset = DataSet() | |||||
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): | |||||
if self.fields: | |||||
ins = {self.fields[k]: v for k, v in d.items()} | |||||
else: | |||||
ins = d | |||||
dataset.append(Instance(**ins)) | |||||
return dataset |
@@ -38,6 +38,8 @@ __all__ = [ | |||||
"QuoraPipe", | "QuoraPipe", | ||||
"QNLIPipe", | "QNLIPipe", | ||||
"MNLIPipe", | "MNLIPipe", | ||||
"CoreferencePipe" | |||||
] | ] | ||||
from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe | from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe | ||||
@@ -47,3 +49,4 @@ from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe | |||||
from .pipe import Pipe | from .pipe import Pipe | ||||
from .conll import Conll2003Pipe | from .conll import Conll2003Pipe | ||||
from .cws import CWSPipe | from .cws import CWSPipe | ||||
from .coreference import CoreferencePipe |
@@ -0,0 +1,170 @@ | |||||
"""undocumented""" | |||||
__all__ = [ | |||||
"CoreferencePipe" | |||||
] | |||||
from .pipe import Pipe | |||||
from ..data_bundle import DataBundle | |||||
from ..loader.coreference import CRLoader | |||||
from ...core.const import Const | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
import numpy as np | |||||
import collections | |||||
class CoreferencePipe(Pipe): | |||||
""" | |||||
对Coreference resolution问题进行处理,得到文章种类/说话者/字符级信息/序列长度。 | |||||
""" | |||||
def __init__(self,config): | |||||
super().__init__() | |||||
self.config = config | |||||
def process(self, data_bundle: DataBundle): | |||||
""" | |||||
对load进来的数据进一步处理 | |||||
原始数据包含:raw_key,raw_speaker,raw_words,raw_clusters | |||||
.. csv-table:: | |||||
:header: "raw_key", "raw_speaker","raw_words","raw_clusters" | |||||
"bc/cctv/00/cctv_0000_0", "[[Speaker#1, Speaker#1],[]]","[['I','am'],[]]","[[[2,3],[6,7]],[[10,12],[20,22]]]" | |||||
"bc/cctv/00/cctv_0000_1", "[['Speaker#1', 'peaker#1'],[]]","[['He','is'],[]]","[[[2,3],[6,7]],[[10,12],[20,22]]]" | |||||
"[...]", "[...]","[...]","[...]" | |||||
处理完成后数据包含文章类别、speaker信息、句子信息、句子对应的index、char、句子长度、target: | |||||
.. csv-table:: | |||||
:header: "words1", "words2","words3","words4","chars","seq_len","target" | |||||
"bc", "[[0,0],[1,1]]","[['I','am'],[]]","[[1,2],[]]","[[[1],[2,3]],[]]","[2,3]","[[[2,3],[6,7]],[[10,12],[20,22]]]" | |||||
"[...]", "[...]","[...]","[...]","[...]","[...]","[...]" | |||||
:param data_bundle: | |||||
:return: | |||||
""" | |||||
genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])} | |||||
vocab = Vocabulary().from_dataset(*data_bundle.datasets.values(), field_name= Const.RAW_WORDS(3)) | |||||
vocab.build_vocab() | |||||
word2id = vocab.word2idx | |||||
data_bundle.set_vocab(vocab,Const.INPUT) | |||||
if self.config.char_path: | |||||
char_dict = get_char_dict(self.config.char_path) | |||||
else: | |||||
char_set = set() | |||||
for i,w in enumerate(word2id): | |||||
if i < 2: | |||||
continue | |||||
for c in w: | |||||
char_set.add(c) | |||||
char_dict = collections.defaultdict(int) | |||||
char_dict.update({c: i for i, c in enumerate(char_set)}) | |||||
for name, ds in data_bundle.datasets.items(): | |||||
# genre | |||||
ds.apply(lambda x: genres[x[Const.RAW_WORDS(0)][:2]], new_field_name=Const.INPUTS(0)) | |||||
# speaker_ids_np | |||||
ds.apply(lambda x: speaker2numpy(x[Const.RAW_WORDS(1)], self.config.max_sentences, is_train=name == 'train'), | |||||
new_field_name=Const.INPUTS(1)) | |||||
# sentences | |||||
ds.rename_field(Const.RAW_WORDS(3),Const.INPUTS(2)) | |||||
# doc_np | |||||
ds.apply(lambda x: doc2numpy(x[Const.INPUTS(2)], word2id, char_dict, max(self.config.filter), | |||||
self.config.max_sentences, is_train=name == 'train')[0], | |||||
new_field_name=Const.INPUTS(3)) | |||||
# char_index | |||||
ds.apply(lambda x: doc2numpy(x[Const.INPUTS(2)], word2id, char_dict, max(self.config.filter), | |||||
self.config.max_sentences, is_train=name == 'train')[1], | |||||
new_field_name=Const.CHAR_INPUT) | |||||
# seq len | |||||
ds.apply(lambda x: doc2numpy(x[Const.INPUTS(2)], word2id, char_dict, max(self.config.filter), | |||||
self.config.max_sentences, is_train=name == 'train')[2], | |||||
new_field_name=Const.INPUT_LEN) | |||||
# clusters | |||||
ds.rename_field(Const.RAW_WORDS(2), Const.TARGET) | |||||
ds.set_ignore_type(Const.TARGET) | |||||
ds.set_padder(Const.TARGET, None) | |||||
ds.set_input(Const.INPUTS(0), Const.INPUTS(1), Const.INPUTS(2), Const.INPUTS(3), Const.CHAR_INPUT, Const.INPUT_LEN) | |||||
ds.set_target(Const.TARGET) | |||||
return data_bundle | |||||
def process_from_file(self, paths): | |||||
bundle = CRLoader().load(paths) | |||||
return self.process(bundle) | |||||
# helper | |||||
def doc2numpy(doc, word2id, chardict, max_filter, max_sentences, is_train): | |||||
docvec, char_index, length, max_len = _doc2vec(doc, word2id, chardict, max_filter, max_sentences, is_train) | |||||
assert max(length) == max_len | |||||
assert char_index.shape[0] == len(length) | |||||
assert char_index.shape[1] == max_len | |||||
doc_np = np.zeros((len(docvec), max_len), int) | |||||
for i in range(len(docvec)): | |||||
for j in range(len(docvec[i])): | |||||
doc_np[i][j] = docvec[i][j] | |||||
return doc_np, char_index, length | |||||
def _doc2vec(doc,word2id,char_dict,max_filter,max_sentences,is_train): | |||||
max_len = 0 | |||||
max_word_length = 0 | |||||
docvex = [] | |||||
length = [] | |||||
if is_train: | |||||
sent_num = min(max_sentences,len(doc)) | |||||
else: | |||||
sent_num = len(doc) | |||||
for i in range(sent_num): | |||||
sent = doc[i] | |||||
length.append(len(sent)) | |||||
if (len(sent) > max_len): | |||||
max_len = len(sent) | |||||
sent_vec =[] | |||||
for j,word in enumerate(sent): | |||||
if len(word)>max_word_length: | |||||
max_word_length = len(word) | |||||
if word in word2id: | |||||
sent_vec.append(word2id[word]) | |||||
else: | |||||
sent_vec.append(word2id["UNK"]) | |||||
docvex.append(sent_vec) | |||||
char_index = np.zeros((sent_num, max_len, max_word_length),dtype=int) | |||||
for i in range(sent_num): | |||||
sent = doc[i] | |||||
for j,word in enumerate(sent): | |||||
char_index[i, j, :len(word)] = [char_dict[c] for c in word] | |||||
return docvex,char_index,length,max_len | |||||
def speaker2numpy(speakers_raw,max_sentences,is_train): | |||||
if is_train and len(speakers_raw)> max_sentences: | |||||
speakers_raw = speakers_raw[0:max_sentences] | |||||
speakers = flatten(speakers_raw) | |||||
speaker_dict = {s: i for i, s in enumerate(set(speakers))} | |||||
speaker_ids = np.array([speaker_dict[s] for s in speakers]) | |||||
return speaker_ids | |||||
# 展平 | |||||
def flatten(l): | |||||
return [item for sublist in l for item in sublist] | |||||
def get_char_dict(path): | |||||
vocab = ["<UNK>"] | |||||
with open(path) as f: | |||||
vocab.extend(c.strip() for c in f.readlines()) | |||||
char_dict = collections.defaultdict(int) | |||||
char_dict.update({c: i for i, c in enumerate(vocab)}) | |||||
return char_dict |
@@ -1,4 +1,4 @@ | |||||
# 共指消解复现 | |||||
# 指代消解复现 | |||||
## 介绍 | ## 介绍 | ||||
Coreference resolution是查找文本中指向同一现实实体的所有表达式的任务。 | Coreference resolution是查找文本中指向同一现实实体的所有表达式的任务。 | ||||
对于涉及自然语言理解的许多更高级别的NLP任务来说, | 对于涉及自然语言理解的许多更高级别的NLP任务来说, | ||||
@@ -1,68 +0,0 @@ | |||||
from fastNLP.io.dataset_loader import JsonLoader,DataSet,Instance | |||||
from fastNLP.io.file_reader import _read_json | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.io.data_bundle import DataBundle | |||||
from reproduction.coreference_resolution.model.config import Config | |||||
import reproduction.coreference_resolution.model.preprocess as preprocess | |||||
class CRLoader(JsonLoader): | |||||
def __init__(self, fields=None, dropna=False): | |||||
super().__init__(fields, dropna) | |||||
def _load(self, path): | |||||
""" | |||||
加载数据 | |||||
:param path: | |||||
:return: | |||||
""" | |||||
dataset = DataSet() | |||||
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): | |||||
if self.fields: | |||||
ins = {self.fields[k]: v for k, v in d.items()} | |||||
else: | |||||
ins = d | |||||
dataset.append(Instance(**ins)) | |||||
return dataset | |||||
def process(self, paths, **kwargs): | |||||
data_info = DataBundle() | |||||
for name in ['train', 'test', 'dev']: | |||||
data_info.datasets[name] = self.load(paths[name]) | |||||
config = Config() | |||||
vocab = Vocabulary().from_dataset(*data_info.datasets.values(), field_name='sentences') | |||||
vocab.build_vocab() | |||||
word2id = vocab.word2idx | |||||
char_dict = preprocess.get_char_dict(config.char_path) | |||||
data_info.vocabs = vocab | |||||
genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])} | |||||
for name, ds in data_info.datasets.items(): | |||||
ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter), | |||||
config.max_sentences, is_train=name=='train')[0], | |||||
new_field_name='doc_np') | |||||
ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter), | |||||
config.max_sentences, is_train=name=='train')[1], | |||||
new_field_name='char_index') | |||||
ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter), | |||||
config.max_sentences, is_train=name=='train')[2], | |||||
new_field_name='seq_len') | |||||
ds.apply(lambda x: preprocess.speaker2numpy(x["speakers"], config.max_sentences, is_train=name=='train'), | |||||
new_field_name='speaker_ids_np') | |||||
ds.apply(lambda x: genres[x["doc_key"][:2]], new_field_name='genre') | |||||
ds.set_ignore_type('clusters') | |||||
ds.set_padder('clusters', None) | |||||
ds.set_input("sentences", "doc_np", "speaker_ids_np", "genre", "char_index", "seq_len") | |||||
ds.set_target("clusters") | |||||
# train_dev, test = self.ds.split(348 / (2802 + 343 + 348), shuffle=False) | |||||
# train, dev = train_dev.split(343 / (2802 + 343), shuffle=False) | |||||
return data_info | |||||
@@ -8,6 +8,7 @@ from fastNLP.models.base_model import BaseModel | |||||
from fastNLP.modules.encoder.variational_rnn import VarLSTM | from fastNLP.modules.encoder.variational_rnn import VarLSTM | ||||
from reproduction.coreference_resolution.model import preprocess | from reproduction.coreference_resolution.model import preprocess | ||||
from fastNLP.io.embed_loader import EmbedLoader | from fastNLP.io.embed_loader import EmbedLoader | ||||
from fastNLP.core.const import Const | |||||
import random | import random | ||||
# 设置seed | # 设置seed | ||||
@@ -415,7 +416,7 @@ class Model(BaseModel): | |||||
return predicted_clusters | return predicted_clusters | ||||
def forward(self, sentences, doc_np, speaker_ids_np, genre, char_index, seq_len): | |||||
def forward(self, words1 , words2, words3, words4, chars, seq_len): | |||||
""" | """ | ||||
实际输入都是tensor | 实际输入都是tensor | ||||
:param sentences: 句子,被fastNLP转化成了numpy, | :param sentences: 句子,被fastNLP转化成了numpy, | ||||
@@ -426,6 +427,14 @@ class Model(BaseModel): | |||||
:param seq_len: 被fastNLP转化成了Tensor | :param seq_len: 被fastNLP转化成了Tensor | ||||
:return: | :return: | ||||
""" | """ | ||||
sentences = words3 | |||||
doc_np = words4 | |||||
speaker_ids_np = words2 | |||||
genre = words1 | |||||
char_index = chars | |||||
# change for fastNLP | # change for fastNLP | ||||
sentences = sentences[0].tolist() | sentences = sentences[0].tolist() | ||||
doc_tensor = doc_np[0] | doc_tensor = doc_np[0] | ||||
@@ -11,18 +11,18 @@ class SoftmaxLoss(LossBase): | |||||
允许多标签分类 | 允许多标签分类 | ||||
""" | """ | ||||
def __init__(self, antecedent_scores=None, clusters=None, mention_start_tensor=None, mention_end_tensor=None): | |||||
def __init__(self, antecedent_scores=None, target=None, mention_start_tensor=None, mention_end_tensor=None): | |||||
""" | """ | ||||
:param pred: | :param pred: | ||||
:param target: | :param target: | ||||
""" | """ | ||||
super().__init__() | super().__init__() | ||||
self._init_param_map(antecedent_scores=antecedent_scores, clusters=clusters, | |||||
self._init_param_map(antecedent_scores=antecedent_scores, target=target, | |||||
mention_start_tensor=mention_start_tensor, mention_end_tensor=mention_end_tensor) | mention_start_tensor=mention_start_tensor, mention_end_tensor=mention_end_tensor) | ||||
def get_loss(self, antecedent_scores, clusters, mention_start_tensor, mention_end_tensor): | |||||
antecedent_labels = get_labels(clusters[0], mention_start_tensor, mention_end_tensor, | |||||
def get_loss(self, antecedent_scores, target, mention_start_tensor, mention_end_tensor): | |||||
antecedent_labels = get_labels(target[0], mention_start_tensor, mention_end_tensor, | |||||
Config().max_antecedents) | Config().max_antecedents) | ||||
antecedent_labels = torch.from_numpy(antecedent_labels*1).to(torch.device("cuda:" + Config().cuda)) | antecedent_labels = torch.from_numpy(antecedent_labels*1).to(torch.device("cuda:" + Config().cuda)) | ||||
@@ -1,14 +0,0 @@ | |||||
import unittest | |||||
from ..data_load.cr_loader import CRLoader | |||||
class Test_CRLoader(unittest.TestCase): | |||||
def test_cr_loader(self): | |||||
train_path = 'data/train.english.jsonlines.mini' | |||||
dev_path = 'data/dev.english.jsonlines.minid' | |||||
test_path = 'data/test.english.jsonlines' | |||||
cr = CRLoader() | |||||
data_info = cr.process({'train':train_path,'dev':dev_path,'test':test_path}) | |||||
print(data_info.datasets['train'][0]) | |||||
print(data_info.datasets['dev'][0]) | |||||
print(data_info.datasets['test'][0]) |
@@ -7,7 +7,9 @@ from torch.optim import Adam | |||||
from fastNLP.core.callback import Callback, GradientClipCallback | from fastNLP.core.callback import Callback, GradientClipCallback | ||||
from fastNLP.core.trainer import Trainer | from fastNLP.core.trainer import Trainer | ||||
from reproduction.coreference_resolution.data_load.cr_loader import CRLoader | |||||
from fastNLP.io.pipe.coreference import CoreferencePipe | |||||
from fastNLP.core.const import Const | |||||
from reproduction.coreference_resolution.model.config import Config | from reproduction.coreference_resolution.model.config import Config | ||||
from reproduction.coreference_resolution.model.model_re import Model | from reproduction.coreference_resolution.model.model_re import Model | ||||
from reproduction.coreference_resolution.model.softmax_loss import SoftmaxLoss | from reproduction.coreference_resolution.model.softmax_loss import SoftmaxLoss | ||||
@@ -36,18 +38,15 @@ if __name__ == "__main__": | |||||
print(config) | print(config) | ||||
@cache_results('cache.pkl') | |||||
# @cache_results('cache.pkl') | |||||
def cache(): | def cache(): | ||||
cr_train_dev_test = CRLoader() | |||||
data_info = cr_train_dev_test.process({'train': config.train_path, 'dev': config.dev_path, | |||||
'test': config.test_path}) | |||||
return data_info | |||||
data_info = cache() | |||||
print("数据集划分:\ntrain:", str(len(data_info.datasets["train"])), | |||||
"\ndev:" + str(len(data_info.datasets["dev"])) + "\ntest:" + str(len(data_info.datasets["test"]))) | |||||
bundle = CoreferencePipe(config).process_from_file({'train': config.train_path, 'dev': config.dev_path,'test': config.test_path}) | |||||
return bundle | |||||
data_bundle = cache() | |||||
print("数据集划分:\ntrain:", str(len(data_bundle.get_dataset("train"))), | |||||
"\ndev:" + str(len(data_bundle.get_dataset("dev"))) + "\ntest:" + str(len(data_bundle.get_dataset('test')))) | |||||
# print(data_info) | # print(data_info) | ||||
model = Model(data_info.vocabs, config) | |||||
model = Model(data_bundle.get_vocab(Const.INPUT), config) | |||||
print(model) | print(model) | ||||
loss = SoftmaxLoss() | loss = SoftmaxLoss() | ||||
@@ -58,11 +57,11 @@ if __name__ == "__main__": | |||||
lr_decay_callback = LRCallback(optim.param_groups, config.lr_decay) | lr_decay_callback = LRCallback(optim.param_groups, config.lr_decay) | ||||
trainer = Trainer(model=model, train_data=data_info.datasets["train"], dev_data=data_info.datasets["dev"], | |||||
loss=loss, metrics=metric, check_code_level=-1,sampler=None, | |||||
trainer = Trainer(model=model, train_data=data_bundle.datasets["train"], dev_data=data_bundle.datasets["dev"], | |||||
loss=loss, metrics=metric, check_code_level=-1, sampler=None, | |||||
batch_size=1, device=torch.device("cuda:" + config.cuda), metric_key='f', n_epochs=config.epoch, | batch_size=1, device=torch.device("cuda:" + config.cuda), metric_key='f', n_epochs=config.epoch, | ||||
optimizer=optim, | optimizer=optim, | ||||
save_path='/remote-home/xxliu/pycharm/fastNLP/fastNLP/reproduction/coreference_resolution/save', | |||||
save_path= None, | |||||
callbacks=[lr_decay_callback, GradientClipCallback(clip_value=5)]) | callbacks=[lr_decay_callback, GradientClipCallback(clip_value=5)]) | ||||
print() | print() | ||||
@@ -1,7 +1,8 @@ | |||||
import torch | import torch | ||||
from reproduction.coreference_resolution.model.config import Config | from reproduction.coreference_resolution.model.config import Config | ||||
from reproduction.coreference_resolution.model.metric import CRMetric | from reproduction.coreference_resolution.model.metric import CRMetric | ||||
from reproduction.coreference_resolution.data_load.cr_loader import CRLoader | |||||
from fastNLP.io.pipe.coreference import CoreferencePipe | |||||
from fastNLP import Tester | from fastNLP import Tester | ||||
import argparse | import argparse | ||||
@@ -11,13 +12,12 @@ if __name__=='__main__': | |||||
parser.add_argument('--path') | parser.add_argument('--path') | ||||
args = parser.parse_args() | args = parser.parse_args() | ||||
cr_loader = CRLoader() | |||||
config = Config() | config = Config() | ||||
data_info = cr_loader.process({'train': config.train_path, 'dev': config.dev_path, | |||||
'test': config.test_path}) | |||||
bundle = CoreferencePipe(Config()).process_from_file( | |||||
{'train': config.train_path, 'dev': config.dev_path, 'test': config.test_path}) | |||||
metirc = CRMetric() | metirc = CRMetric() | ||||
model = torch.load(args.path) | model = torch.load(args.path) | ||||
tester = Tester(data_info.datasets['test'],model,metirc,batch_size=1,device="cuda:0") | |||||
tester = Tester(bundle.get_dataset("test"),model,metirc,batch_size=1,device="cuda:0") | |||||
tester.test() | tester.test() | ||||
print('test over') | print('test over') | ||||
@@ -0,0 +1 @@ | |||||
{"doc_key": "bc/cctv/00/cctv_0000_0", "speakers": [["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"], ["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"], ["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"]], "clusters": [[[70, 70], [485, 486], [500, 500], [73, 73], [55, 55], [153, 154], [366, 366]]], "sentences": [["In", "the", "summer", "of", "2005", ",", "a", "picture", "that", "people", "have", "long", "been", "looking", "forward", "to", "started", "emerging", "with", "frequency", "in", "various", "major", "Hong", "Kong", "media", "."], ["With", "their", "unique", "charm", ",", "these", "well", "-", "known", "cartoon", "images", "once", "again", "caused", "Hong", "Kong", "to", "be", "a", "focus", "of", "worldwide", "attention", "."]]} |
@@ -0,0 +1 @@ | |||||
{"doc_key": "bc/cctv/00/cctv_0005_0", "speakers": [["speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1"], ["speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1"]], "clusters": [[[57, 59], [25, 27], [42, 44]]], "sentences": [["--", "basically", ",", "it", "was", "unanimously", "agreed", "upon", "by", "the", "various", "relevant", "parties", "."], ["To", "express", "its", "determination", ",", "the", "Chinese", "securities", "regulatory", "department", "compares", "this", "stock", "reform", "to", "a", "die", "that", "has", "been", "cast", "."]]} |
@@ -0,0 +1 @@ | |||||
{"doc_key": "bc/cctv/00/cctv_0001_0", "speakers": [["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"], ["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"]], "clusters": [[[113, 114], [42, 45], [88, 91]]], "sentences": [["What", "kind", "of", "memory", "?"], ["We", "respectfully", "invite", "you", "to", "watch", "a", "special", "edition", "of", "Across", "China", "."]]} |
@@ -0,0 +1,16 @@ | |||||
from fastNLP.io.loader.coreference import CRLoader | |||||
import unittest | |||||
class TestCR(unittest.TestCase): | |||||
def test_load(self): | |||||
test_root = "test/data_for_tests/coreference/" | |||||
train_path = test_root+"coreference_train.json" | |||||
dev_path = test_root+"coreference_dev.json" | |||||
test_path = test_root+"coreference_test.json" | |||||
paths = {"train": train_path,"dev":dev_path,"test":test_path} | |||||
bundle1 = CRLoader().load(paths) | |||||
bundle2 = CRLoader().load(test_root) | |||||
print(bundle1) | |||||
print(bundle2) |
@@ -0,0 +1,24 @@ | |||||
import unittest | |||||
from fastNLP.io.pipe.coreference import CoreferencePipe | |||||
class TestCR(unittest.TestCase): | |||||
def test_load(self): | |||||
class Config(): | |||||
max_sentences = 50 | |||||
filter = [3, 4, 5] | |||||
char_path = None | |||||
config = Config() | |||||
file_root_path = "test/data_for_tests/coreference/" | |||||
train_path = file_root_path + "coreference_train.json" | |||||
dev_path = file_root_path + "coreference_dev.json" | |||||
test_path = file_root_path + "coreference_test.json" | |||||
paths = {"train": train_path, "dev": dev_path, "test": test_path} | |||||
bundle1 = CoreferencePipe(config).process_from_file(paths) | |||||
bundle2 = CoreferencePipe(config).process_from_file(file_root_path) | |||||
print(bundle1) | |||||
print(bundle2) |