Browse Source

fix code style in coreference task and related codes

tags/v0.4.10
Yige Xu 5 years ago
parent
commit
753327d214
11 changed files with 57 additions and 44 deletions
  1. +2
    -2
      fastNLP/io/loader/__init__.py
  2. +8
    -4
      fastNLP/io/loader/coreference.py
  3. +2
    -2
      fastNLP/io/pipe/__init__.py
  4. +5
    -7
      fastNLP/io/pipe/coreference.py
  5. +8
    -16
      reproduction/coreference_resolution/train.py
  6. +2
    -2
      reproduction/coreference_resolution/valid.py
  7. +0
    -0
      test/data_for_tests/io/coreference/coreference_dev.json
  8. +0
    -0
      test/data_for_tests/io/coreference/coreference_test.json
  9. +0
    -0
      test/data_for_tests/io/coreference/coreference_train.json
  10. +16
    -6
      test/io/loader/test_coreference_loader.py
  11. +14
    -5
      test/io/pipe/test_coreference.py

+ 2
- 2
fastNLP/io/loader/__init__.py View File

@@ -74,7 +74,7 @@ __all__ = [
"QNLILoader", "QNLILoader",
"RTELoader", "RTELoader",


"CRLoader"
"CoReferenceLoader"
] ]
from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader
from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader
@@ -84,4 +84,4 @@ from .json import JsonLoader
from .loader import Loader from .loader import Loader
from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader
from .conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader from .conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader
from .coreference import CRLoader
from .coreference import CoReferenceLoader

+ 8
- 4
fastNLP/io/loader/coreference.py View File

@@ -1,5 +1,9 @@
"""undocumented""" """undocumented"""


__all__ = [
"CoReferenceLoader",
]

from ...core.dataset import DataSet from ...core.dataset import DataSet
from ..file_reader import _read_json from ..file_reader import _read_json
from ...core.instance import Instance from ...core.instance import Instance
@@ -7,7 +11,7 @@ from ...core.const import Const
from .json import JsonLoader from .json import JsonLoader




class CRLoader(JsonLoader):
class CoReferenceLoader(JsonLoader):
""" """
原始数据中内容应该为, 每一行为一个json对象,其中doc_key包含文章的种类信息,speakers包含每句话的说话者信息,cluster是指向现实中同一个事物的聚集,sentences是文本信息内容。 原始数据中内容应该为, 每一行为一个json对象,其中doc_key包含文章的种类信息,speakers包含每句话的说话者信息,cluster是指向现实中同一个事物的聚集,sentences是文本信息内容。


@@ -24,8 +28,8 @@ class CRLoader(JsonLoader):
""" """
def __init__(self, fields=None, dropna=False): def __init__(self, fields=None, dropna=False):
super().__init__(fields, dropna) super().__init__(fields, dropna)
# self.fields = {"doc_key":Const.INPUTS(0),"speakers":Const.INPUTS(1),"clusters":Const.TARGET,"sentences":Const.INPUTS(2)}
# TODO check 1
# self.fields = {"doc_key":Const.INPUTS(0),"speakers":Const.INPUTS(1),
# "clusters":Const.TARGET,"sentences":Const.INPUTS(2)}
self.fields = {"doc_key": Const.RAW_WORDS(0), "speakers": Const.RAW_WORDS(1), "clusters": Const.RAW_WORDS(2), self.fields = {"doc_key": Const.RAW_WORDS(0), "speakers": Const.RAW_WORDS(1), "clusters": Const.RAW_WORDS(2),
"sentences": Const.RAW_WORDS(3)} "sentences": Const.RAW_WORDS(3)}


@@ -43,4 +47,4 @@ class CRLoader(JsonLoader):
else: else:
ins = d ins = d
dataset.append(Instance(**ins)) dataset.append(Instance(**ins))
return dataset
return dataset

+ 2
- 2
fastNLP/io/pipe/__init__.py View File

