@@ -16,7 +16,6 @@ __all__ = [ | |||||
'CSVLoader', | 'CSVLoader', | ||||
'JsonLoader', | 'JsonLoader', | ||||
'ConllLoader', | 'ConllLoader', | ||||
'MatchingLoader', | |||||
'SNLILoader', | 'SNLILoader', | ||||
'SSTLoader', | 'SSTLoader', | ||||
'PeopleDailyCorpusLoader', | 'PeopleDailyCorpusLoader', | ||||
@@ -27,6 +26,6 @@ __all__ = [ | |||||
] | ] | ||||
from .embed_loader import EmbedLoader | from .embed_loader import EmbedLoader | ||||
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, MatchingLoader,\ | |||||
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, \ | |||||
SNLILoader, SSTLoader, PeopleDailyCorpusLoader, Conll2003Loader | SNLILoader, SSTLoader, PeopleDailyCorpusLoader, Conll2003Loader | ||||
from .model_io import ModelLoader, ModelSaver | from .model_io import ModelLoader, ModelSaver |
@@ -16,7 +16,6 @@ __all__ = [ | |||||
'CSVLoader', | 'CSVLoader', | ||||
'JsonLoader', | 'JsonLoader', | ||||
'ConllLoader', | 'ConllLoader', | ||||
'MatchingLoader', | |||||
'SNLILoader', | 'SNLILoader', | ||||
'SSTLoader', | 'SSTLoader', | ||||
'PeopleDailyCorpusLoader', | 'PeopleDailyCorpusLoader', | ||||
@@ -250,173 +249,6 @@ class JsonLoader(DataSetLoader): | |||||
return ds | return ds | ||||
class MatchingLoader(DataSetLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` | |||||
读取Matching数据集,根据数据集做预处理并返回DataInfo。 | |||||
数据来源: | |||||
SNLI: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||||
""" | |||||
def __init__(self, data_format: str='snli', for_model: str='esim', bert_dir=None): | |||||
super(MatchingLoader, self).__init__() | |||||
self.data_format = data_format.lower() | |||||
self.for_model = for_model.lower() | |||||
self.bert_dir = bert_dir | |||||
def _load(self, path: str) -> DataSet: | |||||
raise NotImplementedError | |||||
def process(self, paths: Union[str, Dict[str, str]], input_field=None) -> DataInfo: | |||||
if isinstance(paths, str): | |||||
paths = {'train': paths} | |||||
data_set = {} | |||||
for n, p in paths.items(): | |||||
if self.data_format == 'snli': | |||||
data = self._load_snli(p) | |||||
else: | |||||
raise RuntimeError(f'Your data format is {self.data_format}, ' | |||||
f'Please choose data format from [snli]') | |||||
if self.for_model == 'esim': | |||||
data = self._for_esim(data) | |||||
elif self.for_model == 'bert': | |||||
data = self._for_bert(data, self.bert_dir) | |||||
else: | |||||
raise RuntimeError(f'Your model is {self.data_format}, ' | |||||
f'Please choose from [esim, bert]') | |||||
if input_field is not None: | |||||
if isinstance(input_field, str): | |||||
data.set_input(input_field) | |||||
elif isinstance(input_field, list): | |||||
for field in input_field: | |||||
data.set_input(field) | |||||
data_set[n] = data | |||||
print(f'successfully load {n} set!') | |||||
if not hasattr(self, 'vocab'): | |||||
raise RuntimeError(f'There is NOT vocab attribute built!') | |||||
if not hasattr(self, 'label_vocab'): | |||||
raise RuntimeError(f'There is NOT label vocab attribute built!') | |||||
if self.for_model != 'bert': | |||||
from fastNLP.modules.encoder.embedding import ElmoEmbedding | |||||
embedding = ElmoEmbedding(self.vocab, model_dir_or_name='en', requires_grad=True, layers='2') | |||||
data_info = DataInfo(vocabs={'vocab': self.vocab, 'target_vocab': self.label_vocab}, | |||||
embeddings={'elmo': embedding} if self.for_model != 'bert' else None, | |||||
datasets=data_set) | |||||
return data_info | |||||
@staticmethod | |||||
def _load_snli(path: str) -> DataSet: | |||||
""" | |||||
读取SNLI数据集 | |||||
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||||
:param str path: 数据集路径 | |||||
:return: | |||||
""" | |||||
raw_ds = JsonLoader( | |||||
fields={ | |||||
'sentence1_parse': Const.INPUTS(0), | |||||
'sentence2_parse': Const.INPUTS(1), | |||||
'gold_label': Const.TARGET, | |||||
} | |||||
)._load(path) | |||||
return raw_ds | |||||
def _for_esim(self, raw_ds: DataSet): | |||||
if self.data_format == 'snli' or self.data_format == 'mnli': | |||||
def parse_tree(x): | |||||
t = Tree.fromstring(x) | |||||
return t.leaves() | |||||
raw_ds.apply(lambda ins: parse_tree( | |||||
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) | |||||
raw_ds.apply(lambda ins: parse_tree( | |||||
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) | |||||
raw_ds.drop(lambda x: x[Const.TARGET] == '-') | |||||
if not hasattr(self, 'vocab'): | |||||
self.vocab = Vocabulary().from_dataset(raw_ds, field_name=[Const.INPUTS(0), Const.INPUTS(1)]) | |||||
if not hasattr(self, 'label_vocab'): | |||||
self.label_vocab = Vocabulary(padding=None, unknown=None).from_dataset(raw_ds, field_name=Const.TARGET) | |||||
raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0)) | |||||
raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1)) | |||||
raw_ds.apply(lambda ins: self.label_vocab.to_index(ins[Const.TARGET]), new_field_name=Const.TARGET) | |||||
raw_ds.apply(lambda ins: len(ins[Const.INPUTS(0)]), new_field_name=Const.INPUT_LENS(0)) | |||||
raw_ds.apply(lambda ins: len(ins[Const.INPUTS(1)]), new_field_name=Const.INPUT_LENS(1)) | |||||
raw_ds.set_input(Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1)) | |||||
raw_ds.set_target(Const.TARGET) | |||||
return raw_ds | |||||
def _for_bert(self, raw_ds: DataSet, bert_dir: str): | |||||
if self.data_format == 'snli' or self.data_format == 'mnli': | |||||
def parse_tree(x): | |||||
t = Tree.fromstring(x) | |||||
return t.leaves() | |||||
raw_ds.apply(lambda ins: parse_tree( | |||||
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) | |||||
raw_ds.apply(lambda ins: parse_tree( | |||||
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) | |||||
raw_ds.drop(lambda x: x[Const.TARGET] == '-') | |||||
tokenizer = BertTokenizer.from_pretrained(bert_dir) | |||||
vocab = Vocabulary(padding=None, unknown=None) | |||||
with open(os.path.join(bert_dir, 'vocab.txt')) as f: | |||||
lines = f.readlines() | |||||
vocab_list = [] | |||||
for line in lines: | |||||
vocab_list.append(line.strip()) | |||||
vocab.add_word_lst(vocab_list) | |||||
vocab.build_vocab() | |||||
vocab.padding = '[PAD]' | |||||
vocab.unknown = '[UNK]' | |||||
if not hasattr(self, 'vocab'): | |||||
self.vocab = vocab | |||||
else: | |||||
for w, idx in self.vocab: | |||||
if vocab[w] != idx: | |||||
raise AttributeError(f"{self.__class__.__name__} has ") | |||||
for i in range(2): | |||||
raw_ds.apply(lambda x: tokenizer.tokenize(" ".join(x[Const.INPUTS(i)])), new_field_name=Const.INPUTS(i)) | |||||
raw_ds.apply(lambda x: ['[CLS]'] + x[Const.INPUTS(0)] + ['[SEP]'] + x[Const.INPUTS(1)] + ['[SEP]'], | |||||
new_field_name=Const.INPUT) | |||||
raw_ds.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), | |||||
new_field_name=Const.INPUT_LENS(0)) | |||||
raw_ds.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), new_field_name=Const.INPUT_LENS(1)) | |||||
max_len = 512 | |||||
raw_ds.apply(lambda x: x[Const.INPUT][: max_len], new_field_name=Const.INPUT) | |||||
raw_ds.apply(lambda x: [self.vocab.to_index(w) for w in x[Const.INPUT]], new_field_name=Const.INPUT) | |||||
raw_ds.apply(lambda x: x[Const.INPUT_LENS(0)][: max_len], new_field_name=Const.INPUT_LENS(0)) | |||||
raw_ds.apply(lambda x: x[Const.INPUT_LENS(1)][: max_len], new_field_name=Const.INPUT_LENS(1)) | |||||
if not hasattr(self, 'label_vocab'): | |||||
self.label_vocab = Vocabulary(padding=None, unknown=None) | |||||
self.label_vocab.from_dataset(raw_ds, field_name=Const.TARGET) | |||||
raw_ds.apply(lambda x: self.label_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET) | |||||
raw_ds.set_input(Const.INPUT, Const.INPUT_LENS(0), Const.INPUT_LENS(1)) | |||||
raw_ds.set_target(Const.TARGET) | |||||
return raw_ds | |||||
class SNLILoader(JsonLoader): | class SNLILoader(JsonLoader): | ||||
""" | """ | ||||
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` | 别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` | ||||
@@ -0,0 +1,223 @@ | |||||
import os | |||||
from nltk import Tree | |||||
from typing import Union, Dict | |||||
from fastNLP.core.const import Const | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.io.base_loader import DataInfo | |||||
from fastNLP.io.dataset_loader import JsonLoader | |||||
from fastNLP.io.file_utils import _get_base_url, cached_path | |||||
from fastNLP.modules.encoder._bert import BertTokenizer | |||||
class MatchingLoader(JsonLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` | |||||
读取Matching任务的数据集 | |||||
""" | |||||
def __init__(self, fields=None, paths: dict=None): | |||||
super(MatchingLoader, self).__init__(fields=fields) | |||||
self.paths = paths | |||||
def _load(self, path): | |||||
return super(MatchingLoader, self)._load(path) | |||||
def process(self, paths: Union[str, Dict[str, str]], dataset_name=None, | |||||
to_lower=False, char_information=False, seq_len_type: str=None, | |||||
bert_tokenizer: str=None, get_index=True, set_input: Union[list, str, bool]=True, | |||||
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo: | |||||
if isinstance(set_input, str): | |||||
set_input = [set_input] | |||||
if isinstance(set_target, str): | |||||
set_target = [set_target] | |||||
if isinstance(set_input, bool): | |||||
auto_set_input = set_input | |||||
else: | |||||
auto_set_input = False | |||||
if isinstance(set_target, bool): | |||||
auto_set_target = set_target | |||||
else: | |||||
auto_set_target = False | |||||
if isinstance(paths, str): | |||||
if os.path.isdir(paths): | |||||
path = {n: os.path.join(paths, self.paths[n]) for n in self.paths.keys()} | |||||
else: | |||||
path = {dataset_name if dataset_name is not None else 'train': paths} | |||||
else: | |||||
path = paths | |||||
data_info = DataInfo() | |||||
for data_name in path.keys(): | |||||
data_info.datasets[data_name] = self._load(path[data_name]) | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
if auto_set_input: | |||||
data_set.set_input(Const.INPUTS(0), Const.INPUTS(1)) | |||||
if auto_set_target: | |||||
data_set.set_target(Const.TARGET) | |||||
if to_lower: | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0), | |||||
is_input=auto_set_input) | |||||
data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1), | |||||
is_input=auto_set_input) | |||||
if bert_tokenizer is not None: | |||||
PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip', | |||||
'en-base-uncased': 'bert-base-uncased-3413b23c.zip', | |||||
'en-base-cased': 'bert-base-cased-f89bfe08.zip', | |||||
'en-large-uncased': 'bert-large-uncased-20939f45.zip', | |||||
'en-large-cased': 'bert-large-cased-e0cf90fc.zip', | |||||
'cn': 'bert-base-chinese-29d0a84a.zip', | |||||
'cn-base': 'bert-base-chinese-29d0a84a.zip', | |||||
'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip', | |||||
'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip', | |||||
'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip', | |||||
} | |||||
if bert_tokenizer.lower() in PRETRAINED_BERT_MODEL_DIR: | |||||
PRETRAIN_URL = _get_base_url('bert') | |||||
model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer] | |||||
model_url = PRETRAIN_URL + model_name | |||||
model_dir = cached_path(model_url) | |||||
# 检查是否存在 | |||||
elif os.path.isdir(bert_tokenizer): | |||||
model_dir = bert_tokenizer | |||||
else: | |||||
raise ValueError(f"Cannot recognize BERT tokenizer from {bert_tokenizer}.") | |||||
tokenizer = BertTokenizer.from_pretrained(model_dir) | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: tokenizer.tokenize(' '.join(x[fields])), new_field_name=fields, | |||||
is_input=auto_set_input) | |||||
if isinstance(concat, bool): | |||||
concat = 'default' if concat else None | |||||
if concat is not None: | |||||
if isinstance(concat, str): | |||||
CONCAT_MAP = {'bert': ['[CLS]', '[SEP]', '', '[SEP]'], | |||||
'default': ['', '<sep>', '', '']} | |||||
if concat.lower() in CONCAT_MAP: | |||||
concat = CONCAT_MAP[concat] | |||||
else: | |||||
concat = 4 * [concat] | |||||
assert len(concat) == 4, \ | |||||
f'Please choose a list with 4 symbols which at the beginning of first sentence ' \ | |||||
f'the end of first sentence, the begin of second sentence, and the end of second' \ | |||||
f'sentence. Your input is {concat}' | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
data_set.apply(lambda x: [concat[0]] + x[Const.INPUTS(0)] + [concat[1]] + [concat[2]] + | |||||
x[Const.INPUTS(1)] + [concat[3]], new_field_name=Const.INPUT) | |||||
data_set.apply(lambda x: [w for w in x[Const.INPUT] if len(w) > 0], new_field_name=Const.INPUT, | |||||
is_input=auto_set_input) | |||||
if seq_len_type is not None: | |||||
if seq_len_type == 'seq_len': # | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: len(x[fields]), | |||||
new_field_name=fields.replace(Const.INPUT, Const.TARGET), | |||||
is_input=auto_set_input) | |||||
elif seq_len_type == 'mask': | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: [1] * len(x[fields]), | |||||
new_field_name=fields.replace(Const.INPUT, Const.TARGET), | |||||
is_input=auto_set_input) | |||||
elif seq_len_type == 'bert': | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
if Const.INPUT not in data_set.get_field_names(): | |||||
raise KeyError(f'Field ``{Const.INPUT}`` not in {data_name} data set: ' | |||||
f'got {data_set.get_field_names()}') | |||||
data_set.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), | |||||
new_field_name=Const.INPUT_LENS(0), is_input=auto_set_input) | |||||
data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), | |||||
new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) | |||||
data_set_list = [d for n, d in data_info.datasets.items()] | |||||
assert len(data_set_list) > 0, f'There are NO data sets in data info!' | |||||
if bert_tokenizer is not None: | |||||
words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]') | |||||
else: | |||||
words_vocab = Vocabulary() | |||||
words_vocab = words_vocab.from_dataset(*data_set_list, | |||||
field_name=[n for n in data_set_list[0].get_field_names() | |||||
if (Const.INPUT in n)]) | |||||
target_vocab = Vocabulary(padding=None, unknown=None) | |||||
target_vocab = target_vocab.from_dataset(*data_set_list, field_name=Const.TARGET) | |||||
data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab} | |||||
if get_index: | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: [words_vocab.to_index(w) for w in x[fields]], new_field_name=fields, | |||||
is_input=auto_set_input) | |||||
data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, | |||||
is_input=auto_set_input, is_target=auto_set_target) | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
if isinstance(set_input, list): | |||||
data_set.set_input(set_input) | |||||
if isinstance(set_target, list): | |||||
data_set.set_target(set_target) | |||||
return data_info | |||||
class SNLILoader(MatchingLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` | |||||
读取SNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
fields = { | |||||
'sentence1_parse': Const.INPUTS(0), | |||||
'sentence2_parse': Const.INPUTS(1), | |||||
'gold_label': Const.TARGET, | |||||
} | |||||
paths = paths if paths is not None else { | |||||
'train': 'snli_1.0_train.jsonl', | |||||
'dev': 'snli_1.0_dev.jsonl', | |||||
'test': 'snli_1.0_test.jsonl'} | |||||
super(SNLILoader, self).__init__(fields=fields, paths=paths) | |||||
def _load(self, path): | |||||
ds = super(SNLILoader, self)._load(path) | |||||
def parse_tree(x): | |||||
t = Tree.fromstring(x) | |||||
return t.leaves() | |||||
ds.apply(lambda ins: parse_tree( | |||||
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) | |||||
ds.apply(lambda ins: parse_tree( | |||||
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) | |||||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||||
return ds | |||||
@@ -1,6 +0,0 @@ | |||||
from fastNLP.io.dataset_loader import SNLILoader | |||||
# TODO: still in progress | |||||
@@ -1,10 +1,10 @@ | |||||
import unittest | import unittest | ||||
from ..data import SNLIDataLoader | |||||
from ..data import MatchingDataLoader | |||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
class TestCWSDataLoader(unittest.TestCase): | class TestCWSDataLoader(unittest.TestCase): | ||||
def test_case1(self): | def test_case1(self): | ||||
snli_loader = SNLIDataLoader() | |||||
snli_loader = MatchingDataLoader() | |||||
# TODO: still in progress | # TODO: still in progress | ||||