Browse Source

update matching dataloader in reproduction/matching

tags/v0.4.10
xuyige 5 years ago
parent
commit
d1f531c049
4 changed files with 221 additions and 175 deletions
  1. +0
    -167
      fastNLP/io/dataset_loader.py
  2. +219
    -0
      reproduction/matching/data/MatchingDataLoader.py
  3. +0
    -6
      reproduction/matching/data/SNLIDataLoader.py
  4. +2
    -2
      reproduction/matching/test/test_snlidataloader.py

+ 0
- 167
fastNLP/io/dataset_loader.py View File

@@ -250,173 +250,6 @@ class JsonLoader(DataSetLoader):
return ds


class MatchingLoader(DataSetLoader):
"""
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader`

读取Matching数据集,根据数据集做预处理并返回DataInfo。

数据来源:
SNLI: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
"""

def __init__(self, data_format: str='snli', for_model: str='esim', bert_dir=None):
super(MatchingLoader, self).__init__()
self.data_format = data_format.lower()
self.for_model = for_model.lower()
self.bert_dir = bert_dir

def _load(self, path: str) -> DataSet:
raise NotImplementedError

def process(self, paths: Union[str, Dict[str, str]], input_field=None) -> DataInfo:
if isinstance(paths, str):
paths = {'train': paths}

data_set = {}
for n, p in paths.items():
if self.data_format == 'snli':
data = self._load_snli(p)
else:
raise RuntimeError(f'Your data format is {self.data_format}, '
f'Please choose data format from [snli]')

if self.for_model == 'esim':
data = self._for_esim(data)
elif self.for_model == 'bert':
data = self._for_bert(data, self.bert_dir)
else:
raise RuntimeError(f'Your model is {self.data_format}, '
f'Please choose from [esim, bert]')

if input_field is not None:
if isinstance(input_field, str):
data.set_input(input_field)
elif isinstance(input_field, list):
for field in input_field:
data.set_input(field)

data_set[n] = data
print(f'successfully load {n} set!')

if not hasattr(self, 'vocab'):
raise RuntimeError(f'There is NOT vocab attribute built!')
if not hasattr(self, 'label_vocab'):
raise RuntimeError(f'There is NOT label vocab attribute built!')

if self.for_model != 'bert':
from fastNLP.modules.encoder.embedding import ElmoEmbedding
embedding = ElmoEmbedding(self.vocab, model_dir_or_name='en', requires_grad=True, layers='2')

data_info = DataInfo(vocabs={'vocab': self.vocab, 'target_vocab': self.label_vocab},
embeddings={'elmo': embedding} if self.for_model != 'bert' else None,
datasets=data_set)

return data_info

@staticmethod
def _load_snli(path: str) -> DataSet:
"""
读取SNLI数据集

数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
:param str path: 数据集路径
:return:
"""
raw_ds = JsonLoader(
fields={
'sentence1_parse': Const.INPUTS(0),
'sentence2_parse': Const.INPUTS(1),
'gold_label': Const.TARGET,
}
)._load(path)
return raw_ds

def _for_esim(self, raw_ds: DataSet):
if self.data_format == 'snli' or self.data_format == 'mnli':
def parse_tree(x):
t = Tree.fromstring(x)
return t.leaves()

raw_ds.apply(lambda ins: parse_tree(
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0))
raw_ds.apply(lambda ins: parse_tree(
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1))
raw_ds.drop(lambda x: x[Const.TARGET] == '-')

if not hasattr(self, 'vocab'):
self.vocab = Vocabulary().from_dataset(raw_ds, field_name=[Const.INPUTS(0), Const.INPUTS(1)])
if not hasattr(self, 'label_vocab'):
self.label_vocab = Vocabulary(padding=None, unknown=None).from_dataset(raw_ds, field_name=Const.TARGET)

raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0))
raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1))
raw_ds.apply(lambda ins: self.label_vocab.to_index(ins[Const.TARGET]), new_field_name=Const.TARGET)
raw_ds.apply(lambda ins: len(ins[Const.INPUTS(0)]), new_field_name=Const.INPUT_LENS(0))
raw_ds.apply(lambda ins: len(ins[Const.INPUTS(1)]), new_field_name=Const.INPUT_LENS(1))

raw_ds.set_input(Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1))
raw_ds.set_target(Const.TARGET)

return raw_ds

def _for_bert(self, raw_ds: DataSet, bert_dir: str):
if self.data_format == 'snli' or self.data_format == 'mnli':
def parse_tree(x):
t = Tree.fromstring(x)
return t.leaves()

raw_ds.apply(lambda ins: parse_tree(
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0))
raw_ds.apply(lambda ins: parse_tree(
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1))
raw_ds.drop(lambda x: x[Const.TARGET] == '-')

tokenizer = BertTokenizer.from_pretrained(bert_dir)

vocab = Vocabulary(padding=None, unknown=None)
with open(os.path.join(bert_dir, 'vocab.txt')) as f:
lines = f.readlines()
vocab_list = []
for line in lines:
vocab_list.append(line.strip())
vocab.add_word_lst(vocab_list)
vocab.build_vocab()
vocab.padding = '[PAD]'
vocab.unknown = '[UNK]'

