Browse Source

Merge pull request #221 from Xiaoxiong-Liu/dev0.5.0

Dev0.5.0
tags/v0.4.10
Yige Xu GitHub 5 years ago
parent
commit
e6f47819e6
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 299 additions and 108 deletions
  1. +4
    -1
      fastNLP/io/loader/__init__.py
  2. +46
    -0
      fastNLP/io/loader/coreference.py
  3. +3
    -0
      fastNLP/io/pipe/__init__.py
  4. +170
    -0
      fastNLP/io/pipe/coreference.py
  5. +1
    -1
      reproduction/coreference_resolution/README.md
  6. +0
    -0
      reproduction/coreference_resolution/data_load/__init__.py
  7. +0
    -68
      reproduction/coreference_resolution/data_load/cr_loader.py
  8. +10
    -1
      reproduction/coreference_resolution/model/model_re.py
  9. +4
    -4
      reproduction/coreference_resolution/model/softmax_loss.py
  10. +0
    -0
      reproduction/coreference_resolution/test/__init__.py
  11. +0
    -14
      reproduction/coreference_resolution/test/test_dataloader.py
  12. +13
    -14
      reproduction/coreference_resolution/train.py
  13. +5
    -5
      reproduction/coreference_resolution/valid.py
  14. +1
    -0
      test/data_for_tests/coreference/coreference_dev.json
  15. +1
    -0
      test/data_for_tests/coreference/coreference_test.json
  16. +1
    -0
      test/data_for_tests/coreference/coreference_train.json
  17. +16
    -0
      test/io/loader/test_coreference_loader.py
  18. +24
    -0
      test/io/pipe/test_coreference.py

+ 4
- 1
fastNLP/io/loader/__init__.py View File

@@ -72,7 +72,9 @@ __all__ = [
"QuoraLoader",
"SNLILoader",
"QNLILoader",
"RTELoader"
"RTELoader",

"CRLoader"
]
from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader
from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader
@@ -82,3 +84,4 @@ from .json import JsonLoader
from .loader import Loader
from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader
from .conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader
from .coreference import CRLoader

+ 46
- 0
fastNLP/io/loader/coreference.py View File

@@ -0,0 +1,46 @@
"""undocumented"""

from ...core.dataset import DataSet
from ..file_reader import _read_json
from ...core.instance import Instance
from ...core.const import Const
from .json import JsonLoader


class CRLoader(JsonLoader):
"""
原始数据中内容应该为, 每一行为一个json对象,其中doc_key包含文章的种类信息,speakers包含每句话的说话者信息,cluster是指向现实中同一个事物的聚集,sentences是文本信息内容。

Example::

{"doc_key":"bc/cctv/00/cctv_001",
"speakers":"[["Speaker1","Speaker1","Speaker1"],["Speaker1","Speaker1","Speaker1"]]",
"clusters":"[[[2,3],[4,5]],[7,8],[18,20]]]",
"sentences":[["I","have","an","apple"],["It","is","good"]]
}

读取预处理好的Conll2012数据。

"""
def __init__(self, fields=None, dropna=False):
super().__init__(fields, dropna)
# self.fields = {"doc_key":Const.INPUTS(0),"speakers":Const.INPUTS(1),"clusters":Const.TARGET,"sentences":Const.INPUTS(2)}
# TODO check 1
self.fields = {"doc_key": Const.RAW_WORDS(0), "speakers": Const.RAW_WORDS(1), "clusters": Const.RAW_WORDS(2),
"sentences": Const.RAW_WORDS(3)}

def _load(self, path):
"""
加载数据
:param path: 数据文件路径,文件为json

:return:
"""
dataset = DataSet()
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna):
if self.fields:
ins = {self.fields[k]: v for k, v in d.items()}
else:
ins = d
dataset.append(Instance(**ins))
return dataset

+ 3
- 0
fastNLP/io/pipe/__init__.py View File

