From 43d3380b730398ac4594edfbfc28b9e8fc55ce77 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Mon, 24 Jun 2019 18:31:38 +0800 Subject: [PATCH 01/10] =?UTF-8?q?1.=E4=BF=AE=E5=A4=8DTrainer=E5=88=9D?= =?UTF-8?q?=E5=A7=8B=E5=8C=96=E7=9A=84=E5=A4=9Adevice=20bug;=202.=E5=9C=A8?= =?UTF-8?q?CrossEntropyLoss=E4=B8=AD=E5=A2=9E=E5=8A=A0seq=5Flen?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/losses.py | 20 ++++++++++++-------- fastNLP/core/trainer.py | 14 ++++++-------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 62e7a8c8..526bf37a 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -26,7 +26,7 @@ from .utils import _build_args from .utils import _check_arg_dict_list from .utils import _check_function_or_method from .utils import _get_func_signature - +from .utils import seq_len_to_mask class LossBase(object): """ @@ -223,7 +223,9 @@ class CrossEntropyLoss(LossBase): :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` - :param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容 + :param seq_len: 句子的长度, 长度之外的token不会计算loss。。 + :param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替 + 传入seq_len. Example:: @@ -231,16 +233,18 @@ class CrossEntropyLoss(LossBase): """ - def __init__(self, pred=None, target=None, padding_idx=-100): + def __init__(self, pred=None, target=None, seq_len=None, padding_idx=-100): super(CrossEntropyLoss, self).__init__() - self._init_param_map(pred=pred, target=target) + self._init_param_map(pred=pred, target=target, seq_len=seq_len) self.padding_idx = padding_idx - def get_loss(self, pred, target): + def get_loss(self, pred, target, seq_len=None): if pred.dim()>2: - if pred.size()[:2]==target.size(): - # F.cross_entropy在计算时,如果pred是(16, 10 ,4), 会在第二维上去log_softmax, 所以需要交换一下位置 - pred = pred.transpose(1, 2) + pred = pred.view(-1, pred.size(-1)) + target = target.view(-1) + if seq_len is not None: + mask = seq_len_to_mask(seq_len).view(-1).eq(0) + target = target.masked_fill(mask, self.padding_idx) return F.cross_entropy(input=pred, target=target, ignore_index=self.padding_idx) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index a303f742..e8dfa814 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -452,17 +452,15 @@ class Trainer(object): else: raise TypeError("train_data type {} not support".format(type(train_data))) - self.model = _move_model_to_device(model, device=device) - if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter): - _check_code(dataset=train_data, model=self.model, losser=losser, metrics=metrics, dev_data=dev_data, + _check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, metric_key=metric_key, check_level=check_code_level, batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) # _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的代码 - + self.model = _move_model_to_device(model, device=device) + self.train_data = train_data self.dev_data = dev_data # If None, No validation. - self.model = model self.losser = losser self.metrics = metrics self.n_epochs = int(n_epochs) @@ -480,16 +478,16 @@ class Trainer(object): if isinstance(optimizer, torch.optim.Optimizer): self.optimizer = optimizer elif isinstance(optimizer, Optimizer): - self.optimizer = optimizer.construct_from_pytorch(model.parameters()) + self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) elif optimizer is None: - self.optimizer = torch.optim.Adam(model.parameters(), lr=4e-3) + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=4e-3) else: raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) self.use_tqdm = use_tqdm self.pbar = None self.print_every = abs(self.print_every) - + if self.dev_data is not None: self.tester = Tester(model=self.model, data=self.dev_data, From e0b23b16db59b249bc4ffbcbe45f4d8f99b7bbd8 Mon Sep 17 00:00:00 2001 From: xuyige Date: Mon, 24 Jun 2019 21:44:43 +0800 Subject: [PATCH 02/10] update data loader of matching --- fastNLP/io/file_utils.py | 31 +++++++ fastNLP/modules/encoder/embedding.py | 41 ++------- .../matching/data/MatchingDataLoader.py | 92 ++++++++++++------- reproduction/matching/matching_esim.py | 65 +++++++++++++ reproduction/matching/model/esim.py | 21 ++++- 5 files changed, 178 insertions(+), 72 deletions(-) create mode 100644 reproduction/matching/matching_esim.py diff --git a/fastNLP/io/file_utils.py b/fastNLP/io/file_utils.py index d178626b..04970cb3 100644 --- a/fastNLP/io/file_utils.py +++ b/fastNLP/io/file_utils.py @@ -10,6 +10,37 @@ import shutil import hashlib +PRETRAINED_BERT_MODEL_DIR = { + 'en': 'bert-base-cased-f89bfe08.zip', + 'en-base-uncased': 'bert-base-uncased-3413b23c.zip', + 'en-base-cased': 'bert-base-cased-f89bfe08.zip', + 'en-large-uncased': 'bert-large-uncased-20939f45.zip', + 'en-large-cased': 'bert-large-cased-e0cf90fc.zip', + + 'cn': 'bert-base-chinese-29d0a84a.zip', + 'cn-base': 'bert-base-chinese-29d0a84a.zip', + + 'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip', + 'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip', + 'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip', +} + +PRETRAINED_ELMO_MODEL_DIR = { + 'en': 'elmo_en-d39843fe.tar.gz', + 'cn': 'elmo_cn-5e9b34e2.tar.gz' +} + +PRETRAIN_STATIC_FILES = { + 'en': 'glove.840B.300d-cc1ad5e1.tar.gz', + 'en-glove-840b-300': 'glove.840B.300d-cc1ad5e1.tar.gz', + 'en-glove-6b-50': "glove.6B.50d-a6028c70.tar.gz", + 'en-word2vec-300': "GoogleNews-vectors-negative300-be166d9d.tar.gz", + 'en-fasttext': "cc.en.300.vec-d53187b2.gz", + 'cn': "tencent_cn-dab24577.tar.gz", + 'cn-fasttext': "cc.zh.300.vec-d68a9bcf.gz", +} + + def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path: """ 给定一个url或者文件名(可以是具体的文件名,也可以是文件),先在cache_dir下寻找该文件是否存在,如果不存在则去下载, 并 diff --git a/fastNLP/modules/encoder/embedding.py b/fastNLP/modules/encoder/embedding.py index c6c95bb7..a58668da 100644 --- a/fastNLP/modules/encoder/embedding.py +++ b/fastNLP/modules/encoder/embedding.py @@ -26,6 +26,7 @@ from ...core.dataset import DataSet from ...core.batch import DataSetIter from ...core.sampler import SequentialSampler from ...core.utils import _move_model_to_device, _get_model_device +from ...io.file_utils import PRETRAINED_BERT_MODEL_DIR, PRETRAINED_ELMO_MODEL_DIR, PRETRAIN_STATIC_FILES class Embedding(nn.Module): @@ -187,15 +188,6 @@ class StaticEmbedding(TokenEmbedding): super(StaticEmbedding, self).__init__(vocab) # 优先定义需要下载的static embedding有哪些。这里估计需要自己搞一个server, - PRETRAIN_STATIC_FILES = { - 'en': 'glove.840B.300d-cc1ad5e1.tar.gz', - 'en-glove-840b-300': 'glove.840B.300d-cc1ad5e1.tar.gz', - 'en-glove-6b-50': "glove.6B.50d-a6028c70.tar.gz", - 'en-word2vec-300': "GoogleNews-vectors-negative300-be166d9d.tar.gz", - 'en-fasttext': "cc.en.300.vec-d53187b2.gz", - 'cn': "tencent_cn-dab24577.tar.gz", - 'cn-fasttext': "cc.zh.300.vec-d68a9bcf.gz", - } # 得到cache_path if model_dir_or_name.lower() in PRETRAIN_STATIC_FILES: @@ -231,7 +223,7 @@ class StaticEmbedding(TokenEmbedding): :return: """ requires_grads = set([param.requires_grad for name, param in self.named_parameters() - if 'words_to_words' not in name]) + if 'words_to_words' not in name]) if len(requires_grads) == 1: return requires_grads.pop() else: @@ -244,8 +236,8 @@ class StaticEmbedding(TokenEmbedding): continue param.requires_grad = value - def _load_with_vocab(self, embed_filepath, vocab, dtype=np.float32, padding='', unknown='', normalize=True, - error='ignore', init_method=None): + def _load_with_vocab(self, embed_filepath, vocab, dtype=np.float32, padding='', unknown='', + normalize=True, error='ignore', init_method=None): """ 从embed_filepath这个预训练的词向量中抽取出vocab这个词表的词的embedding。EmbedLoader将自动判断embed_filepath是 word2vec(第一行只有两个元素)还是glove格式的数据。 @@ -329,11 +321,6 @@ class ContextualEmbedding(TokenEmbedding): """ 由于动态embedding生成比较耗时,所以可以把每句话embedding缓存下来,这样就不需要每次都运行生成过程。 - Example:: - - >>> - - :param datasets: DataSet对象 :param batch_size: int, 生成cache的sentence表示时使用的batch的大小 :param device: 参考 :class::fastNLP.Trainer 的device @@ -363,7 +350,7 @@ class ContextualEmbedding(TokenEmbedding): seq_len = words.ne(pad_index).sum(dim=-1) max_len = words.size(1) # 因为有些情况可能包含CLS, SEP, 从后面往前计算比较安全。 - seq_len_from_behind =(max_len - seq_len).tolist() + seq_len_from_behind = (max_len - seq_len).tolist() word_embeds = self(words).detach().cpu().numpy() for b in range(words.size(0)): length = seq_len_from_behind[b] @@ -446,9 +433,6 @@ class ElmoEmbedding(ContextualEmbedding): self.layers = layers # 根据model_dir_or_name检查是否存在并下载 - PRETRAINED_ELMO_MODEL_DIR = {'en': 'elmo_en-d39843fe.tar.gz', - 'cn': 'elmo_cn-5e9b34e2.tar.gz'} - if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR: PRETRAIN_URL = _get_base_url('elmo') model_name = PRETRAINED_ELMO_MODEL_DIR[model_dir_or_name] @@ -532,21 +516,8 @@ class BertEmbedding(ContextualEmbedding): def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en-base-uncased', layers: str='-1', pool_method: str='first', include_cls_sep: bool=False, requires_grad: bool=False): super(BertEmbedding, self).__init__(vocab) - # 根据model_dir_or_name检查是否存在并下载 - PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip', - 'en-base-uncased': 'bert-base-uncased-3413b23c.zip', - 'en-base-cased': 'bert-base-cased-f89bfe08.zip', - 'en-large-uncased': 'bert-large-uncased-20939f45.zip', - 'en-large-cased': 'bert-large-cased-e0cf90fc.zip', - - 'cn': 'bert-base-chinese-29d0a84a.zip', - 'cn-base': 'bert-base-chinese-29d0a84a.zip', - - 'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip', - 'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip', - 'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip', - } + # 根据model_dir_or_name检查是否存在并下载 if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: PRETRAIN_URL = _get_base_url('bert') model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] diff --git a/reproduction/matching/data/MatchingDataLoader.py b/reproduction/matching/data/MatchingDataLoader.py index 139b1d4f..4868598a 100644 --- a/reproduction/matching/data/MatchingDataLoader.py +++ b/reproduction/matching/data/MatchingDataLoader.py @@ -6,31 +6,58 @@ from typing import Union, Dict from fastNLP.core.const import Const from fastNLP.core.vocabulary import Vocabulary -from fastNLP.core.dataset import DataSet from fastNLP.io.base_loader import DataInfo -from fastNLP.io.dataset_loader import JsonLoader -from fastNLP.io.file_utils import _get_base_url, cached_path +from fastNLP.io.dataset_loader import JsonLoader, DataSetLoader +from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR from fastNLP.modules.encoder._bert import BertTokenizer -class MatchingLoader(JsonLoader): +class MatchingLoader(DataSetLoader): """ 别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` 读取Matching任务的数据集 """ - def __init__(self, fields=None, paths: dict=None): - super(MatchingLoader, self).__init__(fields=fields) + def __init__(self, paths: dict=None): + """ + :param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名 + """ self.paths = paths def _load(self, path): - return super(MatchingLoader, self)._load(path) - - def process(self, paths: Union[str, Dict[str, str]], dataset_name=None, - to_lower=False, char_information=False, seq_len_type: str=None, - bert_tokenizer: str=None, get_index=True, set_input: Union[list, str, bool]=True, + """ + :param str path: 待读取数据集的路径名 + :return: fastNLP.DataSet ds: 返回一个DataSet对象,里面必须包含3个field:其中两个分别为两个句子 + 的原始字符串文本,第三个为标签 + """ + raise NotImplementedError + + def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None, + to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, + get_index=True, set_input: Union[list, str, bool]=True, set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo: + """ + :param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, + 则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 + 对应的全路径文件名。 + :param str dataset_name: 如果在paths里传入的是一个数据集的全路径文件名,那么可以用dataset_name来定义 + 这个数据集的名字,如果不定义则默认为train。 + :param bool to_lower: 是否将文本自动转为小写。默认值为False。 + :param str seq_len_type: 提供的seq_len类型,支持 ``seq_len`` :提供一个数字作为句子长度; ``mask`` : + 提供一个0/1的mask矩阵作为句子长度; ``bert`` :提供segment_type_id(第一个句子为0,第二个句子为1)和 + attention mask矩阵(0/1的mask矩阵)。默认值为None,即不提供seq_len + :param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径 + :param bool get_index: 是否需要根据词表将文本转为index + :param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False + 则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input, + 于此同时其他field不会被设置为input。默认值为True。 + :param set_target: set_target将控制哪些field可以被设置为target,用法与set_input一致。默认值为True。 + :param concat: 是否需要将两个句子拼接起来。如果为False则不会拼接。如果为True则会在两个句子之间插入一个。 + 如果传入一个长度为4的list,则分别表示插在第一句开始前、第一句结束后、第二句开始前、第二句结束后的标识符。如果 + 传入字符串 ``bert`` ,则会采用bert的拼接方式,等价于['[CLS]', '[SEP]', '', '[SEP]']. + :return: + """ if isinstance(set_input, str): set_input = [set_input] if isinstance(set_target, str): @@ -69,19 +96,6 @@ class MatchingLoader(JsonLoader): is_input=auto_set_input) if bert_tokenizer is not None: - PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip', - 'en-base-uncased': 'bert-base-uncased-3413b23c.zip', - 'en-base-cased': 'bert-base-cased-f89bfe08.zip', - 'en-large-uncased': 'bert-large-uncased-20939f45.zip', - 'en-large-cased': 'bert-large-cased-e0cf90fc.zip', - - 'cn': 'bert-base-chinese-29d0a84a.zip', - 'cn-base': 'bert-base-chinese-29d0a84a.zip', - - 'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip', - 'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip', - 'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip', - } if bert_tokenizer.lower() in PRETRAINED_BERT_MODEL_DIR: PRETRAIN_URL = _get_base_url('bert') model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer] @@ -128,14 +142,14 @@ class MatchingLoader(JsonLoader): for fields in data_set.get_field_names(): if Const.INPUT in fields: data_set.apply(lambda x: len(x[fields]), - new_field_name=fields.replace(Const.INPUT, Const.TARGET), + new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN), is_input=auto_set_input) elif seq_len_type == 'mask': for data_name, data_set in data_info.datasets.items(): for fields in data_set.get_field_names(): if Const.INPUT in fields: data_set.apply(lambda x: [1] * len(x[fields]), - new_field_name=fields.replace(Const.INPUT, Const.TARGET), + new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN), is_input=auto_set_input) elif seq_len_type == 'bert': for data_name, data_set in data_info.datasets.items(): @@ -152,11 +166,18 @@ class MatchingLoader(JsonLoader): if bert_tokenizer is not None: words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]') + with open(os.path.join(model_dir, 'vocab.txt'), 'r') as f: + lines = f.readlines() + lines = [line.strip() for line in lines] + words_vocab.add_word_lst(lines) + words_vocab.build_vocab() else: words_vocab = Vocabulary() - words_vocab = words_vocab.from_dataset(*data_set_list, - field_name=[n for n in data_set_list[0].get_field_names() - if (Const.INPUT in n)]) + words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], + field_name=[n for n in data_set_list[0].get_field_names() + if (Const.INPUT in n)], + no_create_entry_dataset=[d for n, d in data_info.datasets.items() + if 'train' not in n]) target_vocab = Vocabulary(padding=None, unknown=None) target_vocab = target_vocab.from_dataset(*data_set_list, field_name=Const.TARGET) data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab} @@ -173,14 +194,14 @@ class MatchingLoader(JsonLoader): for data_name, data_set in data_info.datasets.items(): if isinstance(set_input, list): - data_set.set_input(set_input) + data_set.set_input(*set_input) if isinstance(set_target, list): - data_set.set_target(set_target) + data_set.set_target(*set_target) return data_info -class SNLILoader(MatchingLoader): +class SNLILoader(MatchingLoader, JsonLoader): """ 别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` @@ -203,10 +224,13 @@ class SNLILoader(MatchingLoader): 'train': 'snli_1.0_train.jsonl', 'dev': 'snli_1.0_dev.jsonl', 'test': 'snli_1.0_test.jsonl'} - super(SNLILoader, self).__init__(fields=fields, paths=paths) + # super(SNLILoader, self).__init__(fields=fields, paths=paths) + MatchingLoader.__init__(self, paths=paths) + JsonLoader.__init__(self, fields=fields) def _load(self, path): - ds = super(SNLILoader, self)._load(path) + # ds = super(SNLILoader, self)._load(path) + ds = JsonLoader._load(self, path) def parse_tree(x): t = Tree.fromstring(x) diff --git a/reproduction/matching/matching_esim.py b/reproduction/matching/matching_esim.py new file mode 100644 index 00000000..3da6141f --- /dev/null +++ b/reproduction/matching/matching_esim.py @@ -0,0 +1,65 @@ + +import argparse +import torch + +from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const +from fastNLP.modules.encoder.embedding import ElmoEmbedding, StaticEmbedding + +from reproduction.matching.data.MatchingDataLoader import SNLILoader +from reproduction.matching.model.esim import ESIMModel + +argument = argparse.ArgumentParser() +argument.add_argument('--embedding', choices=['glove', 'elmo'], default='glove') +argument.add_argument('--batch-size-per-gpu', type=int, default=128) +argument.add_argument('--n-epochs', type=int, default=100) +argument.add_argument('--lr', type=float, default=1e-4) +argument.add_argument('--seq-len-type', choices=['mask', 'seq_len'], default='seq_len') +argument.add_argument('--save-dir', type=str, default=None) +arg = argument.parse_args() + +bert_dirs = 'path/to/bert/dir' + +# load data set +data_info = SNLILoader().process( + paths='path/to/snli/data/dir', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, + get_index=True, concat=False, +) + +# load embedding +if arg.embedding == 'elmo': + embedding = ElmoEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True) +elif arg.embedding == 'glove': + embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True) +else: + raise ValueError(f'now we only support elmo or glove embedding for esim model!') + +# define model +model = ESIMModel(embedding) + +# define trainer +trainer = Trainer(train_data=data_info.datasets['train'], model=model, + optimizer=Adam(lr=arg.lr, model_params=model.parameters()), + batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, + n_epochs=arg.n_epochs, print_every=-1, + dev_data=data_info.datasets['dev'], + metrics=AccuracyMetric(), metric_key='acc', + device=[i for i in range(torch.cuda.device_count())], + check_code_level=-1, + save_path=arg.save_path) + +# train model +trainer.train(load_best_model=True) + +# define tester +tester = Tester( + data=data_info.datasets['test'], + model=model, + metrics=AccuracyMetric(), + batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, + device=[i for i in range(torch.cuda.device_count())], +) + +# test model +tester.test() + + diff --git a/reproduction/matching/model/esim.py b/reproduction/matching/model/esim.py index 0551bbdb..d55034e7 100644 --- a/reproduction/matching/model/esim.py +++ b/reproduction/matching/model/esim.py @@ -30,24 +30,37 @@ class ESIMModel(BaseModel): self.bi_attention = SoftmaxAttention() self.rnn_high = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate) - # self.rnn_high = LSTM(hidden_size, hidden_size, dropout=dropout_rate, bidirectional=True) + # self.rnn_high = LSTM(hidden_size, hidden_size, dropout=dropout_rate, bidirectional=True,) self.classifier = nn.Sequential(nn.Dropout(p=dropout_rate), nn.Linear(8 * hidden_size, hidden_size), nn.Tanh(), nn.Dropout(p=dropout_rate), nn.Linear(hidden_size, num_labels)) + + self.dropout_rnn = nn.Dropout(p=dropout_rate) + nn.init.xavier_uniform_(self.classifier[1].weight.data) nn.init.xavier_uniform_(self.classifier[4].weight.data) def forward(self, words1, words2, seq_len1, seq_len2, target=None): - mask1 = seq_len_to_mask(seq_len1) - mask2 = seq_len_to_mask(seq_len2) + """ + :param words1: [batch, seq_len] + :param words2: [batch, seq_len] + :param seq_len1: [batch] + :param seq_len2: [batch] + :param target: + :return: + """ + mask1 = seq_len_to_mask(seq_len1, words1.size(1)) + mask2 = seq_len_to_mask(seq_len2, words2.size(1)) a0 = self.embedding(words1) # B * len * emb_dim b0 = self.embedding(words2) a0, b0 = self.dropout_embed(a0), self.dropout_embed(b0) a = self.rnn(a0, mask1.byte()) # a: [B, PL, 2 * H] b = self.rnn(b0, mask2.byte()) + # a = self.dropout_rnn(self.rnn(a0, seq_len1)[0]) # a: [B, PL, 2 * H] + # b = self.dropout_rnn(self.rnn(b0, seq_len2)[0]) ai, bi = self.bi_attention(a, mask1, b, mask2) @@ -58,6 +71,8 @@ class ESIMModel(BaseModel): a_h = self.rnn_high(a_f, mask1.byte()) # ma: [B, PL, 2 * H] b_h = self.rnn_high(b_f, mask2.byte()) + # a_h = self.dropout_rnn(self.rnn_high(a_f, seq_len1)[0]) # ma: [B, PL, 2 * H] + # b_h = self.dropout_rnn(self.rnn_high(b_f, seq_len2)[0]) a_avg = self.mean_pooling(a_h, mask1, dim=1) a_max, _ = self.max_pooling(a_h, mask1, dim=1) From bc5e071253c2a13ef055d13ac6b88f57bc7038e0 Mon Sep 17 00:00:00 2001 From: xuyige Date: Mon, 24 Jun 2019 21:56:14 +0800 Subject: [PATCH 03/10] Delete matching.py --- reproduction/matching/matching.py | 44 ------------------------------- 1 file changed, 44 deletions(-) delete mode 100644 reproduction/matching/matching.py diff --git a/reproduction/matching/matching.py b/reproduction/matching/matching.py deleted file mode 100644 index 8251b3bc..00000000 --- a/reproduction/matching/matching.py +++ /dev/null @@ -1,44 +0,0 @@ -import os - -import torch - -from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const - -from fastNLP.io.dataset_loader import MatchingLoader - -from reproduction.matching.model.bert import BertForNLI -from reproduction.matching.model.esim import ESIMModel - - -bert_dirs = 'path/to/bert/dir' - -# load data set -# data_info = MatchingLoader(data_format='snli', for_model='bert', bert_dir=bert_dirs).process(... -data_info = MatchingLoader(data_format='snli', for_model='esim').process( - {'train': './data/snli/snli_1.0_train.jsonl', - 'dev': './data/snli/snli_1.0_dev.jsonl', - 'test': './data/snli/snli_1.0_test.jsonl'}, - input_field=[Const.TARGET] -) - -# model = BertForNLI(bert_dir=bert_dirs) -model = ESIMModel(data_info.embeddings['elmo'],) - -trainer = Trainer(train_data=data_info.datasets['train'], model=model, - optimizer=Adam(lr=1e-4, model_params=model.parameters()), - batch_size=torch.cuda.device_count() * 24, n_epochs=20, print_every=-1, - dev_data=data_info.datasets['dev'], - metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())], - check_code_level=-1) -trainer.train(load_best_model=True) - -tester = Tester( - data=data_info.datasets['test'], - model=model, - metrics=AccuracyMetric(), - batch_size=torch.cuda.device_count() * 12, - device=[i for i in range(torch.cuda.device_count())], -) -tester.test() - - From 50faa936b44193b2dc44c356e68a8d8b45119f1a Mon Sep 17 00:00:00 2001 From: xuyige Date: Tue, 25 Jun 2019 17:22:10 +0800 Subject: [PATCH 04/10] add RTE and QNLI loader --- .../matching/data/MatchingDataLoader.py | 107 +++++++++++++++--- 1 file changed, 93 insertions(+), 14 deletions(-) diff --git a/reproduction/matching/data/MatchingDataLoader.py b/reproduction/matching/data/MatchingDataLoader.py index 4868598a..0e4e1283 100644 --- a/reproduction/matching/data/MatchingDataLoader.py +++ b/reproduction/matching/data/MatchingDataLoader.py @@ -1,13 +1,12 @@ import os -from nltk import Tree from typing import Union, Dict from fastNLP.core.const import Const from fastNLP.core.vocabulary import Vocabulary from fastNLP.io.base_loader import DataInfo -from fastNLP.io.dataset_loader import JsonLoader, DataSetLoader +from fastNLP.io.dataset_loader import JsonLoader, DataSetLoader, CSVLoader from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR from fastNLP.modules.encoder._bert import BertTokenizer @@ -35,7 +34,7 @@ class MatchingLoader(DataSetLoader): def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None, to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, - get_index=True, set_input: Union[list, str, bool]=True, + cut_text: int = None, get_index=True, set_input: Union[list, str, bool]=True, set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo: """ :param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, @@ -48,6 +47,7 @@ class MatchingLoader(DataSetLoader): 提供一个0/1的mask矩阵作为句子长度; ``bert`` :提供segment_type_id(第一个句子为0,第二个句子为1)和 attention mask矩阵(0/1的mask矩阵)。默认值为None,即不提供seq_len :param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径 + :param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。 :param bool get_index: 是否需要根据词表将文本转为index :param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False 则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input, @@ -161,6 +161,13 @@ class MatchingLoader(DataSetLoader): data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) + if cut_text is not None: + for data_name, data_set in data_info.datasets.items(): + for fields in data_set.get_field_names(): + if (Const.INPUT in fields) or ((Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len')): + data_set.apply(lambda x: x[fields][: cut_text], new_field_name=fields, + is_input=auto_set_input) + data_set_list = [d for n, d in data_info.datasets.items()] assert len(data_set_list) > 0, f'There are NO data sets in data info!' @@ -216,32 +223,104 @@ class SNLILoader(MatchingLoader, JsonLoader): def __init__(self, paths: dict=None): fields = { - 'sentence1_parse': Const.INPUTS(0), - 'sentence2_parse': Const.INPUTS(1), + 'sentence1_binary_parse': Const.INPUTS(0), + 'sentence2_binary_parse': Const.INPUTS(1), 'gold_label': Const.TARGET, } paths = paths if paths is not None else { 'train': 'snli_1.0_train.jsonl', 'dev': 'snli_1.0_dev.jsonl', 'test': 'snli_1.0_test.jsonl'} - # super(SNLILoader, self).__init__(fields=fields, paths=paths) MatchingLoader.__init__(self, paths=paths) JsonLoader.__init__(self, fields=fields) def _load(self, path): - # ds = super(SNLILoader, self)._load(path) ds = JsonLoader._load(self, path) - def parse_tree(x): - t = Tree.fromstring(x) - return t.leaves() + parentheses_table = str.maketrans({'(': None, ')': None}) - ds.apply(lambda ins: parse_tree( - ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) - ds.apply(lambda ins: parse_tree( - ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) + ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), + new_field_name=Const.INPUTS(0)) + ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), + new_field_name=Const.INPUTS(1)) ds.drop(lambda x: x[Const.TARGET] == '-') return ds +class RTELoader(MatchingLoader, CSVLoader): + """ + 别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.dataset_loader.RTELoader` + + 读取RTE数据集,读取的DataSet包含fields:: + + words1: list(str),第一句文本, premise + words2: list(str), 第二句文本, hypothesis + target: str, 真实标签 + + 数据来源: + """ + + def __init__(self, paths: dict=None): + paths = paths if paths is not None else { + 'train': 'train.tsv', + 'dev': 'dev.tsv', + # 'test': 'test.tsv' # test set has not label + } + MatchingLoader.__init__(self, paths=paths) + self.fields = { + 'sentence1': Const.INPUTS(0), + 'sentence2': Const.INPUTS(1), + 'label': Const.TARGET, + } + CSVLoader.__init__(self, sep='\t') + + def _load(self, path): + ds = CSVLoader._load(self, path) + + for k, v in self.fields.items(): + ds.rename_field(k, v) + for fields in ds.get_all_fields(): + if Const.INPUT in fields: + ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) + + return ds + + +class QNLILoader(MatchingLoader, CSVLoader): + """ + 别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.dataset_loader.QNLILoader` + + 读取QNLI数据集,读取的DataSet包含fields:: + + words1: list(str),第一句文本, premise + words2: list(str), 第二句文本, hypothesis + target: str, 真实标签 + + 数据来源: + """ + + def __init__(self, paths: dict=None): + paths = paths if paths is not None else { + 'train': 'train.tsv', + 'dev': 'dev.tsv', + # 'test': 'test.tsv' # test set has not label + } + MatchingLoader.__init__(self, paths=paths) + self.fields = { + 'question': Const.INPUTS(0), + 'sentence': Const.INPUTS(1), + 'label': Const.TARGET, + } + CSVLoader.__init__(self, sep='\t') + + def _load(self, path): + ds = CSVLoader._load(self, path) + + for k, v in self.fields.items(): + ds.rename_field(k, v) + for fields in ds.get_all_fields(): + if Const.INPUT in fields: + ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) + + return ds From 40c4d216d19ebf02515607e8d0e649d2cc781ca5 Mon Sep 17 00:00:00 2001 From: yh Date: Wed, 26 Jun 2019 13:39:14 +0800 Subject: [PATCH 05/10] =?UTF-8?q?=E4=BF=AE=E6=94=B9staticEmbedding?= =?UTF-8?q?=E7=9A=84=E5=88=9D=E5=A7=8B=E5=8C=96=E6=96=B9=E5=BC=8F=EF=BC=8C?= =?UTF-8?q?=E6=98=BE=E7=A4=BA=E9=80=9A=E8=BF=87=E8=BF=99=E7=A7=8D=E5=88=9D?= =?UTF-8?q?=E5=A7=8B=E5=8C=96=E5=9C=A8esmi=E4=B8=8A=E7=9A=84snli=E6=9B=B4?= =?UTF-8?q?=E5=AE=B9=E6=98=93=E8=BE=BE=E5=88=B088=E7=9A=84test=20acc?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/modules/encoder/embedding.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/fastNLP/modules/encoder/embedding.py b/fastNLP/modules/encoder/embedding.py index a58668da..c48cb806 100644 --- a/fastNLP/modules/encoder/embedding.py +++ b/fastNLP/modules/encoder/embedding.py @@ -180,11 +180,11 @@ class StaticEmbedding(TokenEmbedding): 的名称。目前支持的embedding包括{`en` 或者 `en-glove-840b-300` : glove.840B.300d, `en-glove-6b-50` : glove.6B.50d, `en-word2vec-300` : GoogleNews-vectors-negative300}。第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。 :param requires_grad: 是否需要gradient. 默认为True - :param init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。默认使用torch.nn.init.xavier_uniform_ - 。调用该方法时传入一个tensor对象。 - + :param init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。调用该方法时传入一个tensor对象。 + :param normailize: 是否对vector进行normalize,使得每个vector的norm为1。 """ - def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', requires_grad: bool=True, init_method=None): + def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', requires_grad: bool=True, init_method=None, + normalize=False): super(StaticEmbedding, self).__init__(vocab) # 优先定义需要下载的static embedding有哪些。这里估计需要自己搞一个server, @@ -202,7 +202,8 @@ class StaticEmbedding(TokenEmbedding): raise ValueError(f"Cannot recognize {model_dir_or_name}.") # 读取embedding - embedding, hit_flags = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method) + embedding, hit_flags = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method, + normalize=normalize) self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1], padding_idx=vocab.padding_idx, max_norm=None, norm_type=2, scale_grad_by_freq=False, @@ -257,10 +258,7 @@ class StaticEmbedding(TokenEmbedding): assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary is supported." if not os.path.exists(embed_filepath): raise FileNotFoundError("`{}` does not exist.".format(embed_filepath)) - if init_method is None: - init_method = nn.init.xavier_uniform_ with open(embed_filepath, 'r', encoding='utf-8') as f: - found_count = 0 line = f.readline().strip() parts = line.split() start_idx = 0 @@ -271,7 +269,8 @@ class StaticEmbedding(TokenEmbedding): dim = len(parts) - 1 f.seek(0) matrix = torch.zeros(len(vocab), dim) - init_method(matrix) + if init_method is not None: + init_method(matrix) hit_flags = np.zeros(len(vocab), dtype=bool) for idx, line in enumerate(f, start_idx): try: @@ -286,7 +285,6 @@ class StaticEmbedding(TokenEmbedding): if word in vocab: index = vocab.to_index(word) matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim)) - found_count += 1 hit_flags[index] = True except Exception as e: if error == 'ignore': @@ -294,7 +292,16 @@ class StaticEmbedding(TokenEmbedding): else: print("Error occurred at the {} line.".format(idx)) raise e + found_count = sum(hit_flags) print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) + if init_method is None: + if len(vocab)-found_count>0 and found_count>0: # 有的没找到 + found_vecs = matrix[torch.LongTensor(hit_flags.astype(int)).byte()] + mean = found_vecs.mean(dim=0, keepdim=True) + std = found_vecs.std(dim=0, keepdim=True) + unfound_vec_num = np.sum(hit_flags==False) + unfound_vecs = torch.randn(unfound_vec_num, dim)*std + mean + matrix[torch.LongTensor(hit_flags.astype(int)).eq(0)] = unfound_vecs if normalize: matrix /= (torch.norm(matrix, dim=1, keepdim=True) + 1e-12) From 9c1b4914d8f4fda018f449cf5374941b1fa03c9d Mon Sep 17 00:00:00 2001 From: yh_cc Date: Sun, 30 Jun 2019 09:52:01 +0800 Subject: [PATCH 06/10] =?UTF-8?q?1.=E4=BF=AE=E5=A4=8Dtrainer=E4=B8=AD?= =?UTF-8?q?=E6=BD=9C=E5=9C=A8=E5=A4=9A=E6=AD=A5=E6=9B=B4=E6=96=B0bug;=202.?= =?UTF-8?q?=20LSTM=E7=9A=84=E6=95=B0=E6=8D=AE=E5=B9=B6=E8=A1=8C=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=EF=BC=9B3.=20embed=5Floader=E4=B8=ADbug=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D,=20=E4=B8=94=E5=85=81=E8=AE=B8=E6=89=8B=E5=8A=A8?= =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96=EF=BC=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callback.py | 2 +- fastNLP/core/dataset.py | 6 ++++-- fastNLP/core/optimizer.py | 17 +++++++++++++++++ fastNLP/core/trainer.py | 6 +++--- fastNLP/io/embed_loader.py | 33 +++++++++++++++++++-------------- fastNLP/modules/encoder/lstm.py | 11 +---------- fastNLP/modules/utils.py | 2 ++ setup.py | 2 +- 8 files changed, 48 insertions(+), 31 deletions(-) diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 483f6dc1..5dfd889b 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -548,7 +548,7 @@ class LRScheduler(Callback): else: raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.") - def on_epoch_begin(self): + def on_epoch_end(self): self.scheduler.step(self.epoch) diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 4cd1ad9c..b7df9dec 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -801,17 +801,19 @@ class DataSet(object): else: return DataSet() - def split(self, ratio): + def split(self, ratio, shuffle=True): """ 将DataSet按照ratio的比例拆分,返回两个DataSet :param float ratio: 0 [N,L,C] - output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first) + output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first, total_length=max_len) _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) if self.batch_first: output = output[unsort_idx] else: output = output[:, unsort_idx] - # 解决LSTM无法在DataParallel下使用的问题问题https://github.com/pytorch/pytorch/issues/1591 - if self.batch_first: - if output.size(1) < max_len: - dummy_tensor = output.new_zeros(batch_size, max_len - output.size(1), output.size(-1)) - output = torch.cat([output, dummy_tensor], 1) - else: - if output.size(0) < max_len: - dummy_tensor = output.new_zeros(max_len - output.size(1), batch_size, output.size(-1)) - output = torch.cat([output, dummy_tensor], 0) else: output, hx = self.lstm(x, hx) return output, hx diff --git a/fastNLP/modules/utils.py b/fastNLP/modules/utils.py index c87f3a68..3c6a3d27 100644 --- a/fastNLP/modules/utils.py +++ b/fastNLP/modules/utils.py @@ -82,6 +82,8 @@ def get_embeddings(init_embed): if isinstance(init_embed, tuple): res = nn.Embedding( num_embeddings=init_embed[0], embedding_dim=init_embed[1]) + nn.init.uniform_(res.weight.data, a=-np.sqrt(3/res.weight.data.size(1)), + b=np.sqrt(3/res.weight.data.size(1))) elif isinstance(init_embed, nn.Module): res = init_embed elif isinstance(init_embed, torch.Tensor): diff --git a/setup.py b/setup.py index 49646761..0dbef455 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ with open('requirements.txt', encoding='utf-8') as f: setup( name='FastNLP', - version='0.4.0', + version='dev0.5.0', description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', long_description=readme, long_description_content_type='text/markdown', From 15d9581e6d0805ad52a3f7c367d329999e3841e2 Mon Sep 17 00:00:00 2001 From: xuyige Date: Sun, 30 Jun 2019 15:44:26 +0800 Subject: [PATCH 07/10] fix a bug in predictor --- fastNLP/core/predictor.py | 4 +- .../matching/data/MatchingDataLoader.py | 93 ++++++++++++++++--- 2 files changed, 82 insertions(+), 15 deletions(-) diff --git a/fastNLP/core/predictor.py b/fastNLP/core/predictor.py index 06e586c6..ce016bb6 100644 --- a/fastNLP/core/predictor.py +++ b/fastNLP/core/predictor.py @@ -9,7 +9,7 @@ import torch from . import DataSetIter from . import DataSet from . import SequentialSampler -from .utils import _build_args +from .utils import _build_args, _move_dict_value_to_device, _get_model_device class Predictor(object): @@ -43,6 +43,7 @@ class Predictor(object): raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) self.network.eval() + network_device = _get_model_device(self.network) batch_output = defaultdict(list) data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) @@ -53,6 +54,7 @@ class Predictor(object): with torch.no_grad(): for batch_x, _ in data_iterator: + _move_dict_value_to_device(batch_x, _, device=network_device) refined_batch_x = _build_args(predict_func, **batch_x) prediction = predict_func(**refined_batch_x) diff --git a/reproduction/matching/data/MatchingDataLoader.py b/reproduction/matching/data/MatchingDataLoader.py index 0e4e1283..749b16c8 100644 --- a/reproduction/matching/data/MatchingDataLoader.py +++ b/reproduction/matching/data/MatchingDataLoader.py @@ -86,7 +86,8 @@ class MatchingLoader(DataSetLoader): if auto_set_input: data_set.set_input(Const.INPUTS(0), Const.INPUTS(1)) if auto_set_target: - data_set.set_target(Const.TARGET) + if Const.TARGET in data_set.get_field_names(): + data_set.set_target(Const.TARGET) if to_lower: for data_name, data_set in data_info.datasets.items(): @@ -107,6 +108,13 @@ class MatchingLoader(DataSetLoader): else: raise ValueError(f"Cannot recognize BERT tokenizer from {bert_tokenizer}.") + words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]') + with open(os.path.join(model_dir, 'vocab.txt'), 'r') as f: + lines = f.readlines() + lines = [line.strip() for line in lines] + words_vocab.add_word_lst(lines) + words_vocab.build_vocab() + tokenizer = BertTokenizer.from_pretrained(model_dir) for data_name, data_set in data_info.datasets.items(): @@ -171,14 +179,7 @@ class MatchingLoader(DataSetLoader): data_set_list = [d for n, d in data_info.datasets.items()] assert len(data_set_list) > 0, f'There are NO data sets in data info!' - if bert_tokenizer is not None: - words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]') - with open(os.path.join(model_dir, 'vocab.txt'), 'r') as f: - lines = f.readlines() - lines = [line.strip() for line in lines] - words_vocab.add_word_lst(lines) - words_vocab.build_vocab() - else: + if bert_tokenizer is None: words_vocab = Vocabulary() words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], field_name=[n for n in data_set_list[0].get_field_names() @@ -186,7 +187,8 @@ class MatchingLoader(DataSetLoader): no_create_entry_dataset=[d for n, d in data_info.datasets.items() if 'train' not in n]) target_vocab = Vocabulary(padding=None, unknown=None) - target_vocab = target_vocab.from_dataset(*data_set_list, field_name=Const.TARGET) + target_vocab = target_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], + field_name=Const.TARGET) data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab} if get_index: @@ -196,14 +198,15 @@ class MatchingLoader(DataSetLoader): data_set.apply(lambda x: [words_vocab.to_index(w) for w in x[fields]], new_field_name=fields, is_input=auto_set_input) - data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, - is_input=auto_set_input, is_target=auto_set_target) + if Const.TARGET in data_set.get_field_names(): + data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, + is_input=auto_set_input, is_target=auto_set_target) for data_name, data_set in data_info.datasets.items(): if isinstance(set_input, list): - data_set.set_input(*set_input) + data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()]) if isinstance(set_target, list): - data_set.set_target(*set_target) + data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()]) return data_info @@ -324,3 +327,65 @@ class QNLILoader(MatchingLoader, CSVLoader): return ds + +class MNLILoader(MatchingLoader, CSVLoader): + """ + 别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader` + + 读取SNLI数据集,读取的DataSet包含fields:: + + words1: list(str),第一句文本, premise + words2: list(str), 第二句文本, hypothesis + target: str, 真实标签 + + 数据来源: + """ + + def __init__(self, paths: dict=None): + paths = paths if paths is not None else { + 'train': 'train.tsv', + 'dev_matched': 'dev_matched.tsv', + 'dev_mismatched': 'dev_mismatched.tsv', + 'test_matched': 'test_matched.tsv', + 'test_mismatched': 'test_mismatched.tsv', + } + MatchingLoader.__init__(self, paths=paths) + CSVLoader.__init__(self, sep='\t') + self.fields = { + 'sentence1_binary_parse': Const.INPUTS(0), + 'sentence2_binary_parse': Const.INPUTS(1), + 'gold_label': Const.TARGET, + } + + def _load(self, path): + ds = CSVLoader._load(self, path) + + for k, v in self.fields.items(): + if k in ds.get_field_names(): + ds.rename_field(k, v) + + parentheses_table = str.maketrans({'(': None, ')': None}) + + ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), + new_field_name=Const.INPUTS(0)) + ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), + new_field_name=Const.INPUTS(1)) + if Const.TARGET in ds.get_field_names(): + ds.drop(lambda x: x[Const.TARGET] == '-') + return ds + + +class QuoraLoader(MatchingLoader, CSVLoader): + + def __init__(self, paths: dict=None): + paths = paths if paths is not None else { + 'train': 'train.tsv', + 'dev': 'dev.tsv', + 'test': 'test.tsv', + } + MatchingLoader.__init__(self, paths=paths) + CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID')) + + def _load(self, path): + ds = CSVLoader._load(self, path) + return ds From 5f19601d202d3cbfeb6a57fd9721e63d059e34e2 Mon Sep 17 00:00:00 2001 From: yh Date: Mon, 1 Jul 2019 00:01:45 +0800 Subject: [PATCH 08/10] =?UTF-8?q?=E6=94=AF=E6=8C=81predict=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=B9=B6=E8=A1=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/tester.py | 10 ++++------ fastNLP/core/utils.py | 23 ++++++++++++++++++++++- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 4cdd4ffb..6a0fdb9a 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -48,6 +48,7 @@ from .utils import _move_dict_value_to_device from .utils import _get_func_signature from .utils import _get_model_device from .utils import _move_model_to_device +from .utils import _data_parallel_wrapper __all__ = [ "Tester" @@ -113,12 +114,9 @@ class Tester(object): self._model = self._model.module # check predict - if hasattr(self._model, 'predict'): - self._predict_func = self._model.predict - if not callable(self._predict_func): - _model_name = model.__class__.__name__ - raise TypeError(f"`{_model_name}.predict` must be callable to be used " - f"for evaluation, not `{type(self._predict_func)}`.") + if hasattr(self._model, 'predict') and callable(self._model.predict): + self._predict_func = _data_parallel_wrapper(self._model.predict, self._model.device_ids, + self._model.output_device) else: if isinstance(model, nn.DataParallel): self._predict_func = self._model.module.forward diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index d26df966..8fe764f8 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -16,7 +16,9 @@ from collections import Counter, namedtuple import numpy as np import torch import torch.nn as nn - +from torch.nn.parallel.scatter_gather import scatter_kwargs, gather +from torch.nn.parallel.replicate import replicate +from torch.nn.parallel.parallel_apply import parallel_apply _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', 'varargs']) @@ -277,6 +279,25 @@ def _move_model_to_device(model, device): model = model.to(device) return model +def _data_parallel_wrapper(func, device_ids, output_device): + """ + 这个函数是用于对需要多卡执行的函数的wrapper函数。参考的nn.DataParallel的forward函数 + + :param func: callable + :param device_ids: nn.DataParallel中的device_ids + :param inputs: + :param kwargs: + :return: + """ + def wrapper(*inputs, **kwargs): + inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0) + if len(device_ids) == 1: + return func(*inputs[0], **kwargs[0]) + replicas = replicate(func, device_ids[:len(inputs)]) + outputs = parallel_apply(replicas, inputs, kwargs) + return gather(outputs, output_device) + return wrapper + def _get_model_device(model): """ From f68b2c5382b6411d9aca8b674194b94896d296e9 Mon Sep 17 00:00:00 2001 From: yh Date: Mon, 1 Jul 2019 00:33:31 +0800 Subject: [PATCH 09/10] =?UTF-8?q?Tester=E6=94=AF=E6=8C=81predict=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=B9=B6=E8=A1=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/tester.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 6a0fdb9a..4fa31fd2 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -105,18 +105,16 @@ class Tester(object): self.data_iterator = data else: raise TypeError("data type {} not support".format(type(data))) - - # 如果是DataParallel将没有办法使用predict方法 - if isinstance(self._model, nn.DataParallel): - if hasattr(self._model.module, 'predict') and not hasattr(self._model, 'predict'): - warnings.warn("Cannot use DataParallel to test your model, because your model offer predict() function," - " while DataParallel has no predict() function.") - self._model = self._model.module - + # check predict - if hasattr(self._model, 'predict') and callable(self._model.predict): - self._predict_func = _data_parallel_wrapper(self._model.predict, self._model.device_ids, - self._model.output_device) + if (hasattr(self._model, 'predict') and callable(self._model.predict)) or \ + (isinstance(self._model, nn.DataParallel) and hasattr(self._model.module, 'predict') and + callable(self._model.module.predict)): + if isinstance(self._model, nn.DataParallel): + self._predict_func = _data_parallel_wrapper(self._model.module.predict, self._model.device_ids, + self._model.output_device) + else: + self._predict_func = self._model.predict else: if isinstance(model, nn.DataParallel): self._predict_func = self._model.module.forward From 3c984872d3cd968750ee127f5e8cdf4b0106935f Mon Sep 17 00:00:00 2001 From: yh Date: Mon, 1 Jul 2019 01:03:30 +0800 Subject: [PATCH 10/10] =?UTF-8?q?Tester=E6=95=B0=E6=8D=AE=E5=B9=B6?= =?UTF-8?q?=E8=A1=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/tester.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 4fa31fd2..68950c10 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -111,15 +111,19 @@ class Tester(object): (isinstance(self._model, nn.DataParallel) and hasattr(self._model.module, 'predict') and callable(self._model.module.predict)): if isinstance(self._model, nn.DataParallel): - self._predict_func = _data_parallel_wrapper(self._model.module.predict, self._model.device_ids, + self._predict_func_wrapper = _data_parallel_wrapper(self._model.module.predict, self._model.device_ids, self._model.output_device) + self._predict_func = self._model.module.predict else: self._predict_func = self._model.predict + self._predict_func_wrapper = self._model.predict else: - if isinstance(model, nn.DataParallel): + if isinstance(self._model, nn.DataParallel): + self._predict_func_wrapper = self._model.forward self._predict_func = self._model.module.forward else: self._predict_func = self._model.forward + self._predict_func_wrapper = self._model.forward def test(self): """开始进行验证,并返回验证结果。 @@ -176,7 +180,7 @@ class Tester(object): def _data_forward(self, func, x): """A forward pass of the model. """ x = _build_args(func, **x) - y = func(**x) + y = self._predict_func_wrapper(**x) return y def _format_eval_results(self, results):