@@ -39,7 +39,7 @@ __all__ = [
"QNLIPipe", "QNLIPipe",
"MNLIPipe", "MNLIPipe",


"CoreferencePipe"
"CoReferencePipe"
] ]


from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe
@@ -49,4 +49,4 @@ from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe
from .pipe import Pipe from .pipe import Pipe
from .conll import Conll2003Pipe from .conll import Conll2003Pipe
from .cws import CWSPipe from .cws import CWSPipe
from .coreference import CoreferencePipe
from .coreference import CoReferencePipe

+ 5
- 7
fastNLP/io/pipe/coreference.py View File

@@ -1,8 +1,7 @@
"""undocumented""" """undocumented"""


__all__ = [ __all__ = [
"CoreferencePipe"

"CoReferencePipe"
] ]


import collections import collections
@@ -12,11 +11,11 @@ import numpy as np
from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.vocabulary import Vocabulary
from .pipe import Pipe from .pipe import Pipe
from ..data_bundle import DataBundle from ..data_bundle import DataBundle
from ..loader.coreference import CRLoader
from ..loader.coreference import CoReferenceLoader
from ...core.const import Const from ...core.const import Const




class CoreferencePipe(Pipe):
class CoReferencePipe(Pipe):
""" """
对Coreference resolution问题进行处理,得到文章种类/说话者/字符级信息/序列长度。 对Coreference resolution问题进行处理,得到文章种类/说话者/字符级信息/序列长度。
""" """
@@ -52,7 +51,7 @@ class CoreferencePipe(Pipe):
vocab = Vocabulary().from_dataset(*data_bundle.datasets.values(), field_name= Const.RAW_WORDS(3)) vocab = Vocabulary().from_dataset(*data_bundle.datasets.values(), field_name= Const.RAW_WORDS(3))
vocab.build_vocab() vocab.build_vocab()
word2id = vocab.word2idx word2id = vocab.word2idx
data_bundle.set_vocab(vocab,Const.INPUT)
data_bundle.set_vocab(vocab, Const.INPUTS(0))
if self.config.char_path: if self.config.char_path:
char_dict = get_char_dict(self.config.char_path) char_dict = get_char_dict(self.config.char_path)
else: else:
@@ -93,7 +92,6 @@ class CoreferencePipe(Pipe):
# clusters # clusters
ds.rename_field(Const.RAW_WORDS(2), Const.TARGET) ds.rename_field(Const.RAW_WORDS(2), Const.TARGET)



ds.set_ignore_type(Const.TARGET) ds.set_ignore_type(Const.TARGET)
ds.set_padder(Const.TARGET, None) ds.set_padder(Const.TARGET, None)
ds.set_input(Const.INPUTS(0), Const.INPUTS(1), Const.INPUTS(2), Const.INPUTS(3), Const.CHAR_INPUT, Const.INPUT_LEN) ds.set_input(Const.INPUTS(0), Const.INPUTS(1), Const.INPUTS(2), Const.INPUTS(3), Const.CHAR_INPUT, Const.INPUT_LEN)
@@ -102,7 +100,7 @@ class CoreferencePipe(Pipe):
return data_bundle return data_bundle


def process_from_file(self, paths): def process_from_file(self, paths):
bundle = CRLoader().load(paths)
bundle = CoReferenceLoader().load(paths)
return self.process(bundle) return self.process(bundle)






+ 8
- 16
reproduction/coreference_resolution/train.py View File

@@ -1,5 +1,3 @@
import sys
sys.path.append('../..')


import torch import torch
from torch.optim import Adam from torch.optim import Adam
@@ -7,20 +5,15 @@ from torch.optim import Adam
from fastNLP.core.callback import Callback, GradientClipCallback from fastNLP.core.callback import Callback, GradientClipCallback
from fastNLP.core.trainer import Trainer from fastNLP.core.trainer import Trainer


from fastNLP.io.pipe.coreference import CoreferencePipe
from fastNLP.io.pipe.coreference import CoReferencePipe
from fastNLP.core.const import Const from fastNLP.core.const import Const