@@ -38,6 +38,8 @@ __all__ = [
"QuoraPipe",
"QNLIPipe",
"MNLIPipe",

"CoreferencePipe"
]

from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe
@@ -47,3 +49,4 @@ from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe
from .pipe import Pipe
from .conll import Conll2003Pipe
from .cws import CWSPipe
from .coreference import CoreferencePipe

+ 170
- 0
fastNLP/io/pipe/coreference.py View File

@@ -0,0 +1,170 @@
"""undocumented"""

__all__ = [
"CoreferencePipe"

]

from .pipe import Pipe
from ..data_bundle import DataBundle
from ..loader.coreference import CRLoader
from ...core.const import Const
from fastNLP.core.vocabulary import Vocabulary
import numpy as np
import collections


class CoreferencePipe(Pipe):
"""
对Coreference resolution问题进行处理,得到文章种类/说话者/字符级信息/序列长度。
"""

def __init__(self,config):
super().__init__()
self.config = config

def process(self, data_bundle: DataBundle):
"""
对load进来的数据进一步处理
原始数据包含:raw_key,raw_speaker,raw_words,raw_clusters
.. csv-table::
:header: "raw_key", "raw_speaker","raw_words","raw_clusters"

"bc/cctv/00/cctv_0000_0", "[[Speaker#1, Speaker#1],[]]","[['I','am'],[]]","[[[2,3],[6,7]],[[10,12],[20,22]]]"
"bc/cctv/00/cctv_0000_1", "[['Speaker#1', 'peaker#1'],[]]","[['He','is'],[]]","[[[2,3],[6,7]],[[10,12],[20,22]]]"
"[...]", "[...]","[...]","[...]"

处理完成后数据包含文章类别、speaker信息、句子信息、句子对应的index、char、句子长度、target:
.. csv-table::
:header: "words1", "words2","words3","words4","chars","seq_len","target"

"bc", "[[0,0],[1,1]]","[['I','am'],[]]","[[1,2],[]]","[[[1],[2,3]],[]]","[2,3]","[[[2,3],[6,7]],[[10,12],[20,22]]]"
"[...]", "[...]","[...]","[...]","[...]","[...]","[...]"


:param data_bundle:
:return:
"""
genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])}
vocab = Vocabulary().from_dataset(*data_bundle.datasets.values(), field_name= Const.RAW_WORDS(3))
vocab.build_vocab()
word2id = vocab.word2idx
data_bundle.set_vocab(vocab,Const.INPUT)
if self.config.char_path:
char_dict = get_char_dict(self.config.char_path)
else:
char_set = set()
for i,w in enumerate(word2id):
if i < 2:
continue
for c in w:
char_set.add(c)

char_dict = collections.defaultdict(int)
char_dict.update({c: i for i, c in enumerate(char_set)})

for name, ds in data_bundle.datasets.items():
# genre
ds.apply(lambda x: genres[x[Const.RAW_WORDS(0)][:2]], new_field_name=Const.INPUTS(0))

# speaker_ids_np
ds.apply(lambda x: speaker2numpy(x[Const.RAW_WORDS(1)], self.config.max_sentences, is_train=name == 'train'),
new_field_name=Const.INPUTS(1))

# sentences
ds.rename_field(Const.RAW_WORDS(3),Const.INPUTS(2))

# doc_np
ds.apply(lambda x: doc2numpy(x[Const.INPUTS(2)], word2id, char_dict, max(self.config.filter),
self.config.max_sentences, is_train=name == 'train')[0],
new_field_name=Const.INPUTS(3))
# char_index
ds.apply(lambda x: doc2numpy(x[Const.INPUTS(2)], word2id, char_dict, max(self.config.filter),
self.config.max_sentences, is_train=name == 'train')[1],
new_field_name=Const.CHAR_INPUT)
# seq len
ds.apply(lambda x: doc2numpy(x[Const.INPUTS(2)], word2id, char_dict, max(self.config.filter),
self.config.max_sentences, is_train=name == 'train')[2],
new_field_name=Const.INPUT_LEN)

