Browse Source

Merge pull request #221 from Xiaoxiong-Liu/dev0.5.0

Dev0.5.0
tags/v0.4.10
Yige Xu GitHub 5 years ago
parent
commit
e6f47819e6
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 299 additions and 108 deletions
  1. +4
    -1
      fastNLP/io/loader/__init__.py
  2. +46
    -0
      fastNLP/io/loader/coreference.py
  3. +3
    -0
      fastNLP/io/pipe/__init__.py
  4. +170
    -0
      fastNLP/io/pipe/coreference.py
  5. +1
    -1
      reproduction/coreference_resolution/README.md
  6. +0
    -0
      reproduction/coreference_resolution/data_load/__init__.py
  7. +0
    -68
      reproduction/coreference_resolution/data_load/cr_loader.py
  8. +10
    -1
      reproduction/coreference_resolution/model/model_re.py
  9. +4
    -4
      reproduction/coreference_resolution/model/softmax_loss.py
  10. +0
    -0
      reproduction/coreference_resolution/test/__init__.py
  11. +0
    -14
      reproduction/coreference_resolution/test/test_dataloader.py
  12. +13
    -14
      reproduction/coreference_resolution/train.py
  13. +5
    -5
      reproduction/coreference_resolution/valid.py
  14. +1
    -0
      test/data_for_tests/coreference/coreference_dev.json
  15. +1
    -0
      test/data_for_tests/coreference/coreference_test.json
  16. +1
    -0
      test/data_for_tests/coreference/coreference_train.json
  17. +16
    -0
      test/io/loader/test_coreference_loader.py
  18. +24
    -0
      test/io/pipe/test_coreference.py

+ 4
- 1
fastNLP/io/loader/__init__.py View File

@@ -72,7 +72,9 @@ __all__ = [
"QuoraLoader", "QuoraLoader",
"SNLILoader", "SNLILoader",
"QNLILoader", "QNLILoader",
"RTELoader"
"RTELoader",

"CRLoader"
] ]
from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader
from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader
@@ -82,3 +84,4 @@ from .json import JsonLoader
from .loader import Loader from .loader import Loader
from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader
from .conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader from .conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader
from .coreference import CRLoader

+ 46
- 0
fastNLP/io/loader/coreference.py View File

@@ -0,0 +1,46 @@
"""undocumented"""

from ...core.dataset import DataSet
from ..file_reader import _read_json
from ...core.instance import Instance
from ...core.const import Const
from .json import JsonLoader


class CRLoader(JsonLoader):
"""
原始数据中内容应该为, 每一行为一个json对象,其中doc_key包含文章的种类信息,speakers包含每句话的说话者信息,cluster是指向现实中同一个事物的聚集,sentences是文本信息内容。

Example::

{"doc_key":"bc/cctv/00/cctv_001",
"speakers":"[["Speaker1","Speaker1","Speaker1"],["Speaker1","Speaker1","Speaker1"]]",
"clusters":"[[[2,3],[4,5]],[7,8],[18,20]]]",
"sentences":[["I","have","an","apple"],["It","is","good"]]
}

读取预处理好的Conll2012数据。

"""
def __init__(self, fields=None, dropna=False):
super().__init__(fields, dropna)
# self.fields = {"doc_key":Const.INPUTS(0),"speakers":Const.INPUTS(1),"clusters":Const.TARGET,"sentences":Const.INPUTS(2)}
# TODO check 1
self.fields = {"doc_key": Const.RAW_WORDS(0), "speakers": Const.RAW_WORDS(1), "clusters": Const.RAW_WORDS(2),
"sentences": Const.RAW_WORDS(3)}

def _load(self, path):
"""
加载数据
:param path: 数据文件路径,文件为json

:return:
"""
dataset = DataSet()
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna):
if self.fields:
ins = {self.fields[k]: v for k, v in d.items()}
else:
ins = d
dataset.append(Instance(**ins))
return dataset

+ 3
- 0
fastNLP/io/pipe/__init__.py View File

@@ -38,6 +38,8 @@ __all__ = [
"QuoraPipe", "QuoraPipe",
"QNLIPipe", "QNLIPipe",
"MNLIPipe", "MNLIPipe",

"CoreferencePipe"
] ]


