Browse Source

fix code style in coreference task and related codes

tags/v0.4.10
Yige Xu 5 years ago
parent
commit
753327d214
11 changed files with 57 additions and 44 deletions
  1. +2
    -2
      fastNLP/io/loader/__init__.py
  2. +8
    -4
      fastNLP/io/loader/coreference.py
  3. +2
    -2
      fastNLP/io/pipe/__init__.py
  4. +5
    -7
      fastNLP/io/pipe/coreference.py
  5. +8
    -16
      reproduction/coreference_resolution/train.py
  6. +2
    -2
      reproduction/coreference_resolution/valid.py
  7. +0
    -0
      test/data_for_tests/io/coreference/coreference_dev.json
  8. +0
    -0
      test/data_for_tests/io/coreference/coreference_test.json
  9. +0
    -0
      test/data_for_tests/io/coreference/coreference_train.json
  10. +16
    -6
      test/io/loader/test_coreference_loader.py
  11. +14
    -5
      test/io/pipe/test_coreference.py

+ 2
- 2
fastNLP/io/loader/__init__.py View File

@@ -74,7 +74,7 @@ __all__ = [
"QNLILoader",
"RTELoader",

"CRLoader"
"CoReferenceLoader"
]
from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader
from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader
@@ -84,4 +84,4 @@ from .json import JsonLoader
from .loader import Loader
from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader
from .conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader
from .coreference import CRLoader
from .coreference import CoReferenceLoader

+ 8
- 4
fastNLP/io/loader/coreference.py View File

@@ -1,5 +1,9 @@
"""undocumented"""

__all__ = [
"CoReferenceLoader",
]

from ...core.dataset import DataSet
from ..file_reader import _read_json
from ...core.instance import Instance
@@ -7,7 +11,7 @@ from ...core.const import Const
from .json import JsonLoader


class CRLoader(JsonLoader):
class CoReferenceLoader(JsonLoader):
"""
原始数据中内容应该为, 每一行为一个json对象,其中doc_key包含文章的种类信息,speakers包含每句话的说话者信息,cluster是指向现实中同一个事物的聚集,sentences是文本信息内容。

@@ -24,8 +28,8 @@ class CRLoader(JsonLoader):
"""
def __init__(self, fields=None, dropna=False):
super().__init__(fields, dropna)
# self.fields = {"doc_key":Const.INPUTS(0),"speakers":Const.INPUTS(1),"clusters":Const.TARGET,"sentences":Const.INPUTS(2)}
# TODO check 1
# self.fields = {"doc_key":Const.INPUTS(0),"speakers":Const.INPUTS(1),
# "clusters":Const.TARGET,"sentences":Const.INPUTS(2)}
self.fields = {"doc_key": Const.RAW_WORDS(0), "speakers": Const.RAW_WORDS(1), "clusters": Const.RAW_WORDS(2),
"sentences": Const.RAW_WORDS(3)}

@@ -43,4 +47,4 @@ class CRLoader(JsonLoader):
else:
ins = d
dataset.append(Instance(**ins))
return dataset
return dataset

+ 2
- 2
fastNLP/io/pipe/__init__.py View File

@@ -39,7 +39,7 @@ __all__ = [
"QNLIPipe",
"MNLIPipe",

"CoreferencePipe"
"CoReferencePipe"
]

from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe
@@ -49,4 +49,4 @@ from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe
from .pipe import Pipe
from .conll import Conll2003Pipe
from .cws import CWSPipe
from .coreference import CoreferencePipe
from .coreference import CoReferencePipe

+ 5
- 7
fastNLP/io/pipe/coreference.py View File

@@ -1,8 +1,7 @@
"""undocumented"""

__all__ = [
"CoreferencePipe"

"CoReferencePipe"
]

import collections
@@ -12,11 +11,11 @@ import numpy as np
from fastNLP.core.vocabulary import Vocabulary
from .pipe import Pipe
from ..data_bundle import DataBundle
from ..loader.coreference import CRLoader
from ..loader.coreference import CoReferenceLoader
from ...core.const import Const