from reproduction.coreference_resolution.model.config import Config from reproduction.coreference_resolution.model.config import Config
from reproduction.coreference_resolution.model.model_re import Model from reproduction.coreference_resolution.model.model_re import Model
from reproduction.coreference_resolution.model.softmax_loss import SoftmaxLoss from reproduction.coreference_resolution.model.softmax_loss import SoftmaxLoss
from reproduction.coreference_resolution.model.metric import CRMetric from reproduction.coreference_resolution.model.metric import CRMetric
from fastNLP import SequentialSampler
from fastNLP import cache_results




# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True

class LRCallback(Callback): class LRCallback(Callback):
def __init__(self, parameters, decay_rate=1e-3): def __init__(self, parameters, decay_rate=1e-3):
super().__init__() super().__init__()
@@ -38,15 +31,13 @@ if __name__ == "__main__":


print(config) print(config)


# @cache_results('cache.pkl')
def cache(): def cache():
bundle = CoreferencePipe(config).process_from_file({'train': config.train_path, 'dev': config.dev_path,'test': config.test_path})
bundle = CoReferencePipe(config).process_from_file({'train': config.train_path, 'dev': config.dev_path,
'test': config.test_path})
return bundle return bundle
data_bundle = cache() data_bundle = cache()
print("数据集划分:\ntrain:", str(len(data_bundle.get_dataset("train"))),
"\ndev:" + str(len(data_bundle.get_dataset("dev"))) + "\ntest:" + str(len(data_bundle.get_dataset('test'))))
# print(data_info)
model = Model(data_bundle.get_vocab(Const.INPUT), config)
print(data_bundle)
model = Model(data_bundle.get_vocab(Const.INPUTS(0)), config)
print(model) print(model)


loss = SoftmaxLoss() loss = SoftmaxLoss()
@@ -59,9 +50,10 @@ if __name__ == "__main__":


trainer = Trainer(model=model, train_data=data_bundle.datasets["train"], dev_data=data_bundle.datasets["dev"], trainer = Trainer(model=model, train_data=data_bundle.datasets["train"], dev_data=data_bundle.datasets["dev"],
loss=loss, metrics=metric, check_code_level=-1, sampler=None, loss=loss, metrics=metric, check_code_level=-1, sampler=None,
batch_size=1, device=torch.device("cuda:" + config.cuda), metric_key='f', n_epochs=config.epoch,
batch_size=1, device=torch.device("cuda:" + config.cuda) if torch.cuda.is_available() else None,
metric_key='f', n_epochs=config.epoch,
optimizer=optim, optimizer=optim,
save_path= None,
save_path=None,
callbacks=[lr_decay_callback, GradientClipCallback(clip_value=5)]) callbacks=[lr_decay_callback, GradientClipCallback(clip_value=5)])
print() print()




+ 2
- 2
reproduction/coreference_resolution/valid.py View File

@@ -1,7 +1,7 @@
import torch import torch
from reproduction.coreference_resolution.model.config import Config from reproduction.coreference_resolution.model.config import Config
from reproduction.coreference_resolution.model.metric import CRMetric from reproduction.coreference_resolution.model.metric import CRMetric
from fastNLP.io.pipe.coreference import CoreferencePipe
from fastNLP.io.pipe.coreference import CoReferencePipe


from fastNLP import Tester from fastNLP import Tester
import argparse import argparse
@@ -13,7 +13,7 @@ if __name__=='__main__':
args = parser.parse_args() args = parser.parse_args()
config = Config() config = Config()
bundle = CoreferencePipe(Config()).process_from_file(
bundle = CoReferencePipe(Config()).process_from_file(
{'train': config.train_path, 'dev': config.dev_path, 'test': config.test_path}) {'train': config.train_path, 'dev': config.dev_path, 'test': config.test_path})
metirc = CRMetric() metirc = CRMetric()
model = torch.load(args.path) model = torch.load(args.path)


