| @@ -74,7 +74,7 @@ __all__ = [ | |||
| "QNLILoader", | |||
| "RTELoader", | |||
| "CRLoader" | |||
| "CoReferenceLoader" | |||
| ] | |||
| from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader | |||
| from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader | |||
| @@ -84,4 +84,4 @@ from .json import JsonLoader | |||
| from .loader import Loader | |||
| from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader | |||
| from .conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader | |||
| from .coreference import CRLoader | |||
| from .coreference import CoReferenceLoader | |||
| @@ -1,5 +1,9 @@ | |||
| """undocumented""" | |||
| __all__ = [ | |||
| "CoReferenceLoader", | |||
| ] | |||
| from ...core.dataset import DataSet | |||
| from ..file_reader import _read_json | |||
| from ...core.instance import Instance | |||
| @@ -7,7 +11,7 @@ from ...core.const import Const | |||
| from .json import JsonLoader | |||
| class CRLoader(JsonLoader): | |||
| class CoReferenceLoader(JsonLoader): | |||
| """ | |||
| 原始数据中内容应该为, 每一行为一个json对象,其中doc_key包含文章的种类信息,speakers包含每句话的说话者信息,cluster是指向现实中同一个事物的聚集,sentences是文本信息内容。 | |||
| @@ -24,8 +28,8 @@ class CRLoader(JsonLoader): | |||
| """ | |||
| def __init__(self, fields=None, dropna=False): | |||
| super().__init__(fields, dropna) | |||
| # self.fields = {"doc_key":Const.INPUTS(0),"speakers":Const.INPUTS(1),"clusters":Const.TARGET,"sentences":Const.INPUTS(2)} | |||
| # TODO check 1 | |||
| # self.fields = {"doc_key":Const.INPUTS(0),"speakers":Const.INPUTS(1), | |||
| # "clusters":Const.TARGET,"sentences":Const.INPUTS(2)} | |||
| self.fields = {"doc_key": Const.RAW_WORDS(0), "speakers": Const.RAW_WORDS(1), "clusters": Const.RAW_WORDS(2), | |||
| "sentences": Const.RAW_WORDS(3)} | |||
| @@ -43,4 +47,4 @@ class CRLoader(JsonLoader): | |||
| else: | |||
| ins = d | |||
| dataset.append(Instance(**ins)) | |||
| return dataset | |||
| return dataset | |||
| @@ -39,7 +39,7 @@ __all__ = [ | |||
| "QNLIPipe", | |||
| "MNLIPipe", | |||
| "CoreferencePipe" | |||
| "CoReferencePipe" | |||
| ] | |||
| from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe | |||
| @@ -49,4 +49,4 @@ from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe | |||
| from .pipe import Pipe | |||
| from .conll import Conll2003Pipe | |||
| from .cws import CWSPipe | |||
| from .coreference import CoreferencePipe | |||
| from .coreference import CoReferencePipe | |||
| @@ -1,8 +1,7 @@ | |||
| """undocumented""" | |||
| __all__ = [ | |||
| "CoreferencePipe" | |||
| "CoReferencePipe" | |||
| ] | |||
| import collections | |||
| @@ -12,11 +11,11 @@ import numpy as np | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| from .pipe import Pipe | |||
| from ..data_bundle import DataBundle | |||
| from ..loader.coreference import CRLoader | |||
| from ..loader.coreference import CoReferenceLoader | |||
| from ...core.const import Const | |||
| class CoreferencePipe(Pipe): | |||
| class CoReferencePipe(Pipe): | |||
| """ | |||
| 对Coreference resolution问题进行处理,得到文章种类/说话者/字符级信息/序列长度。 | |||
| """ | |||
| @@ -52,7 +51,7 @@ class CoreferencePipe(Pipe): | |||
| vocab = Vocabulary().from_dataset(*data_bundle.datasets.values(), field_name= Const.RAW_WORDS(3)) | |||
| vocab.build_vocab() | |||
| word2id = vocab.word2idx | |||
| data_bundle.set_vocab(vocab,Const.INPUT) | |||
| data_bundle.set_vocab(vocab, Const.INPUTS(0)) | |||
| if self.config.char_path: | |||
| char_dict = get_char_dict(self.config.char_path) | |||
| else: | |||
| @@ -93,7 +92,6 @@ class CoreferencePipe(Pipe): | |||
| # clusters | |||
| ds.rename_field(Const.RAW_WORDS(2), Const.TARGET) | |||
| ds.set_ignore_type(Const.TARGET) | |||
| ds.set_padder(Const.TARGET, None) | |||
| ds.set_input(Const.INPUTS(0), Const.INPUTS(1), Const.INPUTS(2), Const.INPUTS(3), Const.CHAR_INPUT, Const.INPUT_LEN) | |||
| @@ -102,7 +100,7 @@ class CoreferencePipe(Pipe): | |||
| return data_bundle | |||
| def process_from_file(self, paths): | |||
| bundle = CRLoader().load(paths) | |||
| bundle = CoReferenceLoader().load(paths) | |||
| return self.process(bundle) | |||
| @@ -1,5 +1,3 @@ | |||
| import sys | |||
| sys.path.append('../..') | |||
| import torch | |||
| from torch.optim import Adam | |||
| @@ -7,20 +5,15 @@ from torch.optim import Adam | |||
| from fastNLP.core.callback import Callback, GradientClipCallback | |||
| from fastNLP.core.trainer import Trainer | |||
| from fastNLP.io.pipe.coreference import CoreferencePipe | |||
| from fastNLP.io.pipe.coreference import CoReferencePipe | |||
| from fastNLP.core.const import Const | |||
| from reproduction.coreference_resolution.model.config import Config | |||
| from reproduction.coreference_resolution.model.model_re import Model | |||
| from reproduction.coreference_resolution.model.softmax_loss import SoftmaxLoss | |||
| from reproduction.coreference_resolution.model.metric import CRMetric | |||
| from fastNLP import SequentialSampler | |||
| from fastNLP import cache_results | |||
| # torch.backends.cudnn.benchmark = False | |||
| # torch.backends.cudnn.deterministic = True | |||
| class LRCallback(Callback): | |||
| def __init__(self, parameters, decay_rate=1e-3): | |||
| super().__init__() | |||
| @@ -38,15 +31,13 @@ if __name__ == "__main__": | |||
| print(config) | |||
| # @cache_results('cache.pkl') | |||
| def cache(): | |||
| bundle = CoreferencePipe(config).process_from_file({'train': config.train_path, 'dev': config.dev_path,'test': config.test_path}) | |||
| bundle = CoReferencePipe(config).process_from_file({'train': config.train_path, 'dev': config.dev_path, | |||
| 'test': config.test_path}) | |||
| return bundle | |||
| data_bundle = cache() | |||
| print("数据集划分:\ntrain:", str(len(data_bundle.get_dataset("train"))), | |||
| "\ndev:" + str(len(data_bundle.get_dataset("dev"))) + "\ntest:" + str(len(data_bundle.get_dataset('test')))) | |||
| # print(data_info) | |||
| model = Model(data_bundle.get_vocab(Const.INPUT), config) | |||
| print(data_bundle) | |||
| model = Model(data_bundle.get_vocab(Const.INPUTS(0)), config) | |||
| print(model) | |||
| loss = SoftmaxLoss() | |||
| @@ -59,9 +50,10 @@ if __name__ == "__main__": | |||
| trainer = Trainer(model=model, train_data=data_bundle.datasets["train"], dev_data=data_bundle.datasets["dev"], | |||
| loss=loss, metrics=metric, check_code_level=-1, sampler=None, | |||
| batch_size=1, device=torch.device("cuda:" + config.cuda), metric_key='f', n_epochs=config.epoch, | |||
| batch_size=1, device=torch.device("cuda:" + config.cuda) if torch.cuda.is_available() else None, | |||
| metric_key='f', n_epochs=config.epoch, | |||
| optimizer=optim, | |||
| save_path= None, | |||
| save_path=None, | |||
| callbacks=[lr_decay_callback, GradientClipCallback(clip_value=5)]) | |||
| print() | |||
| @@ -1,7 +1,7 @@ | |||
| import torch | |||
| from reproduction.coreference_resolution.model.config import Config | |||
| from reproduction.coreference_resolution.model.metric import CRMetric | |||
| from fastNLP.io.pipe.coreference import CoreferencePipe | |||
| from fastNLP.io.pipe.coreference import CoReferencePipe | |||
| from fastNLP import Tester | |||
| import argparse | |||
| @@ -13,7 +13,7 @@ if __name__=='__main__': | |||
| args = parser.parse_args() | |||
| config = Config() | |||
| bundle = CoreferencePipe(Config()).process_from_file( | |||
| bundle = CoReferencePipe(Config()).process_from_file( | |||
| {'train': config.train_path, 'dev': config.dev_path, 'test': config.test_path}) | |||
| metirc = CRMetric() | |||
| model = torch.load(args.path) | |||
| @@ -1,16 +1,26 @@ | |||
| from fastNLP.io.loader.coreference import CRLoader | |||
| from fastNLP.io.loader.coreference import CoReferenceLoader | |||
| import unittest | |||
| class TestCR(unittest.TestCase): | |||
| def test_load(self): | |||
| test_root = "test/data_for_tests/coreference/" | |||
| test_root = "test/data_for_tests/io/coreference/" | |||
| train_path = test_root+"coreference_train.json" | |||
| dev_path = test_root+"coreference_dev.json" | |||
| test_path = test_root+"coreference_test.json" | |||
| paths = {"train": train_path,"dev":dev_path,"test":test_path} | |||
| paths = {"train": train_path, "dev": dev_path, "test": test_path} | |||
| bundle1 = CRLoader().load(paths) | |||
| bundle2 = CRLoader().load(test_root) | |||
| bundle1 = CoReferenceLoader().load(paths) | |||
| bundle2 = CoReferenceLoader().load(test_root) | |||
| print(bundle1) | |||
| print(bundle2) | |||
| print(bundle2) | |||
| self.assertEqual(bundle1.num_dataset, 3) | |||
| self.assertEqual(bundle2.num_dataset, 3) | |||
| self.assertEqual(bundle1.num_vocab, 0) | |||
| self.assertEqual(bundle2.num_vocab, 0) | |||
| self.assertEqual(len(bundle1.get_dataset('train')), 1) | |||
| self.assertEqual(len(bundle1.get_dataset('dev')), 1) | |||
| self.assertEqual(len(bundle1.get_dataset('test')), 1) | |||
| @@ -1,5 +1,5 @@ | |||
| import unittest | |||
| from fastNLP.io.pipe.coreference import CoreferencePipe | |||
| from fastNLP.io.pipe.coreference import CoReferencePipe | |||
| class TestCR(unittest.TestCase): | |||
| @@ -11,14 +11,23 @@ class TestCR(unittest.TestCase): | |||
| char_path = None | |||
| config = Config() | |||
| file_root_path = "test/data_for_tests/coreference/" | |||
| file_root_path = "test/data_for_tests/io/coreference/" | |||
| train_path = file_root_path + "coreference_train.json" | |||
| dev_path = file_root_path + "coreference_dev.json" | |||
| test_path = file_root_path + "coreference_test.json" | |||
| paths = {"train": train_path, "dev": dev_path, "test": test_path} | |||
| bundle1 = CoreferencePipe(config).process_from_file(paths) | |||
| bundle2 = CoreferencePipe(config).process_from_file(file_root_path) | |||
| bundle1 = CoReferencePipe(config).process_from_file(paths) | |||
| bundle2 = CoReferencePipe(config).process_from_file(file_root_path) | |||
| print(bundle1) | |||
| print(bundle2) | |||
| print(bundle2) | |||
| self.assertEqual(bundle1.num_dataset, 3) | |||
| self.assertEqual(bundle2.num_dataset, 3) | |||
| self.assertEqual(bundle1.num_vocab, 1) | |||
| self.assertEqual(bundle2.num_vocab, 1) | |||
| self.assertEqual(len(bundle1.get_dataset('train')), 1) | |||
| self.assertEqual(len(bundle1.get_dataset('dev')), 1) | |||
| self.assertEqual(len(bundle1.get_dataset('test')), 1) | |||
| self.assertEqual(len(bundle1.get_vocab('words1')), 84) | |||