@@ -74,7 +74,7 @@ __all__ = [ | |||||
"QNLILoader", | "QNLILoader", | ||||
"RTELoader", | "RTELoader", | ||||
"CRLoader" | |||||
"CoReferenceLoader" | |||||
] | ] | ||||
from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader | from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader | ||||
from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader | from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader | ||||
@@ -84,4 +84,4 @@ from .json import JsonLoader | |||||
from .loader import Loader | from .loader import Loader | ||||
from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader | from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader | ||||
from .conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader | from .conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader | ||||
from .coreference import CRLoader | |||||
from .coreference import CoReferenceLoader |
@@ -1,5 +1,9 @@ | |||||
"""undocumented""" | """undocumented""" | ||||
__all__ = [ | |||||
"CoReferenceLoader", | |||||
] | |||||
from ...core.dataset import DataSet | from ...core.dataset import DataSet | ||||
from ..file_reader import _read_json | from ..file_reader import _read_json | ||||
from ...core.instance import Instance | from ...core.instance import Instance | ||||
@@ -7,7 +11,7 @@ from ...core.const import Const | |||||
from .json import JsonLoader | from .json import JsonLoader | ||||
class CRLoader(JsonLoader): | |||||
class CoReferenceLoader(JsonLoader): | |||||
""" | """ | ||||
原始数据中内容应该为, 每一行为一个json对象,其中doc_key包含文章的种类信息,speakers包含每句话的说话者信息,cluster是指向现实中同一个事物的聚集,sentences是文本信息内容。 | 原始数据中内容应该为, 每一行为一个json对象,其中doc_key包含文章的种类信息,speakers包含每句话的说话者信息,cluster是指向现实中同一个事物的聚集,sentences是文本信息内容。 | ||||
@@ -24,8 +28,8 @@ class CRLoader(JsonLoader): | |||||
""" | """ | ||||
def __init__(self, fields=None, dropna=False): | def __init__(self, fields=None, dropna=False): | ||||
super().__init__(fields, dropna) | super().__init__(fields, dropna) | ||||
# self.fields = {"doc_key":Const.INPUTS(0),"speakers":Const.INPUTS(1),"clusters":Const.TARGET,"sentences":Const.INPUTS(2)} | |||||
# TODO check 1 | |||||
# self.fields = {"doc_key":Const.INPUTS(0),"speakers":Const.INPUTS(1), | |||||
# "clusters":Const.TARGET,"sentences":Const.INPUTS(2)} | |||||
self.fields = {"doc_key": Const.RAW_WORDS(0), "speakers": Const.RAW_WORDS(1), "clusters": Const.RAW_WORDS(2), | self.fields = {"doc_key": Const.RAW_WORDS(0), "speakers": Const.RAW_WORDS(1), "clusters": Const.RAW_WORDS(2), | ||||
"sentences": Const.RAW_WORDS(3)} | "sentences": Const.RAW_WORDS(3)} | ||||
@@ -43,4 +47,4 @@ class CRLoader(JsonLoader): | |||||
else: | else: | ||||
ins = d | ins = d | ||||
dataset.append(Instance(**ins)) | dataset.append(Instance(**ins)) | ||||
return dataset | |||||
return dataset |
@@ -39,7 +39,7 @@ __all__ = [ | |||||
"QNLIPipe", | "QNLIPipe", | ||||
"MNLIPipe", | "MNLIPipe", | ||||
"CoreferencePipe" | |||||
"CoReferencePipe" | |||||
] | ] | ||||
from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe | from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe | ||||
@@ -49,4 +49,4 @@ from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe | |||||
from .pipe import Pipe | from .pipe import Pipe | ||||
from .conll import Conll2003Pipe | from .conll import Conll2003Pipe | ||||
from .cws import CWSPipe | from .cws import CWSPipe | ||||
from .coreference import CoreferencePipe | |||||
from .coreference import CoReferencePipe |
@@ -1,8 +1,7 @@ | |||||
"""undocumented""" | """undocumented""" | ||||
__all__ = [ | __all__ = [ | ||||
"CoreferencePipe" | |||||
"CoReferencePipe" | |||||
] | ] | ||||
import collections | import collections | ||||
@@ -12,11 +11,11 @@ import numpy as np | |||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from .pipe import Pipe | from .pipe import Pipe | ||||
from ..data_bundle import DataBundle | from ..data_bundle import DataBundle | ||||
from ..loader.coreference import CRLoader | |||||
from ..loader.coreference import CoReferenceLoader | |||||
from ...core.const import Const | from ...core.const import Const | ||||
class CoreferencePipe(Pipe): | |||||
class CoReferencePipe(Pipe): | |||||
""" | """ | ||||
对Coreference resolution问题进行处理,得到文章种类/说话者/字符级信息/序列长度。 | 对Coreference resolution问题进行处理,得到文章种类/说话者/字符级信息/序列长度。 | ||||
""" | """ | ||||
@@ -52,7 +51,7 @@ class CoreferencePipe(Pipe): | |||||
vocab = Vocabulary().from_dataset(*data_bundle.datasets.values(), field_name= Const.RAW_WORDS(3)) | vocab = Vocabulary().from_dataset(*data_bundle.datasets.values(), field_name= Const.RAW_WORDS(3)) | ||||
vocab.build_vocab() | vocab.build_vocab() | ||||
word2id = vocab.word2idx | word2id = vocab.word2idx | ||||
data_bundle.set_vocab(vocab,Const.INPUT) | |||||
data_bundle.set_vocab(vocab, Const.INPUTS(0)) | |||||
if self.config.char_path: | if self.config.char_path: | ||||
char_dict = get_char_dict(self.config.char_path) | char_dict = get_char_dict(self.config.char_path) | ||||
else: | else: | ||||
@@ -93,7 +92,6 @@ class CoreferencePipe(Pipe): | |||||
# clusters | # clusters | ||||
ds.rename_field(Const.RAW_WORDS(2), Const.TARGET) | ds.rename_field(Const.RAW_WORDS(2), Const.TARGET) | ||||
ds.set_ignore_type(Const.TARGET) | ds.set_ignore_type(Const.TARGET) | ||||
ds.set_padder(Const.TARGET, None) | ds.set_padder(Const.TARGET, None) | ||||
ds.set_input(Const.INPUTS(0), Const.INPUTS(1), Const.INPUTS(2), Const.INPUTS(3), Const.CHAR_INPUT, Const.INPUT_LEN) | ds.set_input(Const.INPUTS(0), Const.INPUTS(1), Const.INPUTS(2), Const.INPUTS(3), Const.CHAR_INPUT, Const.INPUT_LEN) | ||||
@@ -102,7 +100,7 @@ class CoreferencePipe(Pipe): | |||||
return data_bundle | return data_bundle | ||||
def process_from_file(self, paths): | def process_from_file(self, paths): | ||||
bundle = CRLoader().load(paths) | |||||
bundle = CoReferenceLoader().load(paths) | |||||
return self.process(bundle) | return self.process(bundle) | ||||
@@ -1,5 +1,3 @@ | |||||
import sys | |||||
sys.path.append('../..') | |||||
import torch | import torch | ||||
from torch.optim import Adam | from torch.optim import Adam | ||||
@@ -7,20 +5,15 @@ from torch.optim import Adam | |||||
from fastNLP.core.callback import Callback, GradientClipCallback | from fastNLP.core.callback import Callback, GradientClipCallback | ||||
from fastNLP.core.trainer import Trainer | from fastNLP.core.trainer import Trainer | ||||
from fastNLP.io.pipe.coreference import CoreferencePipe | |||||
from fastNLP.io.pipe.coreference import CoReferencePipe | |||||
from fastNLP.core.const import Const | from fastNLP.core.const import Const | ||||
from reproduction.coreference_resolution.model.config import Config | from reproduction.coreference_resolution.model.config import Config | ||||
from reproduction.coreference_resolution.model.model_re import Model | from reproduction.coreference_resolution.model.model_re import Model | ||||
from reproduction.coreference_resolution.model.softmax_loss import SoftmaxLoss | from reproduction.coreference_resolution.model.softmax_loss import SoftmaxLoss | ||||
from reproduction.coreference_resolution.model.metric import CRMetric | from reproduction.coreference_resolution.model.metric import CRMetric | ||||
from fastNLP import SequentialSampler | |||||
from fastNLP import cache_results | |||||
# torch.backends.cudnn.benchmark = False | |||||
# torch.backends.cudnn.deterministic = True | |||||
class LRCallback(Callback): | class LRCallback(Callback): | ||||
def __init__(self, parameters, decay_rate=1e-3): | def __init__(self, parameters, decay_rate=1e-3): | ||||
super().__init__() | super().__init__() | ||||
@@ -38,15 +31,13 @@ if __name__ == "__main__": | |||||
print(config) | print(config) | ||||
# @cache_results('cache.pkl') | |||||
def cache(): | def cache(): | ||||
bundle = CoreferencePipe(config).process_from_file({'train': config.train_path, 'dev': config.dev_path,'test': config.test_path}) | |||||
bundle = CoReferencePipe(config).process_from_file({'train': config.train_path, 'dev': config.dev_path, | |||||
'test': config.test_path}) | |||||
return bundle | return bundle | ||||
data_bundle = cache() | data_bundle = cache() | ||||
print("数据集划分:\ntrain:", str(len(data_bundle.get_dataset("train"))), | |||||
"\ndev:" + str(len(data_bundle.get_dataset("dev"))) + "\ntest:" + str(len(data_bundle.get_dataset('test')))) | |||||
# print(data_info) | |||||
model = Model(data_bundle.get_vocab(Const.INPUT), config) | |||||
print(data_bundle) | |||||
model = Model(data_bundle.get_vocab(Const.INPUTS(0)), config) | |||||
print(model) | print(model) | ||||
loss = SoftmaxLoss() | loss = SoftmaxLoss() | ||||
@@ -59,9 +50,10 @@ if __name__ == "__main__": | |||||
trainer = Trainer(model=model, train_data=data_bundle.datasets["train"], dev_data=data_bundle.datasets["dev"], | trainer = Trainer(model=model, train_data=data_bundle.datasets["train"], dev_data=data_bundle.datasets["dev"], | ||||
loss=loss, metrics=metric, check_code_level=-1, sampler=None, | loss=loss, metrics=metric, check_code_level=-1, sampler=None, | ||||
batch_size=1, device=torch.device("cuda:" + config.cuda), metric_key='f', n_epochs=config.epoch, | |||||
batch_size=1, device=torch.device("cuda:" + config.cuda) if torch.cuda.is_available() else None, | |||||
metric_key='f', n_epochs=config.epoch, | |||||
optimizer=optim, | optimizer=optim, | ||||
save_path= None, | |||||
save_path=None, | |||||
callbacks=[lr_decay_callback, GradientClipCallback(clip_value=5)]) | callbacks=[lr_decay_callback, GradientClipCallback(clip_value=5)]) | ||||
print() | print() | ||||
@@ -1,7 +1,7 @@ | |||||
import torch | import torch | ||||
from reproduction.coreference_resolution.model.config import Config | from reproduction.coreference_resolution.model.config import Config | ||||
from reproduction.coreference_resolution.model.metric import CRMetric | from reproduction.coreference_resolution.model.metric import CRMetric | ||||
from fastNLP.io.pipe.coreference import CoreferencePipe | |||||
from fastNLP.io.pipe.coreference import CoReferencePipe | |||||
from fastNLP import Tester | from fastNLP import Tester | ||||
import argparse | import argparse | ||||
@@ -13,7 +13,7 @@ if __name__=='__main__': | |||||
args = parser.parse_args() | args = parser.parse_args() | ||||
config = Config() | config = Config() | ||||
bundle = CoreferencePipe(Config()).process_from_file( | |||||
bundle = CoReferencePipe(Config()).process_from_file( | |||||
{'train': config.train_path, 'dev': config.dev_path, 'test': config.test_path}) | {'train': config.train_path, 'dev': config.dev_path, 'test': config.test_path}) | ||||
metirc = CRMetric() | metirc = CRMetric() | ||||
model = torch.load(args.path) | model = torch.load(args.path) | ||||
@@ -1,16 +1,26 @@ | |||||
from fastNLP.io.loader.coreference import CRLoader | |||||
from fastNLP.io.loader.coreference import CoReferenceLoader | |||||
import unittest | import unittest | ||||
class TestCR(unittest.TestCase): | class TestCR(unittest.TestCase): | ||||
def test_load(self): | def test_load(self): | ||||
test_root = "test/data_for_tests/coreference/" | |||||
test_root = "test/data_for_tests/io/coreference/" | |||||
train_path = test_root+"coreference_train.json" | train_path = test_root+"coreference_train.json" | ||||
dev_path = test_root+"coreference_dev.json" | dev_path = test_root+"coreference_dev.json" | ||||
test_path = test_root+"coreference_test.json" | test_path = test_root+"coreference_test.json" | ||||
paths = {"train": train_path,"dev":dev_path,"test":test_path} | |||||
paths = {"train": train_path, "dev": dev_path, "test": test_path} | |||||
bundle1 = CRLoader().load(paths) | |||||
bundle2 = CRLoader().load(test_root) | |||||
bundle1 = CoReferenceLoader().load(paths) | |||||
bundle2 = CoReferenceLoader().load(test_root) | |||||
print(bundle1) | print(bundle1) | ||||
print(bundle2) | |||||
print(bundle2) | |||||
self.assertEqual(bundle1.num_dataset, 3) | |||||
self.assertEqual(bundle2.num_dataset, 3) | |||||
self.assertEqual(bundle1.num_vocab, 0) | |||||
self.assertEqual(bundle2.num_vocab, 0) | |||||
self.assertEqual(len(bundle1.get_dataset('train')), 1) | |||||
self.assertEqual(len(bundle1.get_dataset('dev')), 1) | |||||
self.assertEqual(len(bundle1.get_dataset('test')), 1) |
@@ -1,5 +1,5 @@ | |||||
import unittest | import unittest | ||||
from fastNLP.io.pipe.coreference import CoreferencePipe | |||||
from fastNLP.io.pipe.coreference import CoReferencePipe | |||||
class TestCR(unittest.TestCase): | class TestCR(unittest.TestCase): | ||||
@@ -11,14 +11,23 @@ class TestCR(unittest.TestCase): | |||||
char_path = None | char_path = None | ||||
config = Config() | config = Config() | ||||
file_root_path = "test/data_for_tests/coreference/" | |||||
file_root_path = "test/data_for_tests/io/coreference/" | |||||
train_path = file_root_path + "coreference_train.json" | train_path = file_root_path + "coreference_train.json" | ||||
dev_path = file_root_path + "coreference_dev.json" | dev_path = file_root_path + "coreference_dev.json" | ||||
test_path = file_root_path + "coreference_test.json" | test_path = file_root_path + "coreference_test.json" | ||||
paths = {"train": train_path, "dev": dev_path, "test": test_path} | paths = {"train": train_path, "dev": dev_path, "test": test_path} | ||||
bundle1 = CoreferencePipe(config).process_from_file(paths) | |||||
bundle2 = CoreferencePipe(config).process_from_file(file_root_path) | |||||
bundle1 = CoReferencePipe(config).process_from_file(paths) | |||||
bundle2 = CoReferencePipe(config).process_from_file(file_root_path) | |||||
print(bundle1) | print(bundle1) | ||||
print(bundle2) | |||||
print(bundle2) | |||||
self.assertEqual(bundle1.num_dataset, 3) | |||||
self.assertEqual(bundle2.num_dataset, 3) | |||||
self.assertEqual(bundle1.num_vocab, 1) | |||||
self.assertEqual(bundle2.num_vocab, 1) | |||||
self.assertEqual(len(bundle1.get_dataset('train')), 1) | |||||
self.assertEqual(len(bundle1.get_dataset('dev')), 1) | |||||
self.assertEqual(len(bundle1.get_dataset('test')), 1) | |||||
self.assertEqual(len(bundle1.get_vocab('words1')), 84) |