Browse Source

pipe

tags/v0.4.10
xxliu 5 years ago
parent
commit
b4e542095d
10 changed files with 166 additions and 91 deletions
  1. +4
    -1
      fastNLP/io/loader/__init__.py
  2. +24
    -0
      fastNLP/io/loader/coreference.py
  3. +3
    -0
      fastNLP/io/pipe/__init__.py
  4. +115
    -0
      fastNLP/io/pipe/coreference.py
  5. +1
    -1
      reproduction/coreference_resolution/README.md
  6. +0
    -0
      reproduction/coreference_resolution/data_load/__init__.py
  7. +0
    -68
      reproduction/coreference_resolution/data_load/cr_loader.py
  8. +10
    -10
      reproduction/coreference_resolution/test/test_dataloader.py
  9. +4
    -6
      reproduction/coreference_resolution/train.py
  10. +5
    -5
      reproduction/coreference_resolution/valid.py

+ 4
- 1
fastNLP/io/loader/__init__.py View File

@@ -71,7 +71,9 @@ __all__ = [
"QuoraLoader",
"SNLILoader",
"QNLILoader",
"RTELoader"
"RTELoader",

"CRLoader"
]
from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader
from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader
@@ -81,3 +83,4 @@ from .json import JsonLoader
from .loader import Loader
from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader
from .conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader
from .coreference import CRLoader

+ 24
- 0
fastNLP/io/loader/coreference.py View File

@@ -0,0 +1,24 @@
from ...core.dataset import DataSet
from ..file_reader import _read_json
from ...core.instance import Instance
from .json import JsonLoader


class CRLoader(JsonLoader):
def __init__(self, fields=None, dropna=False):
super().__init__(fields, dropna)

def _load(self, path):
"""
加载数据
:param path:
:return:
"""
dataset = DataSet()
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna):
if self.fields:
ins = {self.fields[k]: v for k, v in d.items()}
else:
ins = d
dataset.append(Instance(**ins))
return dataset

+ 3
- 0
fastNLP/io/pipe/__init__.py View File

@@ -37,6 +37,8 @@ __all__ = [
"QuoraPipe",
"QNLIPipe",
"MNLIPipe",

"CoreferencePipe"
]

from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe
@@ -46,3 +48,4 @@ from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe
from .pipe import Pipe
from .conll import Conll2003Pipe
from .cws import CWSPipe
from .coreference import CoreferencePipe

+ 115
- 0
fastNLP/io/pipe/coreference.py View File

@@ -0,0 +1,115 @@
__all__ = [
"CoreferencePipe"

]

from .pipe import Pipe
from ..data_bundle import DataBundle
from ..loader.coreference import CRLoader
from fastNLP.core.vocabulary import Vocabulary
import numpy as np
import collections


class CoreferencePipe(Pipe):

def __init__(self,config):
super().__init__()
self.config = config

def process(self, data_bundle: DataBundle):
genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])}
vocab = Vocabulary().from_dataset(*data_bundle.datasets.values(), field_name='sentences')
vocab.build_vocab()
word2id = vocab.word2idx
char_dict = get_char_dict(self.config.char_path)
for name, ds in data_bundle.datasets.items():
ds.apply(lambda x: doc2numpy(x['sentences'], word2id, char_dict, max(self.config.filter),
self.config.max_sentences, is_train=name == 'train')[0],
new_field_name='doc_np')
ds.apply(lambda x: doc2numpy(x['sentences'], word2id, char_dict, max(self.config.filter),
self.config.max_sentences, is_train=name == 'train')[1],
new_field_name='char_index')
ds.apply(lambda x: doc2numpy(x['sentences'], word2id, char_dict, max(self.config.filter),
self.config.max_sentences, is_train=name == 'train')[2],
new_field_name='seq_len')
ds.apply(lambda x: speaker2numpy(x["speakers"], self.config.max_sentences, is_train=name == 'train'),
new_field_name='speaker_ids_np')
ds.apply(lambda x: genres[x["doc_key"][:2]], new_field_name='genre')

ds.set_ignore_type('clusters')
ds.set_padder('clusters', None)
ds.set_input("sentences", "doc_np", "speaker_ids_np", "genre", "char_index", "seq_len")
ds.set_target("clusters")
return data_bundle