class CoreferencePipe(Pipe):
class CoReferencePipe(Pipe):
"""
对Coreference resolution问题进行处理,得到文章种类/说话者/字符级信息/序列长度。
"""
@@ -52,7 +51,7 @@ class CoreferencePipe(Pipe):
vocab = Vocabulary().from_dataset(*data_bundle.datasets.values(), field_name= Const.RAW_WORDS(3))
vocab.build_vocab()
word2id = vocab.word2idx
data_bundle.set_vocab(vocab,Const.INPUT)
data_bundle.set_vocab(vocab, Const.INPUTS(0))
if self.config.char_path:
char_dict = get_char_dict(self.config.char_path)
else:
@@ -93,7 +92,6 @@ class CoreferencePipe(Pipe):
# clusters
ds.rename_field(Const.RAW_WORDS(2), Const.TARGET)


ds.set_ignore_type(Const.TARGET)
ds.set_padder(Const.TARGET, None)
ds.set_input(Const.INPUTS(0), Const.INPUTS(1), Const.INPUTS(2), Const.INPUTS(3), Const.CHAR_INPUT, Const.INPUT_LEN)
@@ -102,7 +100,7 @@ class CoreferencePipe(Pipe):
return data_bundle

def process_from_file(self, paths):
bundle = CRLoader().load(paths)
bundle = CoReferenceLoader().load(paths)
return self.process(bundle)




+ 8
- 16
reproduction/coreference_resolution/train.py View File

@@ -1,5 +1,3 @@
import sys
sys.path.append('../..')

import torch
from torch.optim import Adam
@@ -7,20 +5,15 @@ from torch.optim import Adam
from fastNLP.core.callback import Callback, GradientClipCallback
from fastNLP.core.trainer import Trainer

from fastNLP.io.pipe.coreference import CoreferencePipe
from fastNLP.io.pipe.coreference import CoReferencePipe
from fastNLP.core.const import Const

from reproduction.coreference_resolution.model.config import Config
from reproduction.coreference_resolution.model.model_re import Model
from reproduction.coreference_resolution.model.softmax_loss import SoftmaxLoss
from reproduction.coreference_resolution.model.metric import CRMetric
from fastNLP import SequentialSampler
from fastNLP import cache_results


# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True

class LRCallback(Callback):
def __init__(self, parameters, decay_rate=1e-3):
super().__init__()
@@ -38,15 +31,13 @@ if __name__ == "__main__":

print(config)

# @cache_results('cache.pkl')
def cache():
bundle = CoreferencePipe(config).process_from_file({'train': config.train_path, 'dev': config.dev_path,'test': config.test_path})
bundle = CoReferencePipe(config).process_from_file({'train': config.train_path, 'dev': config.dev_path,
'test': config.test_path})
return bundle
data_bundle = cache()
print("数据集划分:\ntrain:", str(len(data_bundle.get_dataset("train"))),
"\ndev:" + str(len(data_bundle.get_dataset("dev"))) + "\ntest:" + str(len(data_bundle.get_dataset('test'))))
# print(data_info)
model = Model(data_bundle.get_vocab(Const.INPUT), config)
print(data_bundle)
model = Model(data_bundle.get_vocab(Const.INPUTS(0)), config)
print(model)

loss = SoftmaxLoss()
@@ -59,9 +50,10 @@ if __name__ == "__main__":

trainer = Trainer(model=model, train_data=data_bundle.datasets["train"], dev_data=data_bundle.datasets["dev"],
loss=loss, metrics=metric, check_code_level=-1, sampler=None,
batch_size=1, device=torch.device("cuda:" + config.cuda), metric_key='f', n_epochs=config.epoch,
batch_size=1, device=torch.device("cuda:" + config.cuda) if torch.cuda.is_available() else None,
metric_key='f', n_epochs=config.epoch,
optimizer=optim,
save_path= None,
save_path=None,
callbacks=[lr_decay_callback, GradientClipCallback(clip_value=5)])
print()



+ 2
- 2
reproduction/coreference_resolution/valid.py View File