from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe
@@ -47,3 +49,4 @@ from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe
from .pipe import Pipe from .pipe import Pipe
from .conll import Conll2003Pipe from .conll import Conll2003Pipe
from .cws import CWSPipe from .cws import CWSPipe
from .coreference import CoreferencePipe

+ 170
- 0
fastNLP/io/pipe/coreference.py View File

@@ -0,0 +1,170 @@
"""undocumented"""

__all__ = [
"CoreferencePipe"

]

from .pipe import Pipe
from ..data_bundle import DataBundle
from ..loader.coreference import CRLoader
from ...core.const import Const
from fastNLP.core.vocabulary import Vocabulary
import numpy as np
import collections


class CoreferencePipe(Pipe):
"""
对Coreference resolution问题进行处理,得到文章种类/说话者/字符级信息/序列长度。
"""

def __init__(self,config):
super().__init__()
self.config = config

def process(self, data_bundle: DataBundle):
"""
对load进来的数据进一步处理
原始数据包含:raw_key,raw_speaker,raw_words,raw_clusters
.. csv-table::
:header: "raw_key", "raw_speaker","raw_words","raw_clusters"

"bc/cctv/00/cctv_0000_0", "[[Speaker#1, Speaker#1],[]]","[['I','am'],[]]","[[[2,3],[6,7]],[[10,12],[20,22]]]"
"bc/cctv/00/cctv_0000_1", "[['Speaker#1', 'peaker#1'],[]]","[['He','is'],[]]","[[[2,3],[6,7]],[[10,12],[20,22]]]"
"[...]", "[...]","[...]","[...]"

处理完成后数据包含文章类别、speaker信息、句子信息、句子对应的index、char、句子长度、target:
.. csv-table::
:header: "words1", "words2","words3","words4","chars","seq_len","target"

"bc", "[[0,0],[1,1]]","[['I','am'],[]]","[[1,2],[]]","[[[1],[2,3]],[]]","[2,3]","[[[2,3],[6,7]],[[10,12],[20,22]]]"
"[...]", "[...]","[...]","[...]","[...]","[...]","[...]"


:param data_bundle:
:return:
"""
genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])}
vocab = Vocabulary().from_dataset(*data_bundle.datasets.values(), field_name= Const.RAW_WORDS(3))
vocab.build_vocab()
word2id = vocab.word2idx
data_bundle.set_vocab(vocab,Const.INPUT)
if self.config.char_path:
char_dict = get_char_dict(self.config.char_path)
else:
char_set = set()
for i,w in enumerate(word2id):
if i < 2:
continue
for c in w:
char_set.add(c)

char_dict = collections.defaultdict(int)
char_dict.update({c: i for i, c in enumerate(char_set)})

for name, ds in data_bundle.datasets.items():
# genre
ds.apply(lambda x: genres[x[Const.RAW_WORDS(0)][:2]], new_field_name=Const.INPUTS(0))

# speaker_ids_np
ds.apply(lambda x: speaker2numpy(x[Const.RAW_WORDS(1)], self.config.max_sentences, is_train=name == 'train'),
new_field_name=Const.INPUTS(1))

# sentences
ds.rename_field(Const.RAW_WORDS(3),Const.INPUTS(2))

# doc_np
ds.apply(lambda x: doc2numpy(x[Const.INPUTS(2)], word2id, char_dict, max(self.config.filter),
self.config.max_sentences, is_train=name == 'train')[0],
new_field_name=Const.INPUTS(3))
# char_index
ds.apply(lambda x: doc2numpy(x[Const.INPUTS(2)], word2id, char_dict, max(self.config.filter),
self.config.max_sentences, is_train=name == 'train')[1],
new_field_name=Const.CHAR_INPUT)
# seq len
ds.apply(lambda x: doc2numpy(x[Const.INPUTS(2)], word2id, char_dict, max(self.config.filter),
self.config.max_sentences, is_train=name == 'train')[2],
new_field_name=Const.INPUT_LEN)

# clusters
ds.rename_field(Const.RAW_WORDS(2), Const.TARGET)


ds.set_ignore_type(Const.TARGET)
ds.set_padder(Const.TARGET, None)
ds.set_input(Const.INPUTS(0), Const.INPUTS(1), Const.INPUTS(2), Const.INPUTS(3), Const.CHAR_INPUT, Const.INPUT_LEN)
ds.set_target(Const.TARGET)