# clusters
ds.rename_field(Const.RAW_WORDS(2), Const.TARGET)


ds.set_ignore_type(Const.TARGET)
ds.set_padder(Const.TARGET, None)
ds.set_input(Const.INPUTS(0), Const.INPUTS(1), Const.INPUTS(2), Const.INPUTS(3), Const.CHAR_INPUT, Const.INPUT_LEN)
ds.set_target(Const.TARGET)

return data_bundle

def process_from_file(self, paths):
bundle = CRLoader().load(paths)
return self.process(bundle)


# helper

def doc2numpy(doc, word2id, chardict, max_filter, max_sentences, is_train):
docvec, char_index, length, max_len = _doc2vec(doc, word2id, chardict, max_filter, max_sentences, is_train)
assert max(length) == max_len
assert char_index.shape[0] == len(length)
assert char_index.shape[1] == max_len
doc_np = np.zeros((len(docvec), max_len), int)
for i in range(len(docvec)):
for j in range(len(docvec[i])):
doc_np[i][j] = docvec[i][j]
return doc_np, char_index, length

def _doc2vec(doc,word2id,char_dict,max_filter,max_sentences,is_train):
max_len = 0
max_word_length = 0
docvex = []
length = []
if is_train:
sent_num = min(max_sentences,len(doc))
else:
sent_num = len(doc)

for i in range(sent_num):
sent = doc[i]
length.append(len(sent))
if (len(sent) > max_len):
max_len = len(sent)
sent_vec =[]
for j,word in enumerate(sent):
if len(word)>max_word_length:
max_word_length = len(word)
if word in word2id:
sent_vec.append(word2id[word])
else:
sent_vec.append(word2id["UNK"])
docvex.append(sent_vec)

char_index = np.zeros((sent_num, max_len, max_word_length),dtype=int)
for i in range(sent_num):
sent = doc[i]
for j,word in enumerate(sent):
char_index[i, j, :len(word)] = [char_dict[c] for c in word]

return docvex,char_index,length,max_len

def speaker2numpy(speakers_raw,max_sentences,is_train):
if is_train and len(speakers_raw)> max_sentences:
speakers_raw = speakers_raw[0:max_sentences]
speakers = flatten(speakers_raw)
speaker_dict = {s: i for i, s in enumerate(set(speakers))}
speaker_ids = np.array([speaker_dict[s] for s in speakers])
return speaker_ids

# 展平
def flatten(l):
return [item for sublist in l for item in sublist]

def get_char_dict(path):
vocab = ["<UNK>"]
with open(path) as f:
vocab.extend(c.strip() for c in f.readlines())
char_dict = collections.defaultdict(int)
char_dict.update({c: i for i, c in enumerate(vocab)})
return char_dict

+ 1
- 1
reproduction/coreference_resolution/README.md View File

@@ -1,4 +1,4 @@
# 指消解复现
# 指消解复现
## 介绍
Coreference resolution是查找文本中指向同一现实实体的所有表达式的任务。
对于涉及自然语言理解的许多更高级别的NLP任务来说,


+ 0
- 0
reproduction/coreference_resolution/data_load/__init__.py View File


+ 0
- 68
reproduction/coreference_resolution/data_load/cr_loader.py View File

@@ -1,68 +0,0 @@
from fastNLP.io.dataset_loader import JsonLoader,DataSet,Instance
from fastNLP.io.file_reader import _read_json
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.io.data_bundle import DataBundle
from reproduction.coreference_resolution.model.config import Config
import reproduction.coreference_resolution.model.preprocess as preprocess


class CRLoader(JsonLoader):
def __init__(self, fields=None, dropna=False):
super().__init__(fields, dropna)

def _load(self, path):
"""
加载数据
:param path:
:return:
"""
dataset = DataSet()
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna):
if self.fields:
ins = {self.fields[k]: v for k, v in d.items()}
else:
ins = d
dataset.append(Instance(**ins))
return dataset

