From 753327d214e296b96e00b19ba0d267c61d7d5d7d Mon Sep 17 00:00:00 2001 From: Yige Xu Date: Wed, 11 Sep 2019 02:16:57 +0800 Subject: [PATCH] fix code style in coreference task and related codes --- fastNLP/io/loader/__init__.py | 4 ++-- fastNLP/io/loader/coreference.py | 12 ++++++---- fastNLP/io/pipe/__init__.py | 4 ++-- fastNLP/io/pipe/coreference.py | 12 ++++------ reproduction/coreference_resolution/train.py | 24 +++++++------------ reproduction/coreference_resolution/valid.py | 4 ++-- .../{ => io}/coreference/coreference_dev.json | 0 .../coreference/coreference_test.json | 0 .../coreference/coreference_train.json | 0 test/io/loader/test_coreference_loader.py | 22 ++++++++++++----- test/io/pipe/test_coreference.py | 19 +++++++++++---- 11 files changed, 57 insertions(+), 44 deletions(-) rename test/data_for_tests/{ => io}/coreference/coreference_dev.json (100%) rename test/data_for_tests/{ => io}/coreference/coreference_test.json (100%) rename test/data_for_tests/{ => io}/coreference/coreference_train.json (100%) diff --git a/fastNLP/io/loader/__init__.py b/fastNLP/io/loader/__init__.py index cf88e8c0..06ad57c3 100644 --- a/fastNLP/io/loader/__init__.py +++ b/fastNLP/io/loader/__init__.py @@ -74,7 +74,7 @@ __all__ = [ "QNLILoader", "RTELoader", - "CRLoader" + "CoReferenceLoader" ] from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader @@ -84,4 +84,4 @@ from .json import JsonLoader from .loader import Loader from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader from .conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader -from .coreference import CRLoader \ No newline at end of file +from .coreference import CoReferenceLoader \ No newline at end of file diff --git a/fastNLP/io/loader/coreference.py b/fastNLP/io/loader/coreference.py index 714b11e5..4293f65a 100644 --- a/fastNLP/io/loader/coreference.py +++ b/fastNLP/io/loader/coreference.py @@ -1,5 +1,9 @@ """undocumented""" +__all__ = [ + "CoReferenceLoader", +] + from ...core.dataset import DataSet from ..file_reader import _read_json from ...core.instance import Instance @@ -7,7 +11,7 @@ from ...core.const import Const from .json import JsonLoader -class CRLoader(JsonLoader): +class CoReferenceLoader(JsonLoader): """ 原始数据中内容应该为, 每一行为一个json对象,其中doc_key包含文章的种类信息,speakers包含每句话的说话者信息,cluster是指向现实中同一个事物的聚集,sentences是文本信息内容。 @@ -24,8 +28,8 @@ class CRLoader(JsonLoader): """ def __init__(self, fields=None, dropna=False): super().__init__(fields, dropna) - # self.fields = {"doc_key":Const.INPUTS(0),"speakers":Const.INPUTS(1),"clusters":Const.TARGET,"sentences":Const.INPUTS(2)} - # TODO check 1 + # self.fields = {"doc_key":Const.INPUTS(0),"speakers":Const.INPUTS(1), + # "clusters":Const.TARGET,"sentences":Const.INPUTS(2)} self.fields = {"doc_key": Const.RAW_WORDS(0), "speakers": Const.RAW_WORDS(1), "clusters": Const.RAW_WORDS(2), "sentences": Const.RAW_WORDS(3)} @@ -43,4 +47,4 @@ class CRLoader(JsonLoader): else: ins = d dataset.append(Instance(**ins)) - return dataset \ No newline at end of file + return dataset diff --git a/fastNLP/io/pipe/__init__.py b/fastNLP/io/pipe/__init__.py index f3534cc2..0ddb1f2d 100644 --- a/fastNLP/io/pipe/__init__.py +++ b/fastNLP/io/pipe/__init__.py @@ -39,7 +39,7 @@ __all__ = [ "QNLIPipe", "MNLIPipe", - "CoreferencePipe" + "CoReferencePipe" ] from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe @@ -49,4 +49,4 @@ from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe from .pipe import Pipe from .conll import Conll2003Pipe from .cws import CWSPipe -from .coreference import CoreferencePipe +from .coreference import CoReferencePipe diff --git a/fastNLP/io/pipe/coreference.py b/fastNLP/io/pipe/coreference.py index 3c171507..c1b218a5 100644 --- a/fastNLP/io/pipe/coreference.py +++ b/fastNLP/io/pipe/coreference.py @@ -1,8 +1,7 @@ """undocumented""" __all__ = [ - "CoreferencePipe" - + "CoReferencePipe" ] import collections @@ -12,11 +11,11 @@ import numpy as np from fastNLP.core.vocabulary import Vocabulary from .pipe import Pipe from ..data_bundle import DataBundle -from ..loader.coreference import CRLoader +from ..loader.coreference import CoReferenceLoader from ...core.const import Const -class CoreferencePipe(Pipe): +class CoReferencePipe(Pipe): """ 对Coreference resolution问题进行处理,得到文章种类/说话者/字符级信息/序列长度。 """ @@ -52,7 +51,7 @@ class CoreferencePipe(Pipe): vocab = Vocabulary().from_dataset(*data_bundle.datasets.values(), field_name= Const.RAW_WORDS(3)) vocab.build_vocab() word2id = vocab.word2idx - data_bundle.set_vocab(vocab,Const.INPUT) + data_bundle.set_vocab(vocab, Const.INPUTS(0)) if self.config.char_path: char_dict = get_char_dict(self.config.char_path) else: @@ -93,7 +92,6 @@ class CoreferencePipe(Pipe): # clusters ds.rename_field(Const.RAW_WORDS(2), Const.TARGET) - ds.set_ignore_type(Const.TARGET) ds.set_padder(Const.TARGET, None) ds.set_input(Const.INPUTS(0), Const.INPUTS(1), Const.INPUTS(2), Const.INPUTS(3), Const.CHAR_INPUT, Const.INPUT_LEN) @@ -102,7 +100,7 @@ class CoreferencePipe(Pipe): return data_bundle def process_from_file(self, paths): - bundle = CRLoader().load(paths) + bundle = CoReferenceLoader().load(paths) return self.process(bundle) diff --git a/reproduction/coreference_resolution/train.py b/reproduction/coreference_resolution/train.py index 23ba5d5b..d5445cd5 100644 --- a/reproduction/coreference_resolution/train.py +++ b/reproduction/coreference_resolution/train.py @@ -1,5 +1,3 @@ -import sys -sys.path.append('../..') import torch from torch.optim import Adam @@ -7,20 +5,15 @@ from torch.optim import Adam from fastNLP.core.callback import Callback, GradientClipCallback from fastNLP.core.trainer import Trainer -from fastNLP.io.pipe.coreference import CoreferencePipe +from fastNLP.io.pipe.coreference import CoReferencePipe from fastNLP.core.const import Const from reproduction.coreference_resolution.model.config import Config from reproduction.coreference_resolution.model.model_re import Model from reproduction.coreference_resolution.model.softmax_loss import SoftmaxLoss from reproduction.coreference_resolution.model.metric import CRMetric -from fastNLP import SequentialSampler -from fastNLP import cache_results -# torch.backends.cudnn.benchmark = False -# torch.backends.cudnn.deterministic = True - class LRCallback(Callback): def __init__(self, parameters, decay_rate=1e-3): super().__init__() @@ -38,15 +31,13 @@ if __name__ == "__main__": print(config) - # @cache_results('cache.pkl') def cache(): - bundle = CoreferencePipe(config).process_from_file({'train': config.train_path, 'dev': config.dev_path,'test': config.test_path}) + bundle = CoReferencePipe(config).process_from_file({'train': config.train_path, 'dev': config.dev_path, + 'test': config.test_path}) return bundle data_bundle = cache() - print("数据集划分:\ntrain:", str(len(data_bundle.get_dataset("train"))), - "\ndev:" + str(len(data_bundle.get_dataset("dev"))) + "\ntest:" + str(len(data_bundle.get_dataset('test')))) - # print(data_info) - model = Model(data_bundle.get_vocab(Const.INPUT), config) + print(data_bundle) + model = Model(data_bundle.get_vocab(Const.INPUTS(0)), config) print(model) loss = SoftmaxLoss() @@ -59,9 +50,10 @@ if __name__ == "__main__": trainer = Trainer(model=model, train_data=data_bundle.datasets["train"], dev_data=data_bundle.datasets["dev"], loss=loss, metrics=metric, check_code_level=-1, sampler=None, - batch_size=1, device=torch.device("cuda:" + config.cuda), metric_key='f', n_epochs=config.epoch, + batch_size=1, device=torch.device("cuda:" + config.cuda) if torch.cuda.is_available() else None, + metric_key='f', n_epochs=config.epoch, optimizer=optim, - save_path= None, + save_path=None, callbacks=[lr_decay_callback, GradientClipCallback(clip_value=5)]) print() diff --git a/reproduction/coreference_resolution/valid.py b/reproduction/coreference_resolution/valid.py index a528ea06..e79642b8 100644 --- a/reproduction/coreference_resolution/valid.py +++ b/reproduction/coreference_resolution/valid.py @@ -1,7 +1,7 @@ import torch from reproduction.coreference_resolution.model.config import Config from reproduction.coreference_resolution.model.metric import CRMetric -from fastNLP.io.pipe.coreference import CoreferencePipe +from fastNLP.io.pipe.coreference import CoReferencePipe from fastNLP import Tester import argparse @@ -13,7 +13,7 @@ if __name__=='__main__': args = parser.parse_args() config = Config() - bundle = CoreferencePipe(Config()).process_from_file( + bundle = CoReferencePipe(Config()).process_from_file( {'train': config.train_path, 'dev': config.dev_path, 'test': config.test_path}) metirc = CRMetric() model = torch.load(args.path) diff --git a/test/data_for_tests/coreference/coreference_dev.json b/test/data_for_tests/io/coreference/coreference_dev.json similarity index 100% rename from test/data_for_tests/coreference/coreference_dev.json rename to test/data_for_tests/io/coreference/coreference_dev.json diff --git a/test/data_for_tests/coreference/coreference_test.json b/test/data_for_tests/io/coreference/coreference_test.json similarity index 100% rename from test/data_for_tests/coreference/coreference_test.json rename to test/data_for_tests/io/coreference/coreference_test.json diff --git a/test/data_for_tests/coreference/coreference_train.json b/test/data_for_tests/io/coreference/coreference_train.json similarity index 100% rename from test/data_for_tests/coreference/coreference_train.json rename to test/data_for_tests/io/coreference/coreference_train.json diff --git a/test/io/loader/test_coreference_loader.py b/test/io/loader/test_coreference_loader.py index d827e947..02f3a1c5 100644 --- a/test/io/loader/test_coreference_loader.py +++ b/test/io/loader/test_coreference_loader.py @@ -1,16 +1,26 @@ -from fastNLP.io.loader.coreference import CRLoader +from fastNLP.io.loader.coreference import CoReferenceLoader import unittest + class TestCR(unittest.TestCase): def test_load(self): - test_root = "test/data_for_tests/coreference/" + test_root = "test/data_for_tests/io/coreference/" train_path = test_root+"coreference_train.json" dev_path = test_root+"coreference_dev.json" test_path = test_root+"coreference_test.json" - paths = {"train": train_path,"dev":dev_path,"test":test_path} + paths = {"train": train_path, "dev": dev_path, "test": test_path} - bundle1 = CRLoader().load(paths) - bundle2 = CRLoader().load(test_root) + bundle1 = CoReferenceLoader().load(paths) + bundle2 = CoReferenceLoader().load(test_root) print(bundle1) - print(bundle2) \ No newline at end of file + print(bundle2) + + self.assertEqual(bundle1.num_dataset, 3) + self.assertEqual(bundle2.num_dataset, 3) + self.assertEqual(bundle1.num_vocab, 0) + self.assertEqual(bundle2.num_vocab, 0) + + self.assertEqual(len(bundle1.get_dataset('train')), 1) + self.assertEqual(len(bundle1.get_dataset('dev')), 1) + self.assertEqual(len(bundle1.get_dataset('test')), 1) diff --git a/test/io/pipe/test_coreference.py b/test/io/pipe/test_coreference.py index 517be993..3a492419 100644 --- a/test/io/pipe/test_coreference.py +++ b/test/io/pipe/test_coreference.py @@ -1,5 +1,5 @@ import unittest -from fastNLP.io.pipe.coreference import CoreferencePipe +from fastNLP.io.pipe.coreference import CoReferencePipe class TestCR(unittest.TestCase): @@ -11,14 +11,23 @@ class TestCR(unittest.TestCase): char_path = None config = Config() - file_root_path = "test/data_for_tests/coreference/" + file_root_path = "test/data_for_tests/io/coreference/" train_path = file_root_path + "coreference_train.json" dev_path = file_root_path + "coreference_dev.json" test_path = file_root_path + "coreference_test.json" paths = {"train": train_path, "dev": dev_path, "test": test_path} - bundle1 = CoreferencePipe(config).process_from_file(paths) - bundle2 = CoreferencePipe(config).process_from_file(file_root_path) + bundle1 = CoReferencePipe(config).process_from_file(paths) + bundle2 = CoReferencePipe(config).process_from_file(file_root_path) print(bundle1) - print(bundle2) \ No newline at end of file + print(bundle2) + self.assertEqual(bundle1.num_dataset, 3) + self.assertEqual(bundle2.num_dataset, 3) + self.assertEqual(bundle1.num_vocab, 1) + self.assertEqual(bundle2.num_vocab, 1) + + self.assertEqual(len(bundle1.get_dataset('train')), 1) + self.assertEqual(len(bundle1.get_dataset('dev')), 1) + self.assertEqual(len(bundle1.get_dataset('test')), 1) + self.assertEqual(len(bundle1.get_vocab('words1')), 84)