diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index b0bf2e60..01e6c8ed 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -250,173 +250,6 @@ class JsonLoader(DataSetLoader): return ds -class MatchingLoader(DataSetLoader): - """ - 别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` - - 读取Matching数据集,根据数据集做预处理并返回DataInfo。 - - 数据来源: - SNLI: https://nlp.stanford.edu/projects/snli/snli_1.0.zip - """ - - def __init__(self, data_format: str='snli', for_model: str='esim', bert_dir=None): - super(MatchingLoader, self).__init__() - self.data_format = data_format.lower() - self.for_model = for_model.lower() - self.bert_dir = bert_dir - - def _load(self, path: str) -> DataSet: - raise NotImplementedError - - def process(self, paths: Union[str, Dict[str, str]], input_field=None) -> DataInfo: - if isinstance(paths, str): - paths = {'train': paths} - - data_set = {} - for n, p in paths.items(): - if self.data_format == 'snli': - data = self._load_snli(p) - else: - raise RuntimeError(f'Your data format is {self.data_format}, ' - f'Please choose data format from [snli]') - - if self.for_model == 'esim': - data = self._for_esim(data) - elif self.for_model == 'bert': - data = self._for_bert(data, self.bert_dir) - else: - raise RuntimeError(f'Your model is {self.data_format}, ' - f'Please choose from [esim, bert]') - - if input_field is not None: - if isinstance(input_field, str): - data.set_input(input_field) - elif isinstance(input_field, list): - for field in input_field: - data.set_input(field) - - data_set[n] = data - print(f'successfully load {n} set!') - - if not hasattr(self, 'vocab'): - raise RuntimeError(f'There is NOT vocab attribute built!') - if not hasattr(self, 'label_vocab'): - raise RuntimeError(f'There is NOT label vocab attribute built!') - - if self.for_model != 'bert': - from fastNLP.modules.encoder.embedding import ElmoEmbedding - embedding = ElmoEmbedding(self.vocab, model_dir_or_name='en', requires_grad=True, layers='2') - - data_info = DataInfo(vocabs={'vocab': self.vocab, 'target_vocab': self.label_vocab}, - embeddings={'elmo': embedding} if self.for_model != 'bert' else None, - datasets=data_set) - - return data_info - - @staticmethod - def _load_snli(path: str) -> DataSet: - """ - 读取SNLI数据集 - - 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip - :param str path: 数据集路径 - :return: - """ - raw_ds = JsonLoader( - fields={ - 'sentence1_parse': Const.INPUTS(0), - 'sentence2_parse': Const.INPUTS(1), - 'gold_label': Const.TARGET, - } - )._load(path) - return raw_ds - - def _for_esim(self, raw_ds: DataSet): - if self.data_format == 'snli' or self.data_format == 'mnli': - def parse_tree(x): - t = Tree.fromstring(x) - return t.leaves() - - raw_ds.apply(lambda ins: parse_tree( - ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) - raw_ds.apply(lambda ins: parse_tree( - ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) - raw_ds.drop(lambda x: x[Const.TARGET] == '-') - - if not hasattr(self, 'vocab'): - self.vocab = Vocabulary().from_dataset(raw_ds, field_name=[Const.INPUTS(0), Const.INPUTS(1)]) - if not hasattr(self, 'label_vocab'): - self.label_vocab = Vocabulary(padding=None, unknown=None).from_dataset(raw_ds, field_name=Const.TARGET) - - raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0)) - raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1)) - raw_ds.apply(lambda ins: self.label_vocab.to_index(ins[Const.TARGET]), new_field_name=Const.TARGET) - raw_ds.apply(lambda ins: len(ins[Const.INPUTS(0)]), new_field_name=Const.INPUT_LENS(0)) - raw_ds.apply(lambda ins: len(ins[Const.INPUTS(1)]), new_field_name=Const.INPUT_LENS(1)) - - raw_ds.set_input(Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1)) - raw_ds.set_target(Const.TARGET) - - return raw_ds - - def _for_bert(self, raw_ds: DataSet, bert_dir: str): - if self.data_format == 'snli' or self.data_format == 'mnli': - def parse_tree(x): - t = Tree.fromstring(x) - return t.leaves() - - raw_ds.apply(lambda ins: parse_tree( - ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) - raw_ds.apply(lambda ins: parse_tree( - ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) - raw_ds.drop(lambda x: x[Const.TARGET] == '-') - - tokenizer = BertTokenizer.from_pretrained(bert_dir) - - vocab = Vocabulary(padding=None, unknown=None) - with open(os.path.join(bert_dir, 'vocab.txt')) as f: - lines = f.readlines() - vocab_list = [] - for line in lines: - vocab_list.append(line.strip()) - vocab.add_word_lst(vocab_list) - vocab.build_vocab() - vocab.padding = '[PAD]' - vocab.unknown = '[UNK]' - - if not hasattr(self, 'vocab'): - self.vocab = vocab - else: - for w, idx in self.vocab: - if vocab[w] != idx: - raise AttributeError(f"{self.__class__.__name__} has ") - - for i in range(2): - raw_ds.apply(lambda x: tokenizer.tokenize(" ".join(x[Const.INPUTS(i)])), new_field_name=Const.INPUTS(i)) - raw_ds.apply(lambda x: ['[CLS]'] + x[Const.INPUTS(0)] + ['[SEP]'] + x[Const.INPUTS(1)] + ['[SEP]'], - new_field_name=Const.INPUT) - raw_ds.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), - new_field_name=Const.INPUT_LENS(0)) - raw_ds.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), new_field_name=Const.INPUT_LENS(1)) - - max_len = 512 - raw_ds.apply(lambda x: x[Const.INPUT][: max_len], new_field_name=Const.INPUT) - raw_ds.apply(lambda x: [self.vocab.to_index(w) for w in x[Const.INPUT]], new_field_name=Const.INPUT) - raw_ds.apply(lambda x: x[Const.INPUT_LENS(0)][: max_len], new_field_name=Const.INPUT_LENS(0)) - raw_ds.apply(lambda x: x[Const.INPUT_LENS(1)][: max_len], new_field_name=Const.INPUT_LENS(1)) - - if not hasattr(self, 'label_vocab'): - self.label_vocab = Vocabulary(padding=None, unknown=None) - self.label_vocab.from_dataset(raw_ds, field_name=Const.TARGET) - raw_ds.apply(lambda x: self.label_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET) - - raw_ds.set_input(Const.INPUT, Const.INPUT_LENS(0), Const.INPUT_LENS(1)) - raw_ds.set_target(Const.TARGET) - - return raw_ds - - class SNLILoader(JsonLoader): """ 别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` diff --git a/reproduction/matching/data/MatchingDataLoader.py b/reproduction/matching/data/MatchingDataLoader.py new file mode 100644 index 00000000..305143b9 --- /dev/null +++ b/reproduction/matching/data/MatchingDataLoader.py @@ -0,0 +1,219 @@ + +import os + +from nltk import Tree +from typing import Union, Dict + +from fastNLP.core.const import Const +from fastNLP.core.vocabulary import Vocabulary +from fastNLP.core.dataset import DataSet +from fastNLP.io.base_loader import DataInfo +from fastNLP.io.dataset_loader import JsonLoader +from fastNLP.io.file_utils import _get_base_url, cached_path +from fastNLP.modules.encoder._bert import BertTokenizer + + +class MatchingLoader(JsonLoader): + """ + 别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` + + 读取Matching任务的数据集 + """ + + def __init__(self, fields=None, paths: dict=None): + super(MatchingLoader, self).__init__(fields=fields) + self.paths = paths + + def _load(self, path): + return super(MatchingLoader, self)._load(path) + + def process(self, paths: Union[str, Dict[str, str]], dataset_name=None, + to_lower=False, char_information=False, seq_len_type: str=None, + bert_tokenizer: str=None, get_index=True, set_input: Union[list, bool]=True, + set_target: Union[list, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo: + if isinstance(set_input, bool): + auto_set_input = set_input + else: + auto_set_input = False + if isinstance(set_target, bool): + auto_set_target = set_target + else: + auto_set_target = False + if isinstance(paths, str): + if os.path.isdir(paths): + path = {n: os.path.join(paths, self.paths[n]) for n in self.paths.keys()} + else: + path = {dataset_name if dataset_name is not None else 'train': paths} + else: + path = paths + + data_info = DataInfo() + for data_name in path.keys(): + data_info.datasets[data_name] = self._load(path[data_name]) + + for data_name, data_set in data_info.datasets.items(): + if auto_set_input: + data_set.set_input(Const.INPUTS(0), Const.INPUTS(1)) + if auto_set_target: + data_set.set_target(Const.TARGET) + + if to_lower: + for data_name, data_set in data_info.datasets.items(): + data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0), + is_input=auto_set_input) + data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1), + is_input=auto_set_input) + + if bert_tokenizer is not None: + PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip', + 'en-base-uncased': 'bert-base-uncased-3413b23c.zip', + 'en-base-cased': 'bert-base-cased-f89bfe08.zip', + 'en-large-uncased': 'bert-large-uncased-20939f45.zip', + 'en-large-cased': 'bert-large-cased-e0cf90fc.zip', + + 'cn': 'bert-base-chinese-29d0a84a.zip', + 'cn-base': 'bert-base-chinese-29d0a84a.zip', + + 'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip', + 'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip', + 'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip', + } + if bert_tokenizer.lower() in PRETRAINED_BERT_MODEL_DIR: + PRETRAIN_URL = _get_base_url('bert') + model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer] + model_url = PRETRAIN_URL + model_name + model_dir = cached_path(model_url) + # 检查是否存在 + elif os.path.isdir(bert_tokenizer): + model_dir = bert_tokenizer + else: + raise ValueError(f"Cannot recognize BERT tokenizer from {bert_tokenizer}.") + + tokenizer = BertTokenizer.from_pretrained(model_dir) + + for data_name, data_set in data_info.datasets.items(): + for fields in data_set.get_field_names(): + if Const.INPUT in fields: + data_set.apply(lambda x: tokenizer.tokenize(' '.join(x[fields])), new_field_name=fields, + is_input=auto_set_input) + + if isinstance(concat, bool): + concat = 'default' if concat else None + if concat is not None: + if isinstance(concat, str): + CONCAT_MAP = {'bert': ['[CLS]', '[SEP]', '', '[SEP]'], + 'default': ['', '', '', '']} + if concat.lower() in CONCAT_MAP: + concat = CONCAT_MAP[concat] + else: + concat = 4 * [concat] + assert len(concat) == 4, \ + f'Please choose a list with 4 symbols which at the beginning of first sentence ' \ + f'the end of first sentence, the begin of second sentence, and the end of second' \ + f'sentence. Your input is {concat}' + + for data_name, data_set in data_info.datasets.items(): + data_set.apply(lambda x: [concat[0]] + x[Const.INPUTS(0)] + [concat[1]] + [concat[2]] + + x[Const.INPUTS(1)] + [concat[3]], new_field_name=Const.INPUT) + data_set.apply(lambda x: [w for w in x[Const.INPUT] if len(w) > 0], new_field_name=Const.INPUT, + is_input=auto_set_input) + + if seq_len_type is not None: + if seq_len_type == 'seq_len': # + for data_name, data_set in data_info.datasets.items(): + for fields in data_set.get_field_names(): + if Const.INPUT in fields: + data_set.apply(lambda x: len(x[fields]), + new_field_name=fields.replace(Const.INPUT, Const.TARGET), + is_input=auto_set_input) + elif seq_len_type == 'mask': + for data_name, data_set in data_info.datasets.items(): + for fields in data_set.get_field_names(): + if Const.INPUT in fields: + data_set.apply(lambda x: [1] * len(x[fields]), + new_field_name=fields.replace(Const.INPUT, Const.TARGET), + is_input=auto_set_input) + elif seq_len_type == 'bert': + for data_name, data_set in data_info.datasets.items(): + if Const.INPUT not in data_set.get_field_names(): + raise KeyError(f'Field ``{Const.INPUT}`` not in {data_name} data set: ' + f'got {data_set.get_field_names()}') + data_set.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), + new_field_name=Const.INPUT_LENS(0), is_input=auto_set_input) + data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), + new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) + + data_set_list = [d for n, d in data_info.datasets.items()] + assert len(data_set_list) > 0, f'There are NO data sets in data info!' + + if bert_tokenizer is not None: + words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]') + else: + words_vocab = Vocabulary() + words_vocab = words_vocab.from_dataset(*data_set_list, + field_name=[n for n in data_set_list[0].get_field_names() + if (Const.INPUT in n)]) + target_vocab = Vocabulary(padding=None, unknown=None) + target_vocab = target_vocab.from_dataset(*data_set_list, field_name=Const.TARGET) + data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab} + + if get_index: + for data_name, data_set in data_info.datasets.items(): + for fields in data_set.get_field_names(): + if Const.INPUT in fields: + data_set.apply(lambda x: [words_vocab.to_index(w) for w in x[fields]], new_field_name=fields, + is_input=auto_set_input) + + data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, + is_input=auto_set_input, is_target=auto_set_target) + + for data_name, data_set in data_info.datasets.items(): + if isinstance(set_input, list): + data_set.set_input(set_input) + if isinstance(set_target, list): + data_set.set_target(set_target) + + return data_info + + +class SNLILoader(MatchingLoader): + """ + 别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` + + 读取SNLI数据集,读取的DataSet包含fields:: + + words1: list(str),第一句文本, premise + words2: list(str), 第二句文本, hypothesis + target: str, 真实标签 + + 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip + """ + + def __init__(self, paths: dict=None): + fields = { + 'sentence1_parse': Const.INPUTS(0), + 'sentence2_parse': Const.INPUTS(1), + 'gold_label': Const.TARGET, + } + paths = paths if paths is not None else { + 'train': 'snli_1.0_train.jsonl', + 'dev': 'snli_1.0_dev.jsonl', + 'test': 'snli_1.0_test.jsonl'} + super(SNLILoader, self).__init__(fields=fields, paths=paths) + + def _load(self, path): + ds = super(SNLILoader, self)._load(path) + + def parse_tree(x): + t = Tree.fromstring(x) + return t.leaves() + + ds.apply(lambda ins: parse_tree( + ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) + ds.apply(lambda ins: parse_tree( + ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) + ds.drop(lambda x: x[Const.TARGET] == '-') + return ds + + + diff --git a/reproduction/matching/data/SNLIDataLoader.py b/reproduction/matching/data/SNLIDataLoader.py deleted file mode 100644 index 6f6bbecd..00000000 --- a/reproduction/matching/data/SNLIDataLoader.py +++ /dev/null @@ -1,6 +0,0 @@ - -from fastNLP.io.dataset_loader import SNLILoader - -# TODO: still in progress - - diff --git a/reproduction/matching/test/test_snlidataloader.py b/reproduction/matching/test/test_snlidataloader.py index bd5c58b6..60b3ad59 100644 --- a/reproduction/matching/test/test_snlidataloader.py +++ b/reproduction/matching/test/test_snlidataloader.py @@ -1,10 +1,10 @@ import unittest -from ..data import SNLIDataLoader +from ..data import MatchingDataLoader from fastNLP.core.vocabulary import Vocabulary class TestCWSDataLoader(unittest.TestCase): def test_case1(self): - snli_loader = SNLIDataLoader() + snli_loader = MatchingDataLoader() # TODO: still in progress