def process(self, paths, **kwargs):
data_info = DataBundle()
for name in ['train', 'test', 'dev']:
data_info.datasets[name] = self.load(paths[name])

config = Config()
vocab = Vocabulary().from_dataset(*data_info.datasets.values(), field_name='sentences')
vocab.build_vocab()
word2id = vocab.word2idx

char_dict = preprocess.get_char_dict(config.char_path)
data_info.vocabs = vocab

genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])}

for name, ds in data_info.datasets.items():
ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter),
config.max_sentences, is_train=name=='train')[0],
new_field_name='doc_np')
ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter),
config.max_sentences, is_train=name=='train')[1],
new_field_name='char_index')
ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter),
config.max_sentences, is_train=name=='train')[2],
new_field_name='seq_len')
ds.apply(lambda x: preprocess.speaker2numpy(x["speakers"], config.max_sentences, is_train=name=='train'),
new_field_name='speaker_ids_np')
ds.apply(lambda x: genres[x["doc_key"][:2]], new_field_name='genre')

ds.set_ignore_type('clusters')
ds.set_padder('clusters', None)
ds.set_input("sentences", "doc_np", "speaker_ids_np", "genre", "char_index", "seq_len")
ds.set_target("clusters")

# train_dev, test = self.ds.split(348 / (2802 + 343 + 348), shuffle=False)
# train, dev = train_dev.split(343 / (2802 + 343), shuffle=False)

return data_info




+ 10
- 1
reproduction/coreference_resolution/model/model_re.py View File

@@ -8,6 +8,7 @@ from fastNLP.models.base_model import BaseModel
from fastNLP.modules.encoder.variational_rnn import VarLSTM
from reproduction.coreference_resolution.model import preprocess
from fastNLP.io.embed_loader import EmbedLoader
from fastNLP.core.const import Const
import random

# 设置seed
@@ -415,7 +416,7 @@ class Model(BaseModel):
return predicted_clusters


def forward(self, sentences, doc_np, speaker_ids_np, genre, char_index, seq_len):
def forward(self, words1 , words2, words3, words4, chars, seq_len):
"""
实际输入都是tensor
:param sentences: 句子,被fastNLP转化成了numpy,
@@ -426,6 +427,14 @@ class Model(BaseModel):
:param seq_len: 被fastNLP转化成了Tensor
:return:
"""

sentences = words3
doc_np = words4
speaker_ids_np = words2
genre = words1
char_index = chars


# change for fastNLP
sentences = sentences[0].tolist()
doc_tensor = doc_np[0]


+ 4
- 4
reproduction/coreference_resolution/model/softmax_loss.py View File

@@ -11,18 +11,18 @@ class SoftmaxLoss(LossBase):
允许多标签分类
"""

def __init__(self, antecedent_scores=None, clusters=None, mention_start_tensor=None, mention_end_tensor=None):
def __init__(self, antecedent_scores=None, target=None, mention_start_tensor=None, mention_end_tensor=None):
"""