def process_from_file(self, paths):
bundle = CRLoader().load(paths)
return self.process(bundle)


# helper

def doc2numpy(doc, word2id, chardict, max_filter, max_sentences, is_train):
docvec, char_index, length, max_len = _doc2vec(doc, word2id, chardict, max_filter, max_sentences, is_train)
assert max(length) == max_len
assert char_index.shape[0] == len(length)
assert char_index.shape[1] == max_len
doc_np = np.zeros((len(docvec), max_len), int)
for i in range(len(docvec)):
for j in range(len(docvec[i])):
doc_np[i][j] = docvec[i][j]
return doc_np, char_index, length

def _doc2vec(doc,word2id,char_dict,max_filter,max_sentences,is_train):
max_len = 0
max_word_length = 0
docvex = []
length = []
if is_train:
sent_num = min(max_sentences,len(doc))
else:
sent_num = len(doc)

for i in range(sent_num):
sent = doc[i]
length.append(len(sent))
if (len(sent) > max_len):
max_len = len(sent)
sent_vec =[]
for j,word in enumerate(sent):
if len(word)>max_word_length:
max_word_length = len(word)
if word in word2id:
sent_vec.append(word2id[word])
else:
sent_vec.append(word2id["UNK"])
docvex.append(sent_vec)

char_index = np.zeros((sent_num, max_len, max_word_length),dtype=int)
for i in range(sent_num):
sent = doc[i]
for j,word in enumerate(sent):
char_index[i, j, :len(word)] = [char_dict[c] for c in word]

return docvex,char_index,length,max_len

def speaker2numpy(speakers_raw,max_sentences,is_train):
if is_train and len(speakers_raw)> max_sentences:
speakers_raw = speakers_raw[0:max_sentences]
speakers = flatten(speakers_raw)
speaker_dict = {s: i for i, s in enumerate(set(speakers))}
speaker_ids = np.array([speaker_dict[s] for s in speakers])
return speaker_ids

# 展平
def flatten(l):
return [item for sublist in l for item in sublist]

def get_char_dict(path):
vocab = ["<UNK>"]
with open(path) as f:
vocab.extend(c.strip() for c in f.readlines())
char_dict = collections.defaultdict(int)
char_dict.update({c: i for i, c in enumerate(vocab)})
return char_dict

+ 1
- 1
reproduction/coreference_resolution/README.md View File

@@ -1,4 +1,4 @@
# 指消解复现
# 指消解复现
## 介绍
Coreference resolution是查找文本中指向同一现实实体的所有表达式的任务。
对于涉及自然语言理解的许多更高级别的NLP任务来说,


+ 0
- 0
reproduction/coreference_resolution/data_load/__init__.py View File


+ 0
- 68
reproduction/coreference_resolution/data_load/cr_loader.py View File

@@ -1,68 +0,0 @@
from fastNLP.io.dataset_loader import JsonLoader,DataSet,Instance
from fastNLP.io.file_reader import _read_json
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.io.data_bundle import DataBundle
from reproduction.coreference_resolution.model.config import Config
import reproduction.coreference_resolution.model.preprocess as preprocess


class CRLoader(JsonLoader):
def __init__(self, fields=None, dropna=False):
super().__init__(fields, dropna)

def _load(self, path):
"""
加载数据
:param path:
:return:
"""
dataset = DataSet()
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna):
if self.fields:
ins = {self.fields[k]: v for k, v in d.items()}
else:
ins = d
dataset.append(Instance(**ins))
return dataset

def process(self, paths, **kwargs):
data_info = DataBundle()
for name in ['train', 'test', 'dev']:
data_info.datasets[name] = self.load(paths[name])

config = Config()
vocab = Vocabulary().from_dataset(*data_info.datasets.values(), field_name='sentences')
vocab.build_vocab()
word2id = vocab.word2idx

char_dict = preprocess.get_char_dict(config.char_path)
data_info.vocabs = vocab

genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])}