if not hasattr(self, 'vocab'):
self.vocab = vocab
else:
for w, idx in self.vocab:
if vocab[w] != idx:
raise AttributeError(f"{self.__class__.__name__} has ")

for i in range(2):
raw_ds.apply(lambda x: tokenizer.tokenize(" ".join(x[Const.INPUTS(i)])), new_field_name=Const.INPUTS(i))
raw_ds.apply(lambda x: ['[CLS]'] + x[Const.INPUTS(0)] + ['[SEP]'] + x[Const.INPUTS(1)] + ['[SEP]'],
new_field_name=Const.INPUT)
raw_ds.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1),
new_field_name=Const.INPUT_LENS(0))
raw_ds.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), new_field_name=Const.INPUT_LENS(1))

max_len = 512
raw_ds.apply(lambda x: x[Const.INPUT][: max_len], new_field_name=Const.INPUT)
raw_ds.apply(lambda x: [self.vocab.to_index(w) for w in x[Const.INPUT]], new_field_name=Const.INPUT)
raw_ds.apply(lambda x: x[Const.INPUT_LENS(0)][: max_len], new_field_name=Const.INPUT_LENS(0))
raw_ds.apply(lambda x: x[Const.INPUT_LENS(1)][: max_len], new_field_name=Const.INPUT_LENS(1))

if not hasattr(self, 'label_vocab'):
self.label_vocab = Vocabulary(padding=None, unknown=None)
self.label_vocab.from_dataset(raw_ds, field_name=Const.TARGET)
raw_ds.apply(lambda x: self.label_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET)

raw_ds.set_input(Const.INPUT, Const.INPUT_LENS(0), Const.INPUT_LENS(1))
raw_ds.set_target(Const.TARGET)

return raw_ds