return data_bundle

def process_from_file(self, paths):
bundle = CRLoader().load(paths)
return self.process(bundle)


# helper

def doc2numpy(doc, word2id, chardict, max_filter, max_sentences, is_train):
docvec, char_index, length, max_len = _doc2vec(doc, word2id, chardict, max_filter, max_sentences, is_train)
assert max(length) == max_len
assert char_index.shape[0] == len(length)
assert char_index.shape[1] == max_len
doc_np = np.zeros((len(docvec), max_len), int)
for i in range(len(docvec)):
for j in range(len(docvec[i])):
doc_np[i][j] = docvec[i][j]
return doc_np, char_index, length

def _doc2vec(doc,word2id,char_dict,max_filter,max_sentences,is_train):
max_len = 0
max_word_length = 0
docvex = []
length = []
if is_train:
sent_num = min(max_sentences,len(doc))
else:
sent_num = len(doc)

for i in range(sent_num):
sent = doc[i]
length.append(len(sent))
if (len(sent) > max_len):
max_len = len(sent)
sent_vec =[]
for j,word in enumerate(sent):
if len(word)>max_word_length:
max_word_length = len(word)
if word in word2id:
sent_vec.append(word2id[word])
else:
sent_vec.append(word2id["UNK"])
docvex.append(sent_vec)

char_index = np.zeros((sent_num, max_len, max_word_length),dtype=int)
for i in range(sent_num):
sent = doc[i]
for j,word in enumerate(sent):
char_index[i, j, :len(word)] = [char_dict[c] for c in word]

return docvex,char_index,length,max_len

def speaker2numpy(speakers_raw,max_sentences,is_train):
if is_train and len(speakers_raw)> max_sentences:
speakers_raw = speakers_raw[0:max_sentences]
speakers = flatten(speakers_raw)
speaker_dict = {s: i for i, s in enumerate(set(speakers))}
speaker_ids = np.array([speaker_dict[s] for s in speakers])
return speaker_ids

# 展平
def flatten(l):
return [item for sublist in l for item in sublist]

def get_char_dict(path):
vocab = ["<UNK>"]
with open(path) as f:
vocab.extend(c.strip() for c in f.readlines())
char_dict = collections.defaultdict(int)
char_dict.update({c: i for i, c in enumerate(vocab)})
return char_dict

+ 1
- 1
reproduction/coreference_resolution/README.md View File

@@ -1,4 +1,4 @@
# 指消解复现
# 指消解复现
## 介绍 ## 介绍
Coreference resolution是查找文本中指向同一现实实体的所有表达式的任务。 Coreference resolution是查找文本中指向同一现实实体的所有表达式的任务。
对于涉及自然语言理解的许多更高级别的NLP任务来说, 对于涉及自然语言理解的许多更高级别的NLP任务来说,


+ 0
- 0
reproduction/coreference_resolution/data_load/__init__.py View File


+ 0
- 68
reproduction/coreference_resolution/data_load/cr_loader.py View File

@@ -1,68 +0,0 @@
from fastNLP.io.dataset_loader import JsonLoader,DataSet,Instance
from fastNLP.io.file_reader import _read_json
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.io.data_bundle import DataBundle
from reproduction.coreference_resolution.model.config import Config
import reproduction.coreference_resolution.model.preprocess as preprocess


class CRLoader(JsonLoader):
def __init__(self, fields=None, dropna=False):
super().__init__(fields, dropna)

def _load(self, path):
"""
加载数据
:param path:
:return:
"""
dataset = DataSet()
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna):
if self.fields:
ins = {self.fields[k]: v for k, v in d.items()}
else:
ins = d
dataset.append(Instance(**ins))
return dataset

def process(self, paths, **kwargs):
data_info = DataBundle()
for name in ['train', 'test', 'dev']:
data_info.datasets[name] = self.load(paths[name])

config = Config()
vocab = Vocabulary().from_dataset(*data_info.datasets.values(), field_name='sentences')
vocab.build_vocab()
word2id = vocab.word2idx

char_dict = preprocess.get_char_dict(config.char_path)
data_info.vocabs = vocab

genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])}

