From b4e542095d34e3831a7f98b3d4e9e0a41e6e3f77 Mon Sep 17 00:00:00 2001 From: xxliu Date: Mon, 26 Aug 2019 19:21:35 +0800 Subject: [PATCH] pipe --- fastNLP/io/loader/__init__.py | 5 +- fastNLP/io/loader/coreference.py | 24 ++++ fastNLP/io/pipe/__init__.py | 3 + fastNLP/io/pipe/coreference.py | 115 ++++++++++++++++++ reproduction/coreference_resolution/README.md | 2 +- .../data_load/__init__.py | 0 .../data_load/cr_loader.py | 68 ----------- .../test/test_dataloader.py | 20 +-- reproduction/coreference_resolution/train.py | 10 +- reproduction/coreference_resolution/valid.py | 10 +- 10 files changed, 166 insertions(+), 91 deletions(-) create mode 100644 fastNLP/io/loader/coreference.py create mode 100644 fastNLP/io/pipe/coreference.py delete mode 100644 reproduction/coreference_resolution/data_load/__init__.py delete mode 100644 reproduction/coreference_resolution/data_load/cr_loader.py diff --git a/fastNLP/io/loader/__init__.py b/fastNLP/io/loader/__init__.py index 6c23f213..aae3171a 100644 --- a/fastNLP/io/loader/__init__.py +++ b/fastNLP/io/loader/__init__.py @@ -71,7 +71,9 @@ __all__ = [ "QuoraLoader", "SNLILoader", "QNLILoader", - "RTELoader" + "RTELoader", + + "CRLoader" ] from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader @@ -81,3 +83,4 @@ from .json import JsonLoader from .loader import Loader from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader from .conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader +from .coreference import CRLoader \ No newline at end of file diff --git a/fastNLP/io/loader/coreference.py b/fastNLP/io/loader/coreference.py new file mode 100644 index 00000000..c8d9bbf5 --- /dev/null +++ b/fastNLP/io/loader/coreference.py @@ -0,0 +1,24 @@ +from ...core.dataset import DataSet +from ..file_reader import _read_json +from ...core.instance import Instance +from .json import JsonLoader + + +class CRLoader(JsonLoader): + def __init__(self, fields=None, dropna=False): + super().__init__(fields, dropna) + + def _load(self, path): + """ + 加载数据 + :param path: + :return: + """ + dataset = DataSet() + for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): + if self.fields: + ins = {self.fields[k]: v for k, v in d.items()} + else: + ins = d + dataset.append(Instance(**ins)) + return dataset \ No newline at end of file diff --git a/fastNLP/io/pipe/__init__.py b/fastNLP/io/pipe/__init__.py index 048e4cfe..d99b68c4 100644 --- a/fastNLP/io/pipe/__init__.py +++ b/fastNLP/io/pipe/__init__.py @@ -37,6 +37,8 @@ __all__ = [ "QuoraPipe", "QNLIPipe", "MNLIPipe", + + "CoreferencePipe" ] from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe @@ -46,3 +48,4 @@ from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe from .pipe import Pipe from .conll import Conll2003Pipe from .cws import CWSPipe +from .coreference import CoreferencePipe diff --git a/fastNLP/io/pipe/coreference.py b/fastNLP/io/pipe/coreference.py new file mode 100644 index 00000000..bdf6a132 --- /dev/null +++ b/fastNLP/io/pipe/coreference.py @@ -0,0 +1,115 @@ +__all__ = [ + "CoreferencePipe" + +] + +from .pipe import Pipe +from ..data_bundle import DataBundle +from ..loader.coreference import CRLoader +from fastNLP.core.vocabulary import Vocabulary +import numpy as np +import collections + + +class CoreferencePipe(Pipe): + + def __init__(self,config): + super().__init__() + self.config = config + + def process(self, data_bundle: DataBundle): + genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])} + vocab = Vocabulary().from_dataset(*data_bundle.datasets.values(), field_name='sentences') + vocab.build_vocab() + word2id = vocab.word2idx + char_dict = get_char_dict(self.config.char_path) + for name, ds in data_bundle.datasets.items(): + ds.apply(lambda x: doc2numpy(x['sentences'], word2id, char_dict, max(self.config.filter), + self.config.max_sentences, is_train=name == 'train')[0], + new_field_name='doc_np') + ds.apply(lambda x: doc2numpy(x['sentences'], word2id, char_dict, max(self.config.filter), + self.config.max_sentences, is_train=name == 'train')[1], + new_field_name='char_index') + ds.apply(lambda x: doc2numpy(x['sentences'], word2id, char_dict, max(self.config.filter), + self.config.max_sentences, is_train=name == 'train')[2], + new_field_name='seq_len') + ds.apply(lambda x: speaker2numpy(x["speakers"], self.config.max_sentences, is_train=name == 'train'), + new_field_name='speaker_ids_np') + ds.apply(lambda x: genres[x["doc_key"][:2]], new_field_name='genre') + + ds.set_ignore_type('clusters') + ds.set_padder('clusters', None) + ds.set_input("sentences", "doc_np", "speaker_ids_np", "genre", "char_index", "seq_len") + ds.set_target("clusters") + return data_bundle + + def process_from_file(self, paths): + bundle = CRLoader().load(paths) + return self.process(bundle) + + +# helper + +def doc2numpy(doc, word2id, chardict, max_filter, max_sentences, is_train): + docvec, char_index, length, max_len = _doc2vec(doc, word2id, chardict, max_filter, max_sentences, is_train) + assert max(length) == max_len + assert char_index.shape[0] == len(length) + assert char_index.shape[1] == max_len + doc_np = np.zeros((len(docvec), max_len), int) + for i in range(len(docvec)): + for j in range(len(docvec[i])): + doc_np[i][j] = docvec[i][j] + return doc_np, char_index, length + +def _doc2vec(doc,word2id,char_dict,max_filter,max_sentences,is_train): + max_len = 0 + max_word_length = 0 + docvex = [] + length = [] + if is_train: + sent_num = min(max_sentences,len(doc)) + else: + sent_num = len(doc) + + for i in range(sent_num): + sent = doc[i] + length.append(len(sent)) + if (len(sent) > max_len): + max_len = len(sent) + sent_vec =[] + for j,word in enumerate(sent): + if len(word)>max_word_length: + max_word_length = len(word) + if word in word2id: + sent_vec.append(word2id[word]) + else: + sent_vec.append(word2id["UNK"]) + docvex.append(sent_vec) + + char_index = np.zeros((sent_num, max_len, max_word_length),dtype=int) + for i in range(sent_num): + sent = doc[i] + for j,word in enumerate(sent): + char_index[i, j, :len(word)] = [char_dict[c] for c in word] + + return docvex,char_index,length,max_len + +def speaker2numpy(speakers_raw,max_sentences,is_train): + if is_train and len(speakers_raw)> max_sentences: + speakers_raw = speakers_raw[0:max_sentences] + speakers = flatten(speakers_raw) + speaker_dict = {s: i for i, s in enumerate(set(speakers))} + speaker_ids = np.array([speaker_dict[s] for s in speakers]) + return speaker_ids + +# 展平 +def flatten(l): + return [item for sublist in l for item in sublist] + +def get_char_dict(path): + vocab = [""] + with open(path) as f: + vocab.extend(c.strip() for c in f.readlines()) + char_dict = collections.defaultdict(int) + char_dict.update({c: i for i, c in enumerate(vocab)}) + return char_dict \ No newline at end of file diff --git a/reproduction/coreference_resolution/README.md b/reproduction/coreference_resolution/README.md index 7cbcd052..c1a286e5 100644 --- a/reproduction/coreference_resolution/README.md +++ b/reproduction/coreference_resolution/README.md @@ -1,4 +1,4 @@ -# 共指消解复现 +# 指代消解复现 ## 介绍 Coreference resolution是查找文本中指向同一现实实体的所有表达式的任务。 对于涉及自然语言理解的许多更高级别的NLP任务来说, diff --git a/reproduction/coreference_resolution/data_load/__init__.py b/reproduction/coreference_resolution/data_load/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/reproduction/coreference_resolution/data_load/cr_loader.py b/reproduction/coreference_resolution/data_load/cr_loader.py deleted file mode 100644 index 5ed73473..00000000 --- a/reproduction/coreference_resolution/data_load/cr_loader.py +++ /dev/null @@ -1,68 +0,0 @@ -from fastNLP.io.dataset_loader import JsonLoader,DataSet,Instance -from fastNLP.io.file_reader import _read_json -from fastNLP.core.vocabulary import Vocabulary -from fastNLP.io.data_bundle import DataBundle -from reproduction.coreference_resolution.model.config import Config -import reproduction.coreference_resolution.model.preprocess as preprocess - - -class CRLoader(JsonLoader): - def __init__(self, fields=None, dropna=False): - super().__init__(fields, dropna) - - def _load(self, path): - """ - 加载数据 - :param path: - :return: - """ - dataset = DataSet() - for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): - if self.fields: - ins = {self.fields[k]: v for k, v in d.items()} - else: - ins = d - dataset.append(Instance(**ins)) - return dataset - - def process(self, paths, **kwargs): - data_info = DataBundle() - for name in ['train', 'test', 'dev']: - data_info.datasets[name] = self.load(paths[name]) - - config = Config() - vocab = Vocabulary().from_dataset(*data_info.datasets.values(), field_name='sentences') - vocab.build_vocab() - word2id = vocab.word2idx - - char_dict = preprocess.get_char_dict(config.char_path) - data_info.vocabs = vocab - - genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])} - - for name, ds in data_info.datasets.items(): - ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter), - config.max_sentences, is_train=name=='train')[0], - new_field_name='doc_np') - ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter), - config.max_sentences, is_train=name=='train')[1], - new_field_name='char_index') - ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter), - config.max_sentences, is_train=name=='train')[2], - new_field_name='seq_len') - ds.apply(lambda x: preprocess.speaker2numpy(x["speakers"], config.max_sentences, is_train=name=='train'), - new_field_name='speaker_ids_np') - ds.apply(lambda x: genres[x["doc_key"][:2]], new_field_name='genre') - - ds.set_ignore_type('clusters') - ds.set_padder('clusters', None) - ds.set_input("sentences", "doc_np", "speaker_ids_np", "genre", "char_index", "seq_len") - ds.set_target("clusters") - - # train_dev, test = self.ds.split(348 / (2802 + 343 + 348), shuffle=False) - # train, dev = train_dev.split(343 / (2802 + 343), shuffle=False) - - return data_info - - - diff --git a/reproduction/coreference_resolution/test/test_dataloader.py b/reproduction/coreference_resolution/test/test_dataloader.py index 0d9dae52..6a3be520 100644 --- a/reproduction/coreference_resolution/test/test_dataloader.py +++ b/reproduction/coreference_resolution/test/test_dataloader.py @@ -1,14 +1,14 @@ + + import unittest -from ..data_load.cr_loader import CRLoader +from fastNLP.io.pipe.coreference import CoreferencePipe +from reproduction.coreference_resolution.model.config import Config class Test_CRLoader(unittest.TestCase): def test_cr_loader(self): - train_path = 'data/train.english.jsonlines.mini' - dev_path = 'data/dev.english.jsonlines.minid' - test_path = 'data/test.english.jsonlines' - cr = CRLoader() - data_info = cr.process({'train':train_path,'dev':dev_path,'test':test_path}) - - print(data_info.datasets['train'][0]) - print(data_info.datasets['dev'][0]) - print(data_info.datasets['test'][0]) + config = Config() + bundle = CoreferencePipe(config).process_from_file({'train': config.train_path, 'dev': config.dev_path,'test': config.test_path}) + + print(bundle.datasets['train'][0]) + print(bundle.datasets['dev'][0]) + print(bundle.datasets['test'][0]) diff --git a/reproduction/coreference_resolution/train.py b/reproduction/coreference_resolution/train.py index a231a575..6c26cf4c 100644 --- a/reproduction/coreference_resolution/train.py +++ b/reproduction/coreference_resolution/train.py @@ -7,7 +7,8 @@ from torch.optim import Adam from fastNLP.core.callback import Callback, GradientClipCallback from fastNLP.core.trainer import Trainer -from reproduction.coreference_resolution.data_load.cr_loader import CRLoader +from fastNLP.io.pipe.coreference import CoreferencePipe + from reproduction.coreference_resolution.model.config import Config from reproduction.coreference_resolution.model.model_re import Model from reproduction.coreference_resolution.model.softmax_loss import SoftmaxLoss @@ -38,11 +39,8 @@ if __name__ == "__main__": @cache_results('cache.pkl') def cache(): - cr_train_dev_test = CRLoader() - - data_info = cr_train_dev_test.process({'train': config.train_path, 'dev': config.dev_path, - 'test': config.test_path}) - return data_info + bundle = CoreferencePipe(Config()).process_from_file({'train': config.train_path, 'dev': config.dev_path,'test': config.test_path}) + return bundle data_info = cache() print("数据集划分:\ntrain:", str(len(data_info.datasets["train"])), "\ndev:" + str(len(data_info.datasets["dev"])) + "\ntest:" + str(len(data_info.datasets["test"]))) diff --git a/reproduction/coreference_resolution/valid.py b/reproduction/coreference_resolution/valid.py index 826332c6..454629e1 100644 --- a/reproduction/coreference_resolution/valid.py +++ b/reproduction/coreference_resolution/valid.py @@ -1,7 +1,8 @@ import torch from reproduction.coreference_resolution.model.config import Config from reproduction.coreference_resolution.model.metric import CRMetric -from reproduction.coreference_resolution.data_load.cr_loader import CRLoader +from fastNLP.io.pipe.coreference import CoreferencePipe + from fastNLP import Tester import argparse @@ -11,13 +12,12 @@ if __name__=='__main__': parser.add_argument('--path') args = parser.parse_args() - cr_loader = CRLoader() config = Config() - data_info = cr_loader.process({'train': config.train_path, 'dev': config.dev_path, - 'test': config.test_path}) + bundle = CoreferencePipe(Config()).process_from_file( + {'train': config.train_path, 'dev': config.dev_path, 'test': config.test_path}) metirc = CRMetric() model = torch.load(args.path) - tester = Tester(data_info.datasets['test'],model,metirc,batch_size=1,device="cuda:0") + tester = Tester(bundle.datasets['test'],model,metirc,batch_size=1,device="cuda:0") tester.test() print('test over')