class SNLILoader(JsonLoader):
"""
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader`


+ 219
- 0
reproduction/matching/data/MatchingDataLoader.py View File

@@ -0,0 +1,219 @@

import os

from nltk import Tree
from typing import Union, Dict

from fastNLP.core.const import Const
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.dataset import DataSet
from fastNLP.io.base_loader import DataInfo
from fastNLP.io.dataset_loader import JsonLoader
from fastNLP.io.file_utils import _get_base_url, cached_path
from fastNLP.modules.encoder._bert import BertTokenizer


class MatchingLoader(JsonLoader):
"""
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader`

读取Matching任务的数据集
"""

def __init__(self, fields=None, paths: dict=None):
super(MatchingLoader, self).__init__(fields=fields)
self.paths = paths

def _load(self, path):
return super(MatchingLoader, self)._load(path)

def process(self, paths: Union[str, Dict[str, str]], dataset_name=None,
to_lower=False, char_information=False, seq_len_type: str=None,
bert_tokenizer: str=None, get_index=True, set_input: Union[list, bool]=True,
set_target: Union[list, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo:
if isinstance(set_input, bool):
auto_set_input = set_input
else:
auto_set_input = False
if isinstance(set_target, bool):
auto_set_target = set_target
else:
auto_set_target = False
if isinstance(paths, str):
if os.path.isdir(paths):
path = {n: os.path.join(paths, self.paths[n]) for n in self.paths.keys()}
else:
path = {dataset_name if dataset_name is not None else 'train': paths}
else:
path = paths

data_info = DataInfo()
for data_name in path.keys():
data_info.datasets[data_name] = self._load(path[data_name])

for data_name, data_set in data_info.datasets.items():
if auto_set_input:
data_set.set_input(Const.INPUTS(0), Const.INPUTS(1))
if auto_set_target:
data_set.set_target(Const.TARGET)

if to_lower:
for data_name, data_set in data_info.datasets.items():
data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0),
is_input=auto_set_input)
data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1),
is_input=auto_set_input)

if bert_tokenizer is not None:
PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip',
'en-base-uncased': 'bert-base-uncased-3413b23c.zip',
'en-base-cased': 'bert-base-cased-f89bfe08.zip',
'en-large-uncased': 'bert-large-uncased-20939f45.zip',
'en-large-cased': 'bert-large-cased-e0cf90fc.zip',

'cn': 'bert-base-chinese-29d0a84a.zip',
'cn-base': 'bert-base-chinese-29d0a84a.zip',

'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip',
'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip',
'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip',
}
if bert_tokenizer.lower() in PRETRAINED_BERT_MODEL_DIR:
PRETRAIN_URL = _get_base_url('bert')
model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer]
model_url = PRETRAIN_URL + model_name
model_dir = cached_path(model_url)
# 检查是否存在
elif os.path.isdir(bert_tokenizer):
model_dir = bert_tokenizer
else:
raise ValueError(f"Cannot recognize BERT tokenizer from {bert_tokenizer}.")

tokenizer = BertTokenizer.from_pretrained(model_dir)

for data_name, data_set in data_info.datasets.items():
for fields in data_set.get_field_names():
if Const.INPUT in fields:
data_set.apply(lambda x: tokenizer.tokenize(' '.join(x[fields])), new_field_name=fields,
is_input=auto_set_input)

if isinstance(concat, bool):
concat = 'default' if concat else None
if concat is not None:
if isinstance(concat, str):
CONCAT_MAP = {'bert': ['[CLS]', '[SEP]', '', '[SEP]'],
'default': ['', '<sep>', '', '']}
if concat.lower() in CONCAT_MAP:
concat = CONCAT_MAP[concat]
else:
concat = 4 * [concat]
assert len(concat) == 4, \
f'Please choose a list with 4 symbols which at the beginning of first sentence ' \
f'the end of first sentence, the begin of second sentence, and the end of second' \
f'sentence. Your input is {concat}'

for data_name, data_set in data_info.datasets.items():
data_set.apply(lambda x: [concat[0]] + x[Const.INPUTS(0)] + [concat[1]] + [concat[2]] +
x[Const.INPUTS(1)] + [concat[3]], new_field_name=Const.INPUT)
data_set.apply(lambda x: [w for w in x[Const.INPUT] if len(w) > 0], new_field_name=Const.INPUT,
is_input=auto_set_input)

if seq_len_type is not None:
if seq_len_type == 'seq_len': #
for data_name, data_set in data_info.datasets.items():
for fields in data_set.get_field_names():
if Const.INPUT in fields:
data_set.apply(lambda x: len(x[fields]),
new_field_name=fields.replace(Const.INPUT, Const.TARGET),
is_input=auto_set_input)
elif seq_len_type == 'mask':
for data_name, data_set in data_info.datasets.items():
for fields in data_set.get_field_names():
if Const.INPUT in fields:
data_set.apply(lambda x: [1] * len(x[fields]),
new_field_name=fields.replace(Const.INPUT, Const.TARGET),
is_input=auto_set_input)
elif seq_len_type == 'bert':
for data_name, data_set in data_info.datasets.items():
if Const.INPUT not in data_set.get_field_names():
raise KeyError(f'Field ``{Const.INPUT}`` not in {data_name} data set: '
f'got {data_set.get_field_names()}')
data_set.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1),
new_field_name=Const.INPUT_LENS(0), is_input=auto_set_input)
data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]),
new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input)

data_set_list = [d for n, d in data_info.datasets.items()]
assert len(data_set_list) > 0, f'There are NO data sets in data info!'

if bert_tokenizer is not None:
words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]')
else:
words_vocab = Vocabulary()
words_vocab = words_vocab.from_dataset(*data_set_list,
field_name=[n for n in data_set_list[0].get_field_names()
if (Const.INPUT in n)])
target_vocab = Vocabulary(padding=None, unknown=None)
target_vocab = target_vocab.from_dataset(*data_set_list, field_name=Const.TARGET)
data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab}

if get_index:
for data_name, data_set in data_info.datasets.items():
for fields in data_set.get_field_names():
if Const.INPUT in fields:
data_set.apply(lambda x: [words_vocab.to_index(w) for w in x[fields]], new_field_name=fields,
is_input=auto_set_input)

data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET,
is_input=auto_set_input, is_target=auto_set_target)

for data_name, data_set in data_info.datasets.items():
if isinstance(set_input, list):
data_set.set_input(set_input)
if isinstance(set_target, list):
data_set.set_target(set_target)

return data_info


class SNLILoader(MatchingLoader):
"""
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader`

读取SNLI数据集,读取的DataSet包含fields::

words1: list(str),第一句文本, premise
words2: list(str), 第二句文本, hypothesis
target: str, 真实标签

数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
"""

def __init__(self, paths: dict=None):
fields = {
'sentence1_parse': Const.INPUTS(0),
'sentence2_parse': Const.INPUTS(1),
'gold_label': Const.TARGET,
}
paths = paths if paths is not None else {
'train': 'snli_1.0_train.jsonl',
'dev': 'snli_1.0_dev.jsonl',
'test': 'snli_1.0_test.jsonl'}
super(SNLILoader, self).__init__(fields=fields, paths=paths)

def _load(self, path):
ds = super(SNLILoader, self)._load(path)

def parse_tree(x):
t = Tree.fromstring(x)
return t.leaves()

ds.apply(lambda ins: parse_tree(
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0))
ds.apply(lambda ins: parse_tree(
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1))
ds.drop(lambda x: x[Const.TARGET] == '-')
return ds




+ 0
- 6
reproduction/matching/data/SNLIDataLoader.py View File

@@ -1,6 +0,0 @@

from fastNLP.io.dataset_loader import SNLILoader

# TODO: still in progress



+ 2
- 2
reproduction/matching/test/test_snlidataloader.py View File

@@ -1,10 +1,10 @@
import unittest
from ..data import SNLIDataLoader
from ..data import MatchingDataLoader
from fastNLP.core.vocabulary import Vocabulary


class TestCWSDataLoader(unittest.TestCase):
def test_case1(self):
snli_loader = SNLIDataLoader()
snli_loader = MatchingDataLoader()
# TODO: still in progress


Loading…
Cancel
Save