From ea0f2f7e00188ab44bad21d8a6e53aa55601a3b6 Mon Sep 17 00:00:00 2001 From: xuyige Date: Mon, 19 Aug 2019 20:48:08 +0800 Subject: [PATCH] update reproduction/matching to adapt version 0.5.0: 1) move loader codes from DataLoader to PiPe; 2) fix some bugs in matching pipe; 3) delete some expire codes. --- fastNLP/io/pipe/matching.py | 12 +- .../matching/data/MatchingDataLoader.py | 435 ------------------ reproduction/matching/matching_bert.py | 76 +-- reproduction/matching/matching_cntn.py | 42 +- reproduction/matching/matching_esim.py | 69 ++- reproduction/matching/matching_mwan.py | 60 +-- reproduction/matching/model/bert.py | 35 +- reproduction/matching/model/cntn.py | 20 +- reproduction/matching/model/esim.py | 21 +- .../matching/test/test_snlidataloader.py | 10 - 10 files changed, 142 insertions(+), 638 deletions(-) delete mode 100644 reproduction/matching/data/MatchingDataLoader.py delete mode 100644 reproduction/matching/test/test_snlidataloader.py diff --git a/fastNLP/io/pipe/matching.py b/fastNLP/io/pipe/matching.py index 2eaeef58..0d1b4e82 100644 --- a/fastNLP/io/pipe/matching.py +++ b/fastNLP/io/pipe/matching.py @@ -89,13 +89,15 @@ class MatchingBertPipe(Pipe): data_bundle.set_vocab(word_vocab, Const.INPUT) data_bundle.set_vocab(target_vocab, Const.TARGET) - input_fields = [Const.INPUT, Const.INPUT_LEN, Const.TARGET] + input_fields = [Const.INPUT, Const.INPUT_LEN] target_fields = [Const.TARGET] for name, dataset in data_bundle.datasets.items(): dataset.add_seq_len(Const.INPUT) dataset.set_input(*input_fields, flag=True) - dataset.set_target(*target_fields, flag=True) + for fields in target_fields: + if dataset.has_field(fields): + dataset.set_target(fields, flag=True) return data_bundle @@ -210,14 +212,16 @@ class MatchingPipe(Pipe): data_bundle.set_vocab(word_vocab, Const.INPUTS(0)) data_bundle.set_vocab(target_vocab, Const.TARGET) - input_fields = [Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1), Const.TARGET] + input_fields = [Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1)] target_fields = [Const.TARGET] for name, dataset in data_bundle.datasets.items(): dataset.add_seq_len(Const.INPUTS(0), Const.INPUT_LENS(0)) dataset.add_seq_len(Const.INPUTS(1), Const.INPUT_LENS(1)) dataset.set_input(*input_fields, flag=True) - dataset.set_target(*target_fields, flag=True) + for fields in target_fields: + if dataset.has_field(fields): + dataset.set_target(fields, flag=True) return data_bundle diff --git a/reproduction/matching/data/MatchingDataLoader.py b/reproduction/matching/data/MatchingDataLoader.py deleted file mode 100644 index f13618aa..00000000 --- a/reproduction/matching/data/MatchingDataLoader.py +++ /dev/null @@ -1,435 +0,0 @@ -""" -这个文件的内容已合并到fastNLP.io.data_loader里,这个文件的内容不再更新 -""" - - -import os - -from typing import Union, Dict - -from fastNLP.core.const import Const -from fastNLP.core.vocabulary import Vocabulary -from fastNLP.io.data_bundle import DataBundle, DataSetLoader -from fastNLP.io.dataset_loader import JsonLoader, CSVLoader -from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR -from fastNLP.modules.encoder._bert import BertTokenizer - - -class MatchingLoader(DataSetLoader): - """ - 别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` - - 读取Matching任务的数据集 - - :param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名 - """ - - def __init__(self, paths: dict=None): - self.paths = paths - - def _load(self, path): - """ - :param str path: 待读取数据集的路径名 - :return: fastNLP.DataSet ds: 返回一个DataSet对象,里面必须包含3个field:其中两个分别为两个句子 - 的原始字符串文本,第三个为标签 - """ - raise NotImplementedError - - def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None, - to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, - cut_text: int = None, get_index=True, auto_pad_length: int=None, - auto_pad_token: str='', set_input: Union[list, str, bool]=True, - set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataBundle: - """ - :param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, - 则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 - 对应的全路径文件名。 - :param str dataset_name: 如果在paths里传入的是一个数据集的全路径文件名,那么可以用dataset_name来定义 - 这个数据集的名字,如果不定义则默认为train。 - :param bool to_lower: 是否将文本自动转为小写。默认值为False。 - :param str seq_len_type: 提供的seq_len类型,支持 ``seq_len`` :提供一个数字作为句子长度; ``mask`` : - 提供一个0/1的mask矩阵作为句子长度; ``bert`` :提供segment_type_id(第一个句子为0,第二个句子为1)和 - attention mask矩阵(0/1的mask矩阵)。默认值为None,即不提供seq_len - :param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径 - :param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。 - :param bool get_index: 是否需要根据词表将文本转为index - :param int auto_pad_length: 是否需要将文本自动pad到一定长度(超过这个长度的文本将会被截掉),默认为不会自动pad - :param str auto_pad_token: 自动pad的内容 - :param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False - 则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input, - 于此同时其他field不会被设置为input。默认值为True。 - :param set_target: set_target将控制哪些field可以被设置为target,用法与set_input一致。默认值为True。 - :param concat: 是否需要将两个句子拼接起来。如果为False则不会拼接。如果为True则会在两个句子之间插入一个。 - 如果传入一个长度为4的list,则分别表示插在第一句开始前、第一句结束后、第二句开始前、第二句结束后的标识符。如果 - 传入字符串 ``bert`` ,则会采用bert的拼接方式,等价于['[CLS]', '[SEP]', '', '[SEP]']. - :return: - """ - if isinstance(set_input, str): - set_input = [set_input] - if isinstance(set_target, str): - set_target = [set_target] - if isinstance(set_input, bool): - auto_set_input = set_input - else: - auto_set_input = False - if isinstance(set_target, bool): - auto_set_target = set_target - else: - auto_set_target = False - if isinstance(paths, str): - if os.path.isdir(paths): - path = {n: os.path.join(paths, self.paths[n]) for n in self.paths.keys()} - else: - path = {dataset_name if dataset_name is not None else 'train': paths} - else: - path = paths - - data_info = DataBundle() - for data_name in path.keys(): - data_info.datasets[data_name] = self._load(path[data_name]) - - for data_name, data_set in data_info.datasets.items(): - if auto_set_input: - data_set.set_input(Const.INPUTS(0), Const.INPUTS(1)) - if auto_set_target: - if Const.TARGET in data_set.get_field_names(): - data_set.set_target(Const.TARGET) - - if to_lower: - for data_name, data_set in data_info.datasets.items(): - data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0), - is_input=auto_set_input) - data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1), - is_input=auto_set_input) - - if bert_tokenizer is not None: - if bert_tokenizer.lower() in PRETRAINED_BERT_MODEL_DIR: - PRETRAIN_URL = _get_base_url('bert') - model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer] - model_url = PRETRAIN_URL + model_name - model_dir = cached_path(model_url) - # 检查是否存在 - elif os.path.isdir(bert_tokenizer): - model_dir = bert_tokenizer - else: - raise ValueError(f"Cannot recognize BERT tokenizer from {bert_tokenizer}.") - - words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]') - with open(os.path.join(model_dir, 'vocab.txt'), 'r') as f: - lines = f.readlines() - lines = [line.strip() for line in lines] - words_vocab.add_word_lst(lines) - words_vocab.build_vocab() - - tokenizer = BertTokenizer.from_pretrained(model_dir) - - for data_name, data_set in data_info.datasets.items(): - for fields in data_set.get_field_names(): - if Const.INPUT in fields: - data_set.apply(lambda x: tokenizer.tokenize(' '.join(x[fields])), new_field_name=fields, - is_input=auto_set_input) - - if isinstance(concat, bool): - concat = 'default' if concat else None - if concat is not None: - if isinstance(concat, str): - CONCAT_MAP = {'bert': ['[CLS]', '[SEP]', '', '[SEP]'], - 'default': ['', '', '', '']} - if concat.lower() in CONCAT_MAP: - concat = CONCAT_MAP[concat] - else: - concat = 4 * [concat] - assert len(concat) == 4, \ - f'Please choose a list with 4 symbols which at the beginning of first sentence ' \ - f'the end of first sentence, the begin of second sentence, and the end of second' \ - f'sentence. Your input is {concat}' - - for data_name, data_set in data_info.datasets.items(): - data_set.apply(lambda x: [concat[0]] + x[Const.INPUTS(0)] + [concat[1]] + [concat[2]] + - x[Const.INPUTS(1)] + [concat[3]], new_field_name=Const.INPUT) - data_set.apply(lambda x: [w for w in x[Const.INPUT] if len(w) > 0], new_field_name=Const.INPUT, - is_input=auto_set_input) - - if seq_len_type is not None: - if seq_len_type == 'seq_len': # - for data_name, data_set in data_info.datasets.items(): - for fields in data_set.get_field_names(): - if Const.INPUT in fields: - data_set.apply(lambda x: len(x[fields]), - new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN), - is_input=auto_set_input) - elif seq_len_type == 'mask': - for data_name, data_set in data_info.datasets.items(): - for fields in data_set.get_field_names(): - if Const.INPUT in fields: - data_set.apply(lambda x: [1] * len(x[fields]), - new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN), - is_input=auto_set_input) - elif seq_len_type == 'bert': - for data_name, data_set in data_info.datasets.items(): - if Const.INPUT not in data_set.get_field_names(): - raise KeyError(f'Field ``{Const.INPUT}`` not in {data_name} data set: ' - f'got {data_set.get_field_names()}') - data_set.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), - new_field_name=Const.INPUT_LENS(0), is_input=auto_set_input) - data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), - new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) - - if auto_pad_length is not None: - cut_text = min(auto_pad_length, cut_text if cut_text is not None else auto_pad_length) - - if cut_text is not None: - for data_name, data_set in data_info.datasets.items(): - for fields in data_set.get_field_names(): - if (Const.INPUT in fields) or ((Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len')): - data_set.apply(lambda x: x[fields][: cut_text], new_field_name=fields, - is_input=auto_set_input) - - data_set_list = [d for n, d in data_info.datasets.items()] - assert len(data_set_list) > 0, f'There are NO data sets in data info!' - - if bert_tokenizer is None: - words_vocab = Vocabulary(padding=auto_pad_token) - words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], - field_name=[n for n in data_set_list[0].get_field_names() - if (Const.INPUT in n)], - no_create_entry_dataset=[d for n, d in data_info.datasets.items() - if 'train' not in n]) - target_vocab = Vocabulary(padding=None, unknown=None) - target_vocab = target_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], - field_name=Const.TARGET) - data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab} - - if get_index: - for data_name, data_set in data_info.datasets.items(): - for fields in data_set.get_field_names(): - if Const.INPUT in fields: - data_set.apply(lambda x: [words_vocab.to_index(w) for w in x[fields]], new_field_name=fields, - is_input=auto_set_input) - - if Const.TARGET in data_set.get_field_names(): - data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, - is_input=auto_set_input, is_target=auto_set_target) - - if auto_pad_length is not None: - if seq_len_type == 'seq_len': - raise RuntimeError(f'the sequence will be padded with the length {auto_pad_length}, ' - f'so the seq_len_type cannot be `{seq_len_type}`!') - for data_name, data_set in data_info.datasets.items(): - for fields in data_set.get_field_names(): - if Const.INPUT in fields: - data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] * - (auto_pad_length - len(x[fields])), new_field_name=fields, - is_input=auto_set_input) - elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'): - data_set.apply(lambda x: x[fields] + [0] * (auto_pad_length - len(x[fields])), - new_field_name=fields, is_input=auto_set_input) - - for data_name, data_set in data_info.datasets.items(): - if isinstance(set_input, list): - data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()]) - if isinstance(set_target, list): - data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()]) - - return data_info - - -class SNLILoader(MatchingLoader, JsonLoader): - """ - 别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` - - 读取SNLI数据集,读取的DataSet包含fields:: - - words1: list(str),第一句文本, premise - words2: list(str), 第二句文本, hypothesis - target: str, 真实标签 - - 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip - """ - - def __init__(self, paths: dict=None): - fields = { - 'sentence1_binary_parse': Const.INPUTS(0), - 'sentence2_binary_parse': Const.INPUTS(1), - 'gold_label': Const.TARGET, - } - paths = paths if paths is not None else { - 'train': 'snli_1.0_train.jsonl', - 'dev': 'snli_1.0_dev.jsonl', - 'test': 'snli_1.0_test.jsonl'} - MatchingLoader.__init__(self, paths=paths) - JsonLoader.__init__(self, fields=fields) - - def _load(self, path): - ds = JsonLoader._load(self, path) - - parentheses_table = str.maketrans({'(': None, ')': None}) - - ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), - new_field_name=Const.INPUTS(0)) - ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), - new_field_name=Const.INPUTS(1)) - ds.drop(lambda x: x[Const.TARGET] == '-') - return ds - - -class RTELoader(MatchingLoader, CSVLoader): - """ - 别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.dataset_loader.RTELoader` - - 读取RTE数据集,读取的DataSet包含fields:: - - words1: list(str),第一句文本, premise - words2: list(str), 第二句文本, hypothesis - target: str, 真实标签 - - 数据来源: - """ - - def __init__(self, paths: dict=None): - paths = paths if paths is not None else { - 'train': 'train.tsv', - 'dev': 'dev.tsv', - 'test': 'test.tsv' # test set has not label - } - MatchingLoader.__init__(self, paths=paths) - self.fields = { - 'sentence1': Const.INPUTS(0), - 'sentence2': Const.INPUTS(1), - 'label': Const.TARGET, - } - CSVLoader.__init__(self, sep='\t') - - def _load(self, path): - ds = CSVLoader._load(self, path) - - for k, v in self.fields.items(): - if v in ds.get_field_names(): - ds.rename_field(k, v) - for fields in ds.get_all_fields(): - if Const.INPUT in fields: - ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) - - return ds - - -class QNLILoader(MatchingLoader, CSVLoader): - """ - 别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.dataset_loader.QNLILoader` - - 读取QNLI数据集,读取的DataSet包含fields:: - - words1: list(str),第一句文本, premise - words2: list(str), 第二句文本, hypothesis - target: str, 真实标签 - - 数据来源: - """ - - def __init__(self, paths: dict=None): - paths = paths if paths is not None else { - 'train': 'train.tsv', - 'dev': 'dev.tsv', - 'test': 'test.tsv' # test set has not label - } - MatchingLoader.__init__(self, paths=paths) - self.fields = { - 'question': Const.INPUTS(0), - 'sentence': Const.INPUTS(1), - 'label': Const.TARGET, - } - CSVLoader.__init__(self, sep='\t') - - def _load(self, path): - ds = CSVLoader._load(self, path) - - for k, v in self.fields.items(): - if v in ds.get_field_names(): - ds.rename_field(k, v) - for fields in ds.get_all_fields(): - if Const.INPUT in fields: - ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) - - return ds - - -class MNLILoader(MatchingLoader, CSVLoader): - """ - 别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader` - - 读取MNLI数据集,读取的DataSet包含fields:: - - words1: list(str),第一句文本, premise - words2: list(str), 第二句文本, hypothesis - target: str, 真实标签 - - 数据来源: - """ - - def __init__(self, paths: dict=None): - paths = paths if paths is not None else { - 'train': 'train.tsv', - 'dev_matched': 'dev_matched.tsv', - 'dev_mismatched': 'dev_mismatched.tsv', - 'test_matched': 'test_matched.tsv', - 'test_mismatched': 'test_mismatched.tsv', - # 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', - # 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', - - # test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) - } - MatchingLoader.__init__(self, paths=paths) - CSVLoader.__init__(self, sep='\t') - self.fields = { - 'sentence1_binary_parse': Const.INPUTS(0), - 'sentence2_binary_parse': Const.INPUTS(1), - 'gold_label': Const.TARGET, - } - - def _load(self, path): - ds = CSVLoader._load(self, path) - - for k, v in self.fields.items(): - if k in ds.get_field_names(): - ds.rename_field(k, v) - - if Const.TARGET in ds.get_field_names(): - if ds[0][Const.TARGET] == 'hidden': - ds.delete_field(Const.TARGET) - - parentheses_table = str.maketrans({'(': None, ')': None}) - - ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), - new_field_name=Const.INPUTS(0)) - ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), - new_field_name=Const.INPUTS(1)) - if Const.TARGET in ds.get_field_names(): - ds.drop(lambda x: x[Const.TARGET] == '-') - return ds - - -class QuoraLoader(MatchingLoader, CSVLoader): - """ - 别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.dataset_loader.QuoraLoader` - - 读取MNLI数据集,读取的DataSet包含fields:: - - words1: list(str),第一句文本, premise - words2: list(str), 第二句文本, hypothesis - target: str, 真实标签 - - 数据来源: - """ - - def __init__(self, paths: dict=None): - paths = paths if paths is not None else { - 'train': 'train.tsv', - 'dev': 'dev.tsv', - 'test': 'test.tsv', - } - MatchingLoader.__init__(self, paths=paths) - CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID')) - - def _load(self, path): - ds = CSVLoader._load(self, path) - return ds diff --git a/reproduction/matching/matching_bert.py b/reproduction/matching/matching_bert.py index 3ed75fd1..323d81a3 100644 --- a/reproduction/matching/matching_bert.py +++ b/reproduction/matching/matching_bert.py @@ -2,8 +2,12 @@ import random import numpy as np import torch -from fastNLP.core import Trainer, Tester, AccuracyMetric, Const, Adam -from fastNLP.io.data_loader import SNLILoader, RTELoader, MNLILoader, QNLILoader, QuoraLoader +from fastNLP.core import Trainer, Tester, AccuracyMetric, Const +from fastNLP.core.callback import WarmupCallback, EvaluateCallback +from fastNLP.core.optimizer import AdamW +from fastNLP.embeddings import BertEmbedding +from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, MNLIBertPipe,\ + QNLIBertPipe, QuoraBertPipe from reproduction.matching.model.bert import BertForNLI @@ -12,16 +16,22 @@ from reproduction.matching.model.bert import BertForNLI class BERTConfig: task = 'snli' + batch_size_per_gpu = 6 n_epochs = 6 lr = 2e-5 - seq_len_type = 'bert' + warm_up_rate = 0.1 seed = 42 + save_path = None # 模型存储的位置,None表示不存储模型。 + train_dataset_name = 'train' dev_dataset_name = 'dev' test_dataset_name = 'test' - save_path = None # 模型存储的位置,None表示不存储模型。 - bert_dir = 'path/to/bert/dir' # 预训练BERT参数文件的文件夹 + + to_lower = True # 忽略大小写 + tokenizer = 'spacy' # 使用spacy进行分词 + + bert_model_dir_or_name = 'bert-base-uncased' arg = BERTConfig() @@ -37,58 +47,52 @@ if n_gpu > 0: # load data set if arg.task == 'snli': - data_info = SNLILoader().process( - paths='path/to/snli/data', to_lower=True, seq_len_type=arg.seq_len_type, - bert_tokenizer=arg.bert_dir, cut_text=512, - get_index=True, concat='bert', - ) + data_bundle = SNLIBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() elif arg.task == 'rte': - data_info = RTELoader().process( - paths='path/to/rte/data', to_lower=True, seq_len_type=arg.seq_len_type, - bert_tokenizer=arg.bert_dir, cut_text=512, - get_index=True, concat='bert', - ) + data_bundle = RTEBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() elif arg.task == 'qnli': - data_info = QNLILoader().process( - paths='path/to/qnli/data', to_lower=True, seq_len_type=arg.seq_len_type, - bert_tokenizer=arg.bert_dir, cut_text=512, - get_index=True, concat='bert', - ) + data_bundle = QNLIBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() elif arg.task == 'mnli': - data_info = MNLILoader().process( - paths='path/to/mnli/data', to_lower=True, seq_len_type=arg.seq_len_type, - bert_tokenizer=arg.bert_dir, cut_text=512, - get_index=True, concat='bert', - ) + data_bundle = MNLIBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() elif arg.task == 'quora': - data_info = QuoraLoader().process( - paths='path/to/quora/data', to_lower=True, seq_len_type=arg.seq_len_type, - bert_tokenizer=arg.bert_dir, cut_text=512, - get_index=True, concat='bert', - ) + data_bundle = QuoraBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() else: raise RuntimeError(f'NOT support {arg.task} task yet!') +print(data_bundle) # print details in data_bundle + +# load embedding +embed = BertEmbedding(data_bundle.vocabs[Const.INPUT], model_dir_or_name=arg.bert_model_dir_or_name) + # define model -model = BertForNLI(class_num=len(data_info.vocabs[Const.TARGET]), bert_dir=arg.bert_dir) +model = BertForNLI(embed, class_num=len(data_bundle.vocabs[Const.TARGET])) + +# define optimizer and callback +optimizer = AdamW(lr=arg.lr, params=model.parameters()) +callbacks = [WarmupCallback(warmup=arg.warm_up_rate, schedule='linear'), ] + +if arg.task in ['snli']: + callbacks.append(EvaluateCallback(data=data_bundle.datasets[arg.test_dataset_name])) + # evaluate test set in every epoch if task is snli. # define trainer -trainer = Trainer(train_data=data_info.datasets[arg.train_dataset_name], model=model, - optimizer=Adam(lr=arg.lr, model_params=model.parameters()), +trainer = Trainer(train_data=data_bundle.datasets[arg.train_dataset_name], model=model, + optimizer=optimizer, batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, n_epochs=arg.n_epochs, print_every=-1, - dev_data=data_info.datasets[arg.dev_dataset_name], + dev_data=data_bundle.datasets[arg.dev_dataset_name], metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())], check_code_level=-1, - save_path=arg.save_path) + save_path=arg.save_path, + callbacks=callbacks) # train model trainer.train(load_best_model=True) # define tester tester = Tester( - data=data_info.datasets[arg.test_dataset_name], + data=data_bundle.datasets[arg.test_dataset_name], model=model, metrics=AccuracyMetric(), batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, diff --git a/reproduction/matching/matching_cntn.py b/reproduction/matching/matching_cntn.py index 098f3bc4..9be716ba 100644 --- a/reproduction/matching/matching_cntn.py +++ b/reproduction/matching/matching_cntn.py @@ -1,9 +1,9 @@ import argparse import torch -from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const +from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const, CrossEntropyLoss from fastNLP.embeddings import StaticEmbedding -from fastNLP.io.data_loader import QNLILoader, RTELoader, SNLILoader, MNLILoader +from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, MNLIPipe, QNLIPipe from reproduction.matching.model.cntn import CNTNModel @@ -13,14 +13,12 @@ argument.add_argument('--embedding', choices=['glove', 'word2vec'], default='glo argument.add_argument('--batch-size-per-gpu', type=int, default=256) argument.add_argument('--n-epochs', type=int, default=200) argument.add_argument('--lr', type=float, default=1e-5) -argument.add_argument('--seq-len-type', choices=['mask', 'seq_len'], default='mask') argument.add_argument('--save-dir', type=str, default=None) argument.add_argument('--cntn-depth', type=int, default=1) argument.add_argument('--cntn-ns', type=int, default=200) argument.add_argument('--cntn-k-top', type=int, default=10) argument.add_argument('--cntn-r', type=int, default=5) argument.add_argument('--dataset', choices=['qnli', 'rte', 'snli', 'mnli'], default='qnli') -argument.add_argument('--max-len', type=int, default=50) arg = argument.parse_args() # dataset dict @@ -45,30 +43,25 @@ else: num_labels = 3 # load data set -if arg.dataset == 'qnli': - data_info = QNLILoader().process( - paths='path/to/qnli/data', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, - get_index=True, concat=False, auto_pad_length=arg.max_len) +if arg.dataset == 'snli': + data_bundle = SNLIPipe(lower=True, tokenizer='raw').process_from_file() elif arg.dataset == 'rte': - data_info = RTELoader().process( - paths='path/to/rte/data', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, - get_index=True, concat=False, auto_pad_length=arg.max_len) -elif arg.dataset == 'snli': - data_info = SNLILoader().process( - paths='path/to/snli/data', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, - get_index=True, concat=False, auto_pad_length=arg.max_len) + data_bundle = RTEPipe(lower=True, tokenizer='raw').process_from_file() +elif arg.dataset == 'qnli': + data_bundle = QNLIPipe(lower=True, tokenizer='raw').process_from_file() elif arg.dataset == 'mnli': - data_info = MNLILoader().process( - paths='path/to/mnli/data', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, - get_index=True, concat=False, auto_pad_length=arg.max_len) + data_bundle = MNLIPipe(lower=True, tokenizer='raw').process_from_file() else: - raise ValueError(f'now we only support [qnli,rte,snli,mnli] dataset for cntn model!') + raise RuntimeError(f'NOT support {arg.task} task yet!') + +print(data_bundle) # print details in data_bundle # load embedding if arg.embedding == 'word2vec': - embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], model_dir_or_name='en-word2vec-300', requires_grad=True) + embedding = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-word2vec-300', + requires_grad=True) elif arg.embedding == 'glove': - embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], model_dir_or_name='en-glove-840b-300', + embedding = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-glove-840b-300d', requires_grad=True) else: raise ValueError(f'now we only support word2vec or glove embedding for cntn model!') @@ -79,11 +72,12 @@ model = CNTNModel(embedding, ns=arg.cntn_ns, k_top=arg.cntn_k_top, num_labels=nu print(model) # define trainer -trainer = Trainer(train_data=data_info.datasets['train'], model=model, +trainer = Trainer(train_data=data_bundle.datasets['train'], model=model, optimizer=Adam(lr=arg.lr, model_params=model.parameters()), + loss=CrossEntropyLoss(), batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, n_epochs=arg.n_epochs, print_every=-1, - dev_data=data_info.datasets[dev_dict[arg.dataset]], + dev_data=data_bundle.datasets[dev_dict[arg.dataset]], metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())], check_code_level=-1) @@ -93,7 +87,7 @@ trainer.train(load_best_model=True) # define tester tester = Tester( - data=data_info.datasets[test_dict[arg.dataset]], + data=data_bundle.datasets[test_dict[arg.dataset]], model=model, metrics=AccuracyMetric(), batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, diff --git a/reproduction/matching/matching_esim.py b/reproduction/matching/matching_esim.py index 2ff6916a..9d50c0fb 100644 --- a/reproduction/matching/matching_esim.py +++ b/reproduction/matching/matching_esim.py @@ -6,10 +6,11 @@ from torch.optim import Adamax from torch.optim.lr_scheduler import StepLR from fastNLP.core import Trainer, Tester, AccuracyMetric, Const -from fastNLP.core.callback import GradientClipCallback, LRScheduler -from fastNLP.embeddings.static_embedding import StaticEmbedding -from fastNLP.embeddings.elmo_embedding import ElmoEmbedding -from fastNLP.io.data_loader import SNLILoader, RTELoader, MNLILoader, QNLILoader, QuoraLoader +from fastNLP.core.callback import GradientClipCallback, LRScheduler, EvaluateCallback +from fastNLP.core.losses import CrossEntropyLoss +from fastNLP.embeddings import StaticEmbedding +from fastNLP.embeddings import ElmoEmbedding +from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, MNLIPipe, QNLIPipe, QuoraPipe from fastNLP.models.snli import ESIM @@ -17,18 +18,21 @@ from fastNLP.models.snli import ESIM class ESIMConfig: task = 'snli' + embedding = 'glove' + batch_size_per_gpu = 196 n_epochs = 30 lr = 2e-3 - seq_len_type = 'seq_len' - # seq_len表示在process的时候用len(words)来表示长度信息; - # mask表示用0/1掩码矩阵来表示长度信息; seed = 42 + save_path = None # 模型存储的位置,None表示不存储模型。 + train_dataset_name = 'train' dev_dataset_name = 'dev' test_dataset_name = 'test' - save_path = None # 模型存储的位置,None表示不存储模型。 + + to_lower = True # 忽略大小写 + tokenizer = 'spacy' # 使用spacy进行分词 arg = ESIMConfig() @@ -44,43 +48,32 @@ if n_gpu > 0: # load data set if arg.task == 'snli': - data_info = SNLILoader().process( - paths='path/to/snli/data', to_lower=False, seq_len_type=arg.seq_len_type, - get_index=True, concat=False, - ) + data_bundle = SNLIPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() elif arg.task == 'rte': - data_info = RTELoader().process( - paths='path/to/rte/data', to_lower=False, seq_len_type=arg.seq_len_type, - get_index=True, concat=False, - ) + data_bundle = RTEPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() elif arg.task == 'qnli': - data_info = QNLILoader().process( - paths='path/to/qnli/data', to_lower=False, seq_len_type=arg.seq_len_type, - get_index=True, concat=False, - ) + data_bundle = QNLIPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() elif arg.task == 'mnli': - data_info = MNLILoader().process( - paths='path/to/mnli/data', to_lower=False, seq_len_type=arg.seq_len_type, - get_index=True, concat=False, - ) + data_bundle = MNLIPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() elif arg.task == 'quora': - data_info = QuoraLoader().process( - paths='path/to/quora/data', to_lower=False, seq_len_type=arg.seq_len_type, - get_index=True, concat=False, - ) + data_bundle = QuoraPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() else: raise RuntimeError(f'NOT support {arg.task} task yet!') +print(data_bundle) # print details in data_bundle + # load embedding if arg.embedding == 'elmo': - embedding = ElmoEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True) + embedding = ElmoEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-medium', + requires_grad=True) elif arg.embedding == 'glove': - embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True, normalize=False) + embedding = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-glove-840b-300d', + requires_grad=True, normalize=False) else: raise RuntimeError(f'NOT support {arg.embedding} embedding yet!') # define model -model = ESIM(embedding, num_labels=len(data_info.vocabs[Const.TARGET])) +model = ESIM(embedding, num_labels=len(data_bundle.vocabs[Const.TARGET])) # define optimizer and callback optimizer = Adamax(lr=arg.lr, params=model.parameters()) @@ -91,23 +84,29 @@ callbacks = [ LRScheduler(scheduler), ] +if arg.task in ['snli']: + callbacks.append(EvaluateCallback(data=data_bundle.datasets[arg.test_dataset_name])) + # evaluate test set in every epoch if task is snli. + # define trainer -trainer = Trainer(train_data=data_info.datasets[arg.train_dataset_name], model=model, +trainer = Trainer(train_data=data_bundle.datasets[arg.train_dataset_name], model=model, optimizer=optimizer, + loss=CrossEntropyLoss(), batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, n_epochs=arg.n_epochs, print_every=-1, - dev_data=data_info.datasets[arg.dev_dataset_name], + dev_data=data_bundle.datasets[arg.dev_dataset_name], metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())], check_code_level=-1, - save_path=arg.save_path) + save_path=arg.save_path, + callbacks=callbacks) # train model trainer.train(load_best_model=True) # define tester tester = Tester( - data=data_info.datasets[arg.test_dataset_name], + data=data_bundle.datasets[arg.test_dataset_name], model=model, metrics=AccuracyMetric(), batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, diff --git a/reproduction/matching/matching_mwan.py b/reproduction/matching/matching_mwan.py index 31af54c5..026ea7b4 100644 --- a/reproduction/matching/matching_mwan.py +++ b/reproduction/matching/matching_mwan.py @@ -6,12 +6,11 @@ from torch.optim import Adadelta from torch.optim.lr_scheduler import StepLR from fastNLP import CrossEntropyLoss -from fastNLP import cache_results from fastNLP.core import Trainer, Tester, AccuracyMetric, Const -from fastNLP.core.callback import LRScheduler, FitlogCallback +from fastNLP.core.callback import LRScheduler, EvaluateCallback from fastNLP.embeddings import StaticEmbedding -from fastNLP.io.data_loader import MNLILoader, QNLILoader, SNLILoader, RTELoader +from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, MNLIPipe, QNLIPipe, QuoraPipe from reproduction.matching.model.mwan import MwanModel import fitlog @@ -46,47 +45,25 @@ for k in arg.__dict__: # load data set if arg.task == 'snli': - @cache_results(f'snli_mwan.pkl') - def read_snli(): - data_info = SNLILoader().process( - paths='path/to/snli/data', to_lower=True, seq_len_type=None, bert_tokenizer=None, - get_index=True, concat=False, extra_split=['/','%','-'], - ) - return data_info - data_info = read_snli() + data_bundle = SNLIPipe(lower=True, tokenizer='spacy').process_from_file() elif arg.task == 'rte': - @cache_results(f'rte_mwan.pkl') - def read_rte(): - data_info = RTELoader().process( - paths='path/to/rte/data', to_lower=True, seq_len_type=None, bert_tokenizer=None, - get_index=True, concat=False, extra_split=['/','%','-'], - ) - return data_info - data_info = read_rte() + data_bundle = RTEPipe(lower=True, tokenizer='spacy').process_from_file() elif arg.task == 'qnli': - data_info = QNLILoader().process( - paths='path/to/qnli/data', to_lower=True, seq_len_type=None, bert_tokenizer=None, - get_index=True, concat=False , cut_text=512, extra_split=['/','%','-'], - ) + data_bundle = QNLIPipe(lower=True, tokenizer='spacy').process_from_file() elif arg.task == 'mnli': - @cache_results(f'mnli_v0.9_mwan.pkl') - def read_mnli(): - data_info = MNLILoader().process( - paths='path/to/mnli/data', to_lower=True, seq_len_type=None, bert_tokenizer=None, - get_index=True, concat=False, extra_split=['/','%','-'], - ) - return data_info - data_info = read_mnli() + data_bundle = MNLIPipe(lower=True, tokenizer='spacy').process_from_file() +elif arg.task == 'quora': + data_bundle = QuoraPipe(lower=True, tokenizer='spacy').process_from_file() else: raise RuntimeError(f'NOT support {arg.task} task yet!') -print(data_info) -print(len(data_info.vocabs['words'])) +print(data_bundle) +print(len(data_bundle.vocabs[Const.INPUTS(0)])) model = MwanModel( - num_class = len(data_info.vocabs[Const.TARGET]), - EmbLayer = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=False, normalize=False), + num_class = len(data_bundle.vocabs[Const.TARGET]), + EmbLayer = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], requires_grad=False, normalize=False), ElmoLayer = None, args_of_imm = { "input_size" : 300 , @@ -105,21 +82,20 @@ callbacks = [ ] if arg.task in ['snli']: - callbacks.append(FitlogCallback(data_info.datasets[arg.testset_name], verbose=1)) + callbacks.append(EvaluateCallback(data=data_bundle.datasets[arg.testset_name])) elif arg.task == 'mnli': - callbacks.append(FitlogCallback({'dev_matched': data_info.datasets['dev_matched'], - 'dev_mismatched': data_info.datasets['dev_mismatched']}, - verbose=1)) + callbacks.append(EvaluateCallback(data={'dev_matched': data_bundle.datasets['dev_matched'], + 'dev_mismatched': data_bundle.datasets['dev_mismatched']},)) trainer = Trainer( - train_data = data_info.datasets['train'], + train_data = data_bundle.datasets['train'], model = model, optimizer = optimizer, num_workers = 0, batch_size = arg.batch_size, n_epochs = arg.n_epochs, print_every = -1, - dev_data = data_info.datasets[arg.devset_name], + dev_data = data_bundle.datasets[arg.devset_name], metrics = AccuracyMetric(pred = "pred" , target = "target"), metric_key = 'acc', device = [i for i in range(torch.cuda.device_count())], @@ -130,7 +106,7 @@ trainer = Trainer( trainer.train(load_best_model=True) tester = Tester( - data=data_info.datasets[arg.testset_name], + data=data_bundle.datasets[arg.testset_name], model=model, metrics=AccuracyMetric(), batch_size=arg.batch_size, diff --git a/reproduction/matching/model/bert.py b/reproduction/matching/model/bert.py index a21f8c36..73a0c533 100644 --- a/reproduction/matching/model/bert.py +++ b/reproduction/matching/model/bert.py @@ -3,39 +3,28 @@ import torch import torch.nn as nn from fastNLP.core.const import Const -from fastNLP.models import BaseModel -from fastNLP.embeddings.bert import BertModel +from fastNLP.models.base_model import BaseModel +from fastNLP.embeddings import BertEmbedding class BertForNLI(BaseModel): - # TODO: still in progress - def __init__(self, class_num=3, bert_dir=None): + def __init__(self, bert_embed: BertEmbedding, class_num=3): super(BertForNLI, self).__init__() - if bert_dir is not None: - self.bert = BertModel.from_pretrained(bert_dir) - else: - self.bert = BertModel() - hidden_size = self.bert.pooler.dense._parameters['bias'].size(-1) - self.classifier = nn.Linear(hidden_size, class_num) - - def forward(self, words, seq_len1, seq_len2, target=None): + self.embed = bert_embed + self.classifier = nn.Linear(self.embed.embedding_dim, class_num) + + def forward(self, words): """ :param torch.Tensor words: [batch_size, seq_len] input_ids - :param torch.Tensor seq_len1: [batch_size, seq_len] token_type_ids - :param torch.Tensor seq_len2: [batch_size, seq_len] attention_mask - :param torch.Tensor target: [batch] :return: """ - _, pooled_output = self.bert(words, seq_len1, seq_len2) - logits = self.classifier(pooled_output) + hidden = self.embed(words) + logits = self.classifier(hidden) - if target is not None: - loss_func = torch.nn.CrossEntropyLoss() - loss = loss_func(logits, target) - return {Const.OUTPUT: logits, Const.LOSS: loss} return {Const.OUTPUT: logits} - def predict(self, words, seq_len1, seq_len2, target=None): - return self.forward(words, seq_len1, seq_len2) + def predict(self, words): + logits = self.forward(words)[Const.OUTPUT] + return {Const.OUTPUT: logits.argmax(dim=-1)} diff --git a/reproduction/matching/model/cntn.py b/reproduction/matching/model/cntn.py index a0a104a3..cfa5e5a8 100644 --- a/reproduction/matching/model/cntn.py +++ b/reproduction/matching/model/cntn.py @@ -3,10 +3,8 @@ import torch.nn as nn import torch.nn.functional as F import numpy as np -from torch.nn import CrossEntropyLoss - -from fastNLP.models import BaseModel -from fastNLP.embeddings.embedding import TokenEmbedding +from fastNLP.models.base_model import BaseModel +from fastNLP.embeddings import TokenEmbedding from fastNLP.core.const import Const @@ -83,13 +81,12 @@ class CNTNModel(BaseModel): self.weight_V = nn.Linear(2 * ns, r) self.weight_u = nn.Sequential(nn.Dropout(p=dropout_rate), nn.Linear(r, num_labels)) - def forward(self, words1, words2, seq_len1, seq_len2, target=None): + def forward(self, words1, words2, seq_len1, seq_len2): """ :param words1: [batch, seq_len, emb_size] Question. :param words2: [batch, seq_len, emb_size] Answer. :param seq_len1: [batch] :param seq_len2: [batch] - :param target: [batch] Glod labels. :return: """ in_q = self.embedding(words1) @@ -109,12 +106,7 @@ class CNTNModel(BaseModel): in_a = self.fc_q(in_a.view(in_a.size(0), -1)) score = torch.tanh(self.weight_u(self.weight_M(in_q, in_a) + self.weight_V(torch.cat((in_q, in_a), -1)))) - if target is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(score, target) - return {Const.LOSS: loss, Const.OUTPUT: score} - else: - return {Const.OUTPUT: score} + return {Const.OUTPUT: score} - def predict(self, **kwargs): - return self.forward(**kwargs) + def predict(self, words1, words2, seq_len1, seq_len2): + return self.forward(words1, words2, seq_len1, seq_len2) diff --git a/reproduction/matching/model/esim.py b/reproduction/matching/model/esim.py index 87e5ba65..d704e2f8 100644 --- a/reproduction/matching/model/esim.py +++ b/reproduction/matching/model/esim.py @@ -2,10 +2,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.nn import CrossEntropyLoss - -from fastNLP.models import BaseModel -from fastNLP.embeddings.embedding import TokenEmbedding +from fastNLP.models.base_model import BaseModel +from fastNLP.embeddings import TokenEmbedding from fastNLP.core.const import Const from fastNLP.core.utils import seq_len_to_mask @@ -42,13 +40,12 @@ class ESIMModel(BaseModel): nn.init.xavier_uniform_(self.classifier[1].weight.data) nn.init.xavier_uniform_(self.classifier[4].weight.data) - def forward(self, words1, words2, seq_len1, seq_len2, target=None): + def forward(self, words1, words2, seq_len1, seq_len2): """ :param words1: [batch, seq_len] :param words2: [batch, seq_len] :param seq_len1: [batch] :param seq_len2: [batch] - :param target: :return: """ mask1 = seq_len_to_mask(seq_len1, words1.size(1)) @@ -82,16 +79,10 @@ class ESIMModel(BaseModel): logits = torch.tanh(self.classifier(out)) # logits = self.classifier(out) - if target is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits, target) - - return {Const.LOSS: loss, Const.OUTPUT: logits} - else: - return {Const.OUTPUT: logits} + return {Const.OUTPUT: logits} - def predict(self, **kwargs): - pred = self.forward(**kwargs)[Const.OUTPUT].argmax(-1) + def predict(self, words1, words2, seq_len1, seq_len2): + pred = self.forward(words1, words2, seq_len1, seq_len2)[Const.OUTPUT].argmax(-1) return {Const.OUTPUT: pred} # input [batch_size, len , hidden] diff --git a/reproduction/matching/test/test_snlidataloader.py b/reproduction/matching/test/test_snlidataloader.py deleted file mode 100644 index 60b3ad59..00000000 --- a/reproduction/matching/test/test_snlidataloader.py +++ /dev/null @@ -1,10 +0,0 @@ -import unittest -from ..data import MatchingDataLoader -from fastNLP.core.vocabulary import Vocabulary - - -class TestCWSDataLoader(unittest.TestCase): - def test_case1(self): - snli_loader = MatchingDataLoader() - # TODO: still in progress -