test/data_for_tests/coreference/coreference_dev.json → test/data_for_tests/io/coreference/coreference_dev.json View File


test/data_for_tests/coreference/coreference_test.json → test/data_for_tests/io/coreference/coreference_test.json View File


test/data_for_tests/coreference/coreference_train.json → test/data_for_tests/io/coreference/coreference_train.json View File


+ 16
- 6
test/io/loader/test_coreference_loader.py View File

@@ -1,16 +1,26 @@
from fastNLP.io.loader.coreference import CRLoader
from fastNLP.io.loader.coreference import CoReferenceLoader
import unittest import unittest



class TestCR(unittest.TestCase): class TestCR(unittest.TestCase):
def test_load(self): def test_load(self):


test_root = "test/data_for_tests/coreference/"
test_root = "test/data_for_tests/io/coreference/"
train_path = test_root+"coreference_train.json" train_path = test_root+"coreference_train.json"
dev_path = test_root+"coreference_dev.json" dev_path = test_root+"coreference_dev.json"
test_path = test_root+"coreference_test.json" test_path = test_root+"coreference_test.json"
paths = {"train": train_path,"dev":dev_path,"test":test_path}
paths = {"train": train_path, "dev": dev_path, "test": test_path}


bundle1 = CRLoader().load(paths)
bundle2 = CRLoader().load(test_root)
bundle1 = CoReferenceLoader().load(paths)
bundle2 = CoReferenceLoader().load(test_root)
print(bundle1) print(bundle1)
print(bundle2)
print(bundle2)

self.assertEqual(bundle1.num_dataset, 3)
self.assertEqual(bundle2.num_dataset, 3)
self.assertEqual(bundle1.num_vocab, 0)
self.assertEqual(bundle2.num_vocab, 0)

self.assertEqual(len(bundle1.get_dataset('train')), 1)
self.assertEqual(len(bundle1.get_dataset('dev')), 1)
self.assertEqual(len(bundle1.get_dataset('test')), 1)

+ 14
- 5
test/io/pipe/test_coreference.py View File

@@ -1,5 +1,5 @@
import unittest import unittest
from fastNLP.io.pipe.coreference import CoreferencePipe
from fastNLP.io.pipe.coreference import CoReferencePipe




class TestCR(unittest.TestCase): class TestCR(unittest.TestCase):
@@ -11,14 +11,23 @@ class TestCR(unittest.TestCase):
char_path = None char_path = None
config = Config() config = Config()


file_root_path = "test/data_for_tests/coreference/"
file_root_path = "test/data_for_tests/io/coreference/"
train_path = file_root_path + "coreference_train.json" train_path = file_root_path + "coreference_train.json"
dev_path = file_root_path + "coreference_dev.json" dev_path = file_root_path + "coreference_dev.json"
test_path = file_root_path + "coreference_test.json" test_path = file_root_path + "coreference_test.json"


paths = {"train": train_path, "dev": dev_path, "test": test_path} paths = {"train": train_path, "dev": dev_path, "test": test_path}


bundle1 = CoreferencePipe(config).process_from_file(paths)
bundle2 = CoreferencePipe(config).process_from_file(file_root_path)
bundle1 = CoReferencePipe(config).process_from_file(paths)
bundle2 = CoReferencePipe(config).process_from_file(file_root_path)
print(bundle1) print(bundle1)
print(bundle2)
print(bundle2)
self.assertEqual(bundle1.num_dataset, 3)
self.assertEqual(bundle2.num_dataset, 3)
self.assertEqual(bundle1.num_vocab, 1)
self.assertEqual(bundle2.num_vocab, 1)

self.assertEqual(len(bundle1.get_dataset('train')), 1)
self.assertEqual(len(bundle1.get_dataset('dev')), 1)
self.assertEqual(len(bundle1.get_dataset('test')), 1)
self.assertEqual(len(bundle1.get_vocab('words1')), 84)

Loading…
Cancel
Save