:param pred:
:param target:
"""
super().__init__()
self._init_param_map(antecedent_scores=antecedent_scores, clusters=clusters,
self._init_param_map(antecedent_scores=antecedent_scores, target=target,
mention_start_tensor=mention_start_tensor, mention_end_tensor=mention_end_tensor)

def get_loss(self, antecedent_scores, clusters, mention_start_tensor, mention_end_tensor):
antecedent_labels = get_labels(clusters[0], mention_start_tensor, mention_end_tensor,
def get_loss(self, antecedent_scores, target, mention_start_tensor, mention_end_tensor):
antecedent_labels = get_labels(target[0], mention_start_tensor, mention_end_tensor,
Config().max_antecedents)

antecedent_labels = torch.from_numpy(antecedent_labels*1).to(torch.device("cuda:" + Config().cuda))


+ 0
- 0
reproduction/coreference_resolution/test/__init__.py View File


+ 0
- 14
reproduction/coreference_resolution/test/test_dataloader.py View File

@@ -1,14 +0,0 @@
import unittest
from ..data_load.cr_loader import CRLoader

class Test_CRLoader(unittest.TestCase):
def test_cr_loader(self):
train_path = 'data/train.english.jsonlines.mini'
dev_path = 'data/dev.english.jsonlines.minid'
test_path = 'data/test.english.jsonlines'
cr = CRLoader()
data_info = cr.process({'train':train_path,'dev':dev_path,'test':test_path})

print(data_info.datasets['train'][0])
print(data_info.datasets['dev'][0])
print(data_info.datasets['test'][0])

+ 13
- 14
reproduction/coreference_resolution/train.py View File

@@ -7,7 +7,9 @@ from torch.optim import Adam
from fastNLP.core.callback import Callback, GradientClipCallback
from fastNLP.core.trainer import Trainer

from reproduction.coreference_resolution.data_load.cr_loader import CRLoader
from fastNLP.io.pipe.coreference import CoreferencePipe
from fastNLP.core.const import Const

from reproduction.coreference_resolution.model.config import Config
from reproduction.coreference_resolution.model.model_re import Model
from reproduction.coreference_resolution.model.softmax_loss import SoftmaxLoss
@@ -36,18 +38,15 @@ if __name__ == "__main__":

print(config)

@cache_results('cache.pkl')
# @cache_results('cache.pkl')
def cache():
cr_train_dev_test = CRLoader()

data_info = cr_train_dev_test.process({'train': config.train_path, 'dev': config.dev_path,
'test': config.test_path})
return data_info
data_info = cache()
print("数据集划分:\ntrain:", str(len(data_info.datasets["train"])),
"\ndev:" + str(len(data_info.datasets["dev"])) + "\ntest:" + str(len(data_info.datasets["test"])))
bundle = CoreferencePipe(config).process_from_file({'train': config.train_path, 'dev': config.dev_path,'test': config.test_path})
return bundle
data_bundle = cache()
print("数据集划分:\ntrain:", str(len(data_bundle.get_dataset("train"))),
"\ndev:" + str(len(data_bundle.get_dataset("dev"))) + "\ntest:" + str(len(data_bundle.get_dataset('test'))))
# print(data_info)
model = Model(data_info.vocabs, config)
model = Model(data_bundle.get_vocab(Const.INPUT), config)
print(model)

loss = SoftmaxLoss()
@@ -58,11 +57,11 @@ if __name__ == "__main__":

lr_decay_callback = LRCallback(optim.param_groups, config.lr_decay)

trainer = Trainer(model=model, train_data=data_info.datasets["train"], dev_data=data_info.datasets["dev"],
loss=loss, metrics=metric, check_code_level=-1,sampler=None,
trainer = Trainer(model=model, train_data=data_bundle.datasets["train"], dev_data=data_bundle.datasets["dev"],
loss=loss, metrics=metric, check_code_level=-1, sampler=None,
batch_size=1, device=torch.device("cuda:" + config.cuda), metric_key='f', n_epochs=config.epoch,
optimizer=optim,
save_path='/remote-home/xxliu/pycharm/fastNLP/fastNLP/reproduction/coreference_resolution/save',
save_path= None,
callbacks=[lr_decay_callback, GradientClipCallback(clip_value=5)])
print()



+ 5
- 5
reproduction/coreference_resolution/valid.py View File

@@ -1,7 +1,8 @@
import torch
from reproduction.coreference_resolution.model.config import Config
from reproduction.coreference_resolution.model.metric import CRMetric
from reproduction.coreference_resolution.data_load.cr_loader import CRLoader
from fastNLP.io.pipe.coreference import CoreferencePipe

from fastNLP import Tester
import argparse

@@ -11,13 +12,12 @@ if __name__=='__main__':
parser.add_argument('--path')
args = parser.parse_args()
cr_loader = CRLoader()
config = Config()
data_info = cr_loader.process({'train': config.train_path, 'dev': config.dev_path,
'test': config.test_path})
bundle = CoreferencePipe(Config()).process_from_file(
{'train': config.train_path, 'dev': config.dev_path, 'test': config.test_path})
metirc = CRMetric()
model = torch.load(args.path)
tester = Tester(data_info.datasets['test'],model,metirc,batch_size=1,device="cuda:0")
tester = Tester(bundle.get_dataset("test"),model,metirc,batch_size=1,device="cuda:0")
tester.test()
print('test over')



+ 1
- 0
test/data_for_tests/coreference/coreference_dev.json View File

@@ -0,0 +1 @@
{"doc_key": "bc/cctv/00/cctv_0000_0", "speakers": [["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"], ["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"], ["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"]], "clusters": [[[70, 70], [485, 486], [500, 500], [73, 73], [55, 55], [153, 154], [366, 366]]], "sentences": [["In", "the", "summer", "of", "2005", ",", "a", "picture", "that", "people", "have", "long", "been", "looking", "forward", "to", "started", "emerging", "with", "frequency", "in", "various", "major", "Hong", "Kong", "media", "."], ["With", "their", "unique", "charm", ",", "these", "well", "-", "known", "cartoon", "images", "once", "again", "caused", "Hong", "Kong", "to", "be", "a", "focus", "of", "worldwide", "attention", "."]]}

+ 1
- 0
test/data_for_tests/coreference/coreference_test.json View File

@@ -0,0 +1 @@
{"doc_key": "bc/cctv/00/cctv_0005_0", "speakers": [["speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1"], ["speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1", "speaker#1"]], "clusters": [[[57, 59], [25, 27], [42, 44]]], "sentences": [["--", "basically", ",", "it", "was", "unanimously", "agreed", "upon", "by", "the", "various", "relevant", "parties", "."], ["To", "express", "its", "determination", ",", "the", "Chinese", "securities", "regulatory", "department", "compares", "this", "stock", "reform", "to", "a", "die", "that", "has", "been", "cast", "."]]}

+ 1
- 0
test/data_for_tests/coreference/coreference_train.json View File

@@ -0,0 +1 @@
{"doc_key": "bc/cctv/00/cctv_0001_0", "speakers": [["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"], ["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"]], "clusters": [[[113, 114], [42, 45], [88, 91]]], "sentences": [["What", "kind", "of", "memory", "?"], ["We", "respectfully", "invite", "you", "to", "watch", "a", "special", "edition", "of", "Across", "China", "."]]}

+ 16
- 0
test/io/loader/test_coreference_loader.py View File

@@ -0,0 +1,16 @@
from fastNLP.io.loader.coreference import CRLoader
import unittest

class TestCR(unittest.TestCase):
def test_load(self):

test_root = "test/data_for_tests/coreference/"
train_path = test_root+"coreference_train.json"
dev_path = test_root+"coreference_dev.json"
test_path = test_root+"coreference_test.json"
paths = {"train": train_path,"dev":dev_path,"test":test_path}

bundle1 = CRLoader().load(paths)
bundle2 = CRLoader().load(test_root)
print(bundle1)
print(bundle2)

+ 24
- 0
test/io/pipe/test_coreference.py View File

@@ -0,0 +1,24 @@
import unittest
from fastNLP.io.pipe.coreference import CoreferencePipe


class TestCR(unittest.TestCase):

def test_load(self):
class Config():
max_sentences = 50
filter = [3, 4, 5]
char_path = None
config = Config()

file_root_path = "test/data_for_tests/coreference/"
train_path = file_root_path + "coreference_train.json"
dev_path = file_root_path + "coreference_dev.json"
test_path = file_root_path + "coreference_test.json"

paths = {"train": train_path, "dev": dev_path, "test": test_path}

bundle1 = CoreferencePipe(config).process_from_file(paths)
bundle2 = CoreferencePipe(config).process_from_file(file_root_path)
print(bundle1)
print(bundle2)

Loading…
Cancel
Save