@@ -1,7 +1,7 @@
import torch
from reproduction.coreference_resolution.model.config import Config
from reproduction.coreference_resolution.model.metric import CRMetric
from fastNLP.io.pipe.coreference import CoreferencePipe
from fastNLP.io.pipe.coreference import CoReferencePipe

from fastNLP import Tester
import argparse
@@ -13,7 +13,7 @@ if __name__=='__main__':
args = parser.parse_args()
config = Config()
bundle = CoreferencePipe(Config()).process_from_file(
bundle = CoReferencePipe(Config()).process_from_file(
{'train': config.train_path, 'dev': config.dev_path, 'test': config.test_path})
metirc = CRMetric()
model = torch.load(args.path)


test/data_for_tests/coreference/coreference_dev.json → test/data_for_tests/io/coreference/coreference_dev.json View File


test/data_for_tests/coreference/coreference_test.json → test/data_for_tests/io/coreference/coreference_test.json View File


test/data_for_tests/coreference/coreference_train.json → test/data_for_tests/io/coreference/coreference_train.json View File


+ 16
- 6
test/io/loader/test_coreference_loader.py View File

@@ -1,16 +1,26 @@
from fastNLP.io.loader.coreference import CRLoader
from fastNLP.io.loader.coreference import CoReferenceLoader
import unittest


class TestCR(unittest.TestCase):
def test_load(self):

test_root = "test/data_for_tests/coreference/"
test_root = "test/data_for_tests/io/coreference/"
train_path = test_root+"coreference_train.json"
dev_path = test_root+"coreference_dev.json"
test_path = test_root+"coreference_test.json"
paths = {"train": train_path,"dev":dev_path,"test":test_path}
paths = {"train": train_path, "dev": dev_path, "test": test_path}

bundle1 = CRLoader().load(paths)
bundle2 = CRLoader().load(test_root)
bundle1 = CoReferenceLoader().load(paths)
bundle2 = CoReferenceLoader().load(test_root)
print(bundle1)
print(bundle2)
print(bundle2)

self.assertEqual(bundle1.num_dataset, 3)
self.assertEqual(bundle2.num_dataset, 3)
self.assertEqual(bundle1.num_vocab, 0)
self.assertEqual(bundle2.num_vocab, 0)

self.assertEqual(len(bundle1.get_dataset('train')), 1)
self.assertEqual(len(bundle1.get_dataset('dev')), 1)
self.assertEqual(len(bundle1.get_dataset('test')), 1)

+ 14
- 5
test/io/pipe/test_coreference.py View File

@@ -1,5 +1,5 @@
import unittest
from fastNLP.io.pipe.coreference import CoreferencePipe
from fastNLP.io.pipe.coreference import CoReferencePipe


class TestCR(unittest.TestCase):
@@ -11,14 +11,23 @@ class TestCR(unittest.TestCase):
char_path = None
config = Config()

file_root_path = "test/data_for_tests/coreference/"
file_root_path = "test/data_for_tests/io/coreference/"
train_path = file_root_path + "coreference_train.json"
dev_path = file_root_path + "coreference_dev.json"
test_path = file_root_path + "coreference_test.json"

paths = {"train": train_path, "dev": dev_path, "test": test_path}

bundle1 = CoreferencePipe(config).process_from_file(paths)
bundle2 = CoreferencePipe(config).process_from_file(file_root_path)
bundle1 = CoReferencePipe(config).process_from_file(paths)
bundle2 = CoReferencePipe(config).process_from_file(file_root_path)
print(bundle1)
print(bundle2)
print(bundle2)
self.assertEqual(bundle1.num_dataset, 3)
self.assertEqual(bundle2.num_dataset, 3)
self.assertEqual(bundle1.num_vocab, 1)
self.assertEqual(bundle2.num_vocab, 1)

self.assertEqual(len(bundle1.get_dataset('train')), 1)
self.assertEqual(len(bundle1.get_dataset('dev')), 1)
self.assertEqual(len(bundle1.get_dataset('test')), 1)
self.assertEqual(len(bundle1.get_vocab('words1')), 84)

Loading…
Cancel
Save