for name, ds in data_info.datasets.items():
ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter),
config.max_sentences, is_train=name=='train')[0],
new_field_name='doc_np')
ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter),
config.max_sentences, is_train=name=='train')[1],
new_field_name='char_index')
ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter),
config.max_sentences, is_train=name=='train')[2],
new_field_name='seq_len')
ds.apply(lambda x: preprocess.speaker2numpy(x["speakers"], config.max_sentences, is_train=name=='train'),
new_field_name='speaker_ids_np')
ds.apply(lambda x: genres[x["doc_key"][:2]], new_field_name='genre')

ds.set_ignore_type('clusters')
ds.set_padder('clusters', None)
ds.set_input("sentences", "doc_np", "speaker_ids_np", "genre", "char_index", "seq_len")
ds.set_target("clusters")

# train_dev, test = self.ds.split(348 / (2802 + 343 + 348), shuffle=False)
# train, dev = train_dev.split(343 / (2802 + 343), shuffle=False)

return data_info




+ 10
- 1
reproduction/coreference_resolution/model/model_re.py View File

@@ -8,6 +8,7 @@ from fastNLP.models.base_model import BaseModel
from fastNLP.modules.encoder.variational_rnn import VarLSTM from fastNLP.modules.encoder.variational_rnn import VarLSTM
from reproduction.coreference_resolution.model import preprocess from reproduction.coreference_resolution.model import preprocess
from fastNLP.io.embed_loader import EmbedLoader from fastNLP.io.embed_loader import EmbedLoader
from fastNLP.core.const import Const
import random import random


# 设置seed # 设置seed
@@ -415,7 +416,7 @@ class Model(BaseModel):
return predicted_clusters return predicted_clusters




def forward(self, sentences, doc_np, speaker_ids_np, genre, char_index, seq_len):
def forward(self, words1 , words2, words3, words4, chars, seq_len):
""" """
实际输入都是tensor 实际输入都是tensor
:param sentences: 句子,被fastNLP转化成了numpy, :param sentences: 句子,被fastNLP转化成了numpy,
@@ -426,6 +427,14 @@ class Model(BaseModel):
:param seq_len: 被fastNLP转化成了Tensor :param seq_len: 被fastNLP转化成了Tensor
:return: :return:
""" """

sentences = words3
doc_np = words4
speaker_ids_np = words2
genre = words1
char_index = chars


# change for fastNLP # change for fastNLP
sentences = sentences[0].tolist() sentences = sentences[0].tolist()
doc_tensor = doc_np[0] doc_tensor = doc_np[0]


+ 4
- 4
reproduction/coreference_resolution/model/softmax_loss.py View File

@@ -11,18 +11,18 @@ class SoftmaxLoss(LossBase):
允许多标签分类 允许多标签分类
""" """


def __init__(self, antecedent_scores=None, clusters=None, mention_start_tensor=None, mention_end_tensor=None):
def __init__(self, antecedent_scores=None, target=None, mention_start_tensor=None, mention_end_tensor=None):
""" """


:param pred: :param pred:
:param target: :param target:
""" """
super().__init__() super().__init__()
self._init_param_map(antecedent_scores=antecedent_scores, clusters=clusters,
self._init_param_map(antecedent_scores=antecedent_scores, target=target,
mention_start_tensor=mention_start_tensor, mention_end_tensor=mention_end_tensor) mention_start_tensor=mention_start_tensor, mention_end_tensor=mention_end_tensor)


def get_loss(self, antecedent_scores, clusters, mention_start_tensor, mention_end_tensor):
antecedent_labels = get_labels(clusters[0], mention_start_tensor, mention_end_tensor,
def get_loss(self, antecedent_scores, target, mention_start_tensor, mention_end_tensor):
antecedent_labels = get_labels(target[0], mention_start_tensor, mention_end_tensor,
Config().max_antecedents) Config().max_antecedents)


antecedent_labels = torch.from_numpy(antecedent_labels*1).to(torch.device("cuda:" + Config().cuda)) antecedent_labels = torch.from_numpy(antecedent_labels*1).to(torch.device("cuda:" + Config().cuda))


+ 0
- 0
reproduction/coreference_resolution/test/__init__.py View File


+ 0
- 14
reproduction/coreference_resolution/test/test_dataloader.py View File

@@ -1,14 +0,0 @@
import unittest
from ..data_load.cr_loader import CRLoader