for name, ds in data_info.datasets.items():
ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter),
config.max_sentences, is_train=name=='train')[0],
new_field_name='doc_np')
ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter),
config.max_sentences, is_train=name=='train')[1],
new_field_name='char_index')
ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter),
config.max_sentences, is_train=name=='train')[2],
new_field_name='seq_len')
ds.apply(lambda x: preprocess.speaker2numpy(x["speakers"], config.max_sentences, is_train=name=='train'),
new_field_name='speaker_ids_np')
ds.apply(lambda x: genres[x["doc_key"][:2]], new_field_name='genre')

ds.set_ignore_type('clusters')
ds.set_padder('clusters', None)
ds.set_input("sentences", "doc_np", "speaker_ids_np", "genre", "char_index", "seq_len")
ds.set_target("clusters")

# train_dev, test = self.ds.split(348 / (2802 + 343 + 348), shuffle=False)
# train, dev = train_dev.split(343 / (2802 + 343), shuffle=False)

return data_info




+ 10
- 10
reproduction/coreference_resolution/test/test_dataloader.py View File

@@ -1,14 +1,14 @@


import unittest
from ..data_load.cr_loader import CRLoader
from fastNLP.io.pipe.coreference import CoreferencePipe
from reproduction.coreference_resolution.model.config import Config

class Test_CRLoader(unittest.TestCase):
def test_cr_loader(self):
train_path = 'data/train.english.jsonlines.mini'
dev_path = 'data/dev.english.jsonlines.minid'
test_path = 'data/test.english.jsonlines'
cr = CRLoader()
data_info = cr.process({'train':train_path,'dev':dev_path,'test':test_path})

print(data_info.datasets['train'][0])
print(data_info.datasets['dev'][0])
print(data_info.datasets['test'][0])
config = Config()
bundle = CoreferencePipe(config).process_from_file({'train': config.train_path, 'dev': config.dev_path,'test': config.test_path})

print(bundle.datasets['train'][0])
print(bundle.datasets['dev'][0])
print(bundle.datasets['test'][0])

+ 4
- 6
reproduction/coreference_resolution/train.py View File

@@ -7,7 +7,8 @@ from torch.optim import Adam
from fastNLP.core.callback import Callback, GradientClipCallback
from fastNLP.core.trainer import Trainer

from reproduction.coreference_resolution.data_load.cr_loader import CRLoader
from fastNLP.io.pipe.coreference import CoreferencePipe

from reproduction.coreference_resolution.model.config import Config
from reproduction.coreference_resolution.model.model_re import Model
from reproduction.coreference_resolution.model.softmax_loss import SoftmaxLoss
@@ -38,11 +39,8 @@ if __name__ == "__main__":

@cache_results('cache.pkl')
def cache():
cr_train_dev_test = CRLoader()

data_info = cr_train_dev_test.process({'train': config.train_path, 'dev': config.dev_path,
'test': config.test_path})
return data_info
bundle = CoreferencePipe(Config()).process_from_file({'train': config.train_path, 'dev': config.dev_path,'test': config.test_path})
return bundle
data_info = cache()
print("数据集划分:\ntrain:", str(len(data_info.datasets["train"])),
"\ndev:" + str(len(data_info.datasets["dev"])) + "\ntest:" + str(len(data_info.datasets["test"])))


+ 5
- 5
reproduction/coreference_resolution/valid.py View File

@@ -1,7 +1,8 @@
import torch
from reproduction.coreference_resolution.model.config import Config
from reproduction.coreference_resolution.model.metric import CRMetric
from reproduction.coreference_resolution.data_load.cr_loader import CRLoader
from fastNLP.io.pipe.coreference import CoreferencePipe

from fastNLP import Tester
import argparse

@@ -11,13 +12,12 @@ if __name__=='__main__':
parser.add_argument('--path')
args = parser.parse_args()
cr_loader = CRLoader()
config = Config()
data_info = cr_loader.process({'train': config.train_path, 'dev': config.dev_path,
'test': config.test_path})
bundle = CoreferencePipe(Config()).process_from_file(
{'train': config.train_path, 'dev': config.dev_path, 'test': config.test_path})
metirc = CRMetric()
model = torch.load(args.path)
tester = Tester(data_info.datasets['test'],model,metirc,batch_size=1,device="cuda:0")
tester = Tester(bundle.datasets['test'],model,metirc,batch_size=1,device="cuda:0")
tester.test()
print('test over')



Loading…
Cancel
Save