|
@@ -0,0 +1,219 @@ |
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
from nltk import Tree |
|
|
|
|
|
from typing import Union, Dict |
|
|
|
|
|
|
|
|
|
|
|
from fastNLP.core.const import Const |
|
|
|
|
|
from fastNLP.core.vocabulary import Vocabulary |
|
|
|
|
|
from fastNLP.core.dataset import DataSet |
|
|
|
|
|
from fastNLP.io.base_loader import DataInfo |
|
|
|
|
|
from fastNLP.io.dataset_loader import JsonLoader |
|
|
|
|
|
from fastNLP.io.file_utils import _get_base_url, cached_path |
|
|
|
|
|
from fastNLP.modules.encoder._bert import BertTokenizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MatchingLoader(JsonLoader): |
|
|
|
|
|
""" |
|
|
|
|
|
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` |
|
|
|
|
|
|
|
|
|
|
|
读取Matching任务的数据集 |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, fields=None, paths: dict=None): |
|
|
|
|
|
super(MatchingLoader, self).__init__(fields=fields) |
|
|
|
|
|
self.paths = paths |
|
|
|
|
|
|
|
|
|
|
|
def _load(self, path): |
|
|
|
|
|
return super(MatchingLoader, self)._load(path) |
|
|
|
|
|
|
|
|
|
|
|
def process(self, paths: Union[str, Dict[str, str]], dataset_name=None, |
|
|
|
|
|
to_lower=False, char_information=False, seq_len_type: str=None, |
|
|
|
|
|
bert_tokenizer: str=None, get_index=True, set_input: Union[list, bool]=True, |
|
|
|
|
|
set_target: Union[list, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo: |
|
|
|
|
|
if isinstance(set_input, bool): |
|
|
|
|
|
auto_set_input = set_input |
|
|
|
|
|
else: |
|
|
|
|
|
auto_set_input = False |
|
|
|
|
|
if isinstance(set_target, bool): |
|
|
|
|
|
auto_set_target = set_target |
|
|
|
|
|
else: |
|
|
|
|
|
auto_set_target = False |
|
|
|
|
|
if isinstance(paths, str): |
|
|
|
|
|
if os.path.isdir(paths): |
|
|
|
|
|
path = {n: os.path.join(paths, self.paths[n]) for n in self.paths.keys()} |
|
|
|
|
|
else: |
|
|
|
|
|
path = {dataset_name if dataset_name is not None else 'train': paths} |
|
|
|
|
|
else: |
|
|
|
|
|
path = paths |
|
|
|
|
|
|
|
|
|
|
|
data_info = DataInfo() |
|
|
|
|
|
for data_name in path.keys(): |
|
|
|
|
|
data_info.datasets[data_name] = self._load(path[data_name]) |
|
|
|
|
|
|
|
|
|
|
|
for data_name, data_set in data_info.datasets.items(): |
|
|
|
|
|
if auto_set_input: |
|
|
|
|
|
data_set.set_input(Const.INPUTS(0), Const.INPUTS(1)) |
|
|
|
|
|
if auto_set_target: |
|
|
|
|
|
data_set.set_target(Const.TARGET) |
|
|
|
|
|
|
|
|
|
|
|
if to_lower: |
|
|
|
|
|
for data_name, data_set in data_info.datasets.items(): |
|
|
|
|
|
data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0), |
|
|
|
|
|
is_input=auto_set_input) |
|
|
|
|
|
data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1), |
|
|
|
|
|
is_input=auto_set_input) |
|
|
|
|
|
|
|
|
|
|
|
if bert_tokenizer is not None: |
|
|
|
|
|
PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip', |
|
|
|
|
|
'en-base-uncased': 'bert-base-uncased-3413b23c.zip', |
|
|
|
|
|
'en-base-cased': 'bert-base-cased-f89bfe08.zip', |
|
|
|
|
|
'en-large-uncased': 'bert-large-uncased-20939f45.zip', |
|
|
|
|
|
'en-large-cased': 'bert-large-cased-e0cf90fc.zip', |
|
|
|
|
|
|
|
|
|
|
|
'cn': 'bert-base-chinese-29d0a84a.zip', |
|
|
|
|
|
'cn-base': 'bert-base-chinese-29d0a84a.zip', |
|
|
|
|
|
|
|
|
|
|
|
'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip', |
|
|
|
|
|
'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip', |
|
|
|
|
|
'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip', |
|
|
|
|
|
} |
|
|
|
|
|
if bert_tokenizer.lower() in PRETRAINED_BERT_MODEL_DIR: |
|
|
|
|
|
PRETRAIN_URL = _get_base_url('bert') |
|
|
|
|
|
model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer] |
|
|
|
|
|
model_url = PRETRAIN_URL + model_name |
|
|
|
|
|
model_dir = cached_path(model_url) |
|
|
|
|
|
# 检查是否存在 |
|
|
|
|
|
elif os.path.isdir(bert_tokenizer): |
|
|
|
|
|
model_dir = bert_tokenizer |
|
|
|
|
|
else: |
|
|
|
|
|
raise ValueError(f"Cannot recognize BERT tokenizer from {bert_tokenizer}.") |
|
|
|
|
|
|
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained(model_dir) |
|
|
|
|
|
|
|
|
|
|
|
for data_name, data_set in data_info.datasets.items(): |
|
|
|
|
|
for fields in data_set.get_field_names(): |
|
|
|
|
|
if Const.INPUT in fields: |
|
|
|
|
|
data_set.apply(lambda x: tokenizer.tokenize(' '.join(x[fields])), new_field_name=fields, |
|
|
|
|
|
is_input=auto_set_input) |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(concat, bool): |
|
|
|
|
|
concat = 'default' if concat else None |
|
|
|
|
|
if concat is not None: |
|
|
|
|
|
if isinstance(concat, str): |
|
|
|
|
|
CONCAT_MAP = {'bert': ['[CLS]', '[SEP]', '', '[SEP]'], |
|
|
|
|
|
'default': ['', '<sep>', '', '']} |
|
|
|
|
|
if concat.lower() in CONCAT_MAP: |
|
|
|
|
|
concat = CONCAT_MAP[concat] |
|
|
|
|
|
else: |
|
|
|
|
|
concat = 4 * [concat] |
|
|
|
|
|
assert len(concat) == 4, \ |
|
|
|
|
|
f'Please choose a list with 4 symbols which at the beginning of first sentence ' \ |
|
|
|
|
|
f'the end of first sentence, the begin of second sentence, and the end of second' \ |
|
|
|
|
|
f'sentence. Your input is {concat}' |
|
|
|
|
|
|
|
|
|
|
|
for data_name, data_set in data_info.datasets.items(): |
|
|
|
|
|
data_set.apply(lambda x: [concat[0]] + x[Const.INPUTS(0)] + [concat[1]] + [concat[2]] + |
|
|
|
|
|
x[Const.INPUTS(1)] + [concat[3]], new_field_name=Const.INPUT) |
|
|
|
|
|
data_set.apply(lambda x: [w for w in x[Const.INPUT] if len(w) > 0], new_field_name=Const.INPUT, |
|
|
|
|
|
is_input=auto_set_input) |
|
|
|
|
|
|
|
|
|
|
|
if seq_len_type is not None: |
|
|
|
|
|
if seq_len_type == 'seq_len': # |
|
|
|
|
|
for data_name, data_set in data_info.datasets.items(): |
|
|
|
|
|
for fields in data_set.get_field_names(): |
|
|
|
|
|
if Const.INPUT in fields: |
|
|
|
|
|
data_set.apply(lambda x: len(x[fields]), |
|
|
|
|
|
new_field_name=fields.replace(Const.INPUT, Const.TARGET), |
|
|
|
|
|
is_input=auto_set_input) |
|
|
|
|
|
elif seq_len_type == 'mask': |
|
|
|
|
|
for data_name, data_set in data_info.datasets.items(): |
|
|
|
|
|
for fields in data_set.get_field_names(): |
|
|
|
|
|
if Const.INPUT in fields: |
|
|
|
|
|
data_set.apply(lambda x: [1] * len(x[fields]), |
|
|
|
|
|
new_field_name=fields.replace(Const.INPUT, Const.TARGET), |
|
|
|
|
|
is_input=auto_set_input) |
|
|
|
|
|
elif seq_len_type == 'bert': |
|
|
|
|
|
for data_name, data_set in data_info.datasets.items(): |
|
|
|
|
|
if Const.INPUT not in data_set.get_field_names(): |
|
|
|
|
|
raise KeyError(f'Field ``{Const.INPUT}`` not in {data_name} data set: ' |
|
|
|
|
|
f'got {data_set.get_field_names()}') |
|
|
|
|
|
data_set.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), |
|
|
|
|
|
new_field_name=Const.INPUT_LENS(0), is_input=auto_set_input) |
|
|
|
|
|
data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), |
|
|
|
|
|
new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) |
|
|
|
|
|
|
|
|
|
|
|
data_set_list = [d for n, d in data_info.datasets.items()] |
|
|
|
|
|
assert len(data_set_list) > 0, f'There are NO data sets in data info!' |
|
|
|
|
|
|
|
|
|
|
|
if bert_tokenizer is not None: |
|
|
|
|
|
words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]') |
|
|
|
|
|
else: |
|
|
|
|
|
words_vocab = Vocabulary() |
|
|
|
|
|
words_vocab = words_vocab.from_dataset(*data_set_list, |
|
|
|
|
|
field_name=[n for n in data_set_list[0].get_field_names() |
|
|
|
|
|
if (Const.INPUT in n)]) |
|
|
|
|
|
target_vocab = Vocabulary(padding=None, unknown=None) |
|
|
|
|
|
target_vocab = target_vocab.from_dataset(*data_set_list, field_name=Const.TARGET) |
|
|
|
|
|
data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab} |
|
|
|
|
|
|
|
|
|
|
|
if get_index: |
|
|
|
|
|
for data_name, data_set in data_info.datasets.items(): |
|
|
|
|
|
for fields in data_set.get_field_names(): |
|
|
|
|
|
if Const.INPUT in fields: |
|
|
|
|
|
data_set.apply(lambda x: [words_vocab.to_index(w) for w in x[fields]], new_field_name=fields, |
|
|
|
|
|
is_input=auto_set_input) |
|
|
|
|
|
|
|
|
|
|
|
data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, |
|
|
|
|
|
is_input=auto_set_input, is_target=auto_set_target) |
|
|
|
|
|
|
|
|
|
|
|
for data_name, data_set in data_info.datasets.items(): |
|
|
|
|
|
if isinstance(set_input, list): |
|
|
|
|
|
data_set.set_input(set_input) |
|
|
|
|
|
if isinstance(set_target, list): |
|
|
|
|
|
data_set.set_target(set_target) |
|
|
|
|
|
|
|
|
|
|
|
return data_info |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SNLILoader(MatchingLoader): |
|
|
|
|
|
""" |
|
|
|
|
|
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` |
|
|
|
|
|
|
|
|
|
|
|
读取SNLI数据集,读取的DataSet包含fields:: |
|
|
|
|
|
|
|
|
|
|
|
words1: list(str),第一句文本, premise |
|
|
|
|
|
words2: list(str), 第二句文本, hypothesis |
|
|
|
|
|
target: str, 真实标签 |
|
|
|
|
|
|
|
|
|
|
|
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, paths: dict=None): |
|
|
|
|
|
fields = { |
|
|
|
|
|
'sentence1_parse': Const.INPUTS(0), |
|
|
|
|
|
'sentence2_parse': Const.INPUTS(1), |
|
|
|
|
|
'gold_label': Const.TARGET, |
|
|
|
|
|
} |
|
|
|
|
|
paths = paths if paths is not None else { |
|
|
|
|
|
'train': 'snli_1.0_train.jsonl', |
|
|
|
|
|
'dev': 'snli_1.0_dev.jsonl', |
|
|
|
|
|
'test': 'snli_1.0_test.jsonl'} |
|
|
|
|
|
super(SNLILoader, self).__init__(fields=fields, paths=paths) |
|
|
|
|
|
|
|
|
|
|
|
def _load(self, path): |
|
|
|
|
|
ds = super(SNLILoader, self)._load(path) |
|
|
|
|
|
|
|
|
|
|
|
def parse_tree(x): |
|
|
|
|
|
t = Tree.fromstring(x) |
|
|
|
|
|
return t.leaves() |
|
|
|
|
|
|
|
|
|
|
|
ds.apply(lambda ins: parse_tree( |
|
|
|
|
|
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) |
|
|
|
|
|
ds.apply(lambda ins: parse_tree( |
|
|
|
|
|
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) |
|
|
|
|
|
ds.drop(lambda x: x[Const.TARGET] == '-') |
|
|
|
|
|
return ds |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|