class Test_CRLoader(unittest.TestCase):
def test_cr_loader(self):
train_path = 'data/train.english.jsonlines.mini'
dev_path = 'data/dev.english.jsonlines.minid'
test_path = 'data/test.english.jsonlines'
cr = CRLoader()
data_info = cr.process({'train':train_path,'dev':dev_path,'test':test_path})

print(data_info.datasets['train'][0])
print(data_info.datasets['dev'][0])
print(data_info.datasets['test'][0])

+ 13
- 14
reproduction/coreference_resolution/train.py View File

@@ -7,7 +7,9 @@ from torch.optim import Adam
from fastNLP.core.callback import Callback, GradientClipCallback from fastNLP.core.callback import Callback, GradientClipCallback
from fastNLP.core.trainer import Trainer from fastNLP.core.trainer import Trainer


from reproduction.coreference_resolution.data_load.cr_loader import CRLoader
from fastNLP.io.pipe.coreference import CoreferencePipe
from fastNLP.core.const import Const

from reproduction.coreference_resolution.model.config import Config from reproduction.coreference_resolution.model.config import Config
from reproduction.coreference_resolution.model.model_re import Model from reproduction.coreference_resolution.model.model_re import Model
from reproduction.coreference_resolution.model.softmax_loss import SoftmaxLoss from reproduction.coreference_resolution.model.softmax_loss import SoftmaxLoss
@@ -36,18 +38,15 @@ if __name__ == "__main__":


print(config) print(config)


@cache_results('cache.pkl')
# @cache_results('cache.pkl')
def cache(): def cache():
cr_train_dev_test = CRLoader()

data_info = cr_train_dev_test.process({'train': config.train_path, 'dev': config.dev_path,
'test': config.test_path})
return data_info
data_info = cache()
print("数据集划分:\ntrain:", str(len(data_info.datasets["train"])),
"\ndev:" + str(len(data_info.datasets["dev"])) + "\ntest:" + str(len(data_info.datasets["test"])))
bundle = CoreferencePipe(config).process_from_file({'train': config.train_path, 'dev': config.dev_path,'test': config.test_path})
return bundle
data_bundle = cache()
print("数据集划分:\ntrain:", str(len(data_bundle.get_dataset("train"))),
"\ndev:" + str(len(data_bundle.get_dataset("dev"))) + "\ntest:" + str(len(data_bundle.get_dataset('test'))))
# print(data_info) # print(data_info)
model = Model(data_info.vocabs, config)
model = Model(data_bundle.get_vocab(Const.INPUT), config)
print(model) print(model)


loss = SoftmaxLoss() loss = SoftmaxLoss()
@@ -58,11 +57,11 @@ if __name__ == "__main__":


lr_decay_callback = LRCallback(optim.param_groups, config.lr_decay) lr_decay_callback = LRCallback(optim.param_groups, config.lr_decay)


trainer = Trainer(model=model, train_data=data_info.datasets["train"], dev_data=data_info.datasets["dev"],
loss=loss, metrics=metric, check_code_level=-1,sampler=None,
trainer = Trainer(model=model, train_data=data_bundle.datasets["train"], dev_data=data_bundle.datasets["dev"],
loss=loss, metrics=metric, check_code_level=-1, sampler=None,
batch_size=1, device=torch.device("cuda:" + config.cuda), metric_key='f', n_epochs=config.epoch, batch_size=1, device=torch.device("cuda:" + config.cuda), metric_key='f', n_epochs=config.epoch,
optimizer=optim, optimizer=optim,
save_path='/remote-home/xxliu/pycharm/fastNLP/fastNLP/reproduction/coreference_resolution/save',
save_path= None,
callbacks=[lr_decay_callback, GradientClipCallback(clip_value=5)]) callbacks=[lr_decay_callback, GradientClipCallback(clip_value=5)])
print() print()




+ 5
- 5
reproduction/coreference_resolution/valid.py View File

@@ -1,7 +1,8 @@
import torch import torch
from reproduction.coreference_resolution.model.config import Config from reproduction.coreference_resolution.model.config import Config
from reproduction.coreference_resolution.model.metric import CRMetric from reproduction.coreference_resolution.model.metric import CRMetric
from reproduction.coreference_resolution.data_load.cr_loader import CRLoader
from fastNLP.io.pipe.coreference import CoreferencePipe

from fastNLP import Tester from fastNLP import Tester
import argparse import argparse


@@ -11,13 +12,12 @@ if __name__=='__main__':
parser.add_argument('--path') parser.add_argument('--path')
args = parser.parse_args() args = parser.parse_args()
cr_loader = CRLoader()
config = Config() config = Config()
data_info = cr_loader.process({'train': config.train_path, 'dev': config.dev_path,
'test': config.test_path})
bundle = CoreferencePipe(Config()).process_from_file(
{'train': config.train_path, 'dev': config.dev_path, 'test': config.test_path})
metirc = CRMetric() metirc = CRMetric()
model = torch.load(args.path) model = torch.load(args.path)
tester = Tester(data_info.datasets['test'],model,metirc,batch_size=1,device="cuda:0")
tester = Tester(bundle.get_dataset("test"),model,metirc,batch_size=1,device="cuda:0")
tester.test() tester.test()
print('test over') print('test over')




+ 1
- 0
test/data_for_tests/coreference/coreference_dev.json View File

@@ -0,0 +1 @@
{"doc_key": "bc/cctv/00/cctv_0000_0", "speakers": [["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"], ["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"], ["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"]], "clusters": [[[70, 70], [485, 486], [500, 500], [73, 73], [55, 55], [153, 154], [366, 366]]], "sentences": [["In", "the", "summer", "of", "2005", ",", "a", "picture", "that", "people", "have", "long", "been", "looking", "forward", "to", "started", "emerging", "with", "frequency", "in", "various", "major", "Hong", "Kong", "media", "."], ["With", "their", "unique", "charm", ",", "these", "well", "-", "known", "cartoon", "images", "once", "again", "caused", "Hong", "Kong", "to", "be", "a", "focus", "of", "worldwide", "attention", "."]]}

+ 1
- 0
test/data_for_tests/coreference/coreference_test.json View File

@@ -0,0 +1 @@
{"doc_key": "bc/cctv/00/cctv_0005_0", "speakers": [["speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1"], ["speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1"]], "clusters": [[[57, 59], [25, 27], [42, 44]]], "sentences": [["--", "basically", ",", "it", "was", "unanimously", "agreed", "upon", "by", "the", "various", "relevant", "parties", "."], ["To", "express", "its", "determination", ",", "the", "Chinese", "securities", "regulatory", "department", "compares", "this", "stock", "reform", "to", "a", "die", "that", "has", "been", "cast", "."]]}

+ 1
- 0
test/data_for_tests/coreference/coreference_train.json View File

@@ -0,0 +1 @@
{"doc_key": "bc/cctv/00/cctv_0001_0", "speakers": [["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"], ["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"]], "clusters": [[[113, 114], [42, 45], [88, 91]]], "sentences": [["What", "kind", "of", "memory", "?"], ["We", "respectfully", "invite", "you", "to", "watch", "a", "special", "edition", "of", "Across", "China", "."]]}

+ 16
- 0
test/io/loader/test_coreference_loader.py View File

@@ -0,0 +1,16 @@
from fastNLP.io.loader.coreference import CRLoader
import unittest

class TestCR(unittest.TestCase):
def test_load(self):

test_root = "test/data_for_tests/coreference/"
train_path = test_root+"coreference_train.json"
dev_path = test_root+"coreference_dev.json"
test_path = test_root+"coreference_test.json"
paths = {"train": train_path,"dev":dev_path,"test":test_path}

bundle1 = CRLoader().load(paths)
bundle2 = CRLoader().load(test_root)
print(bundle1)
print(bundle2)

+ 24
- 0
test/io/pipe/test_coreference.py View File

@@ -0,0 +1,24 @@
import unittest
from fastNLP.io.pipe.coreference import CoreferencePipe


class TestCR(unittest.TestCase):

def test_load(self):
class Config():
max_sentences = 50
filter = [3, 4, 5]
char_path = None
config = Config()

file_root_path = "test/data_for_tests/coreference/"
train_path = file_root_path + "coreference_train.json"
dev_path = file_root_path + "coreference_dev.json"
test_path = file_root_path + "coreference_test.json"

paths = {"train": train_path, "dev": dev_path, "test": test_path}

bundle1 = CoreferencePipe(config).process_from_file(paths)
bundle2 = CoreferencePipe(config).process_from_file(file_root_path)
print(bundle1)
print(bundle2)

Loading…
Cancel
Save