From 43fac849f9e21bf28a99da94d69d5d1674ff7e47 Mon Sep 17 00:00:00 2001 From: yh Date: Mon, 8 Jul 2019 22:49:43 +0800 Subject: [PATCH] =?UTF-8?q?1.=20=E5=A2=9E=E5=8A=A0learning=20rate=20Warmup?= =?UTF-8?q?Callback;=202.=E5=A2=9E=E5=8A=A0=E6=A8=A1=E5=9E=8B=E4=BF=9D?= =?UTF-8?q?=E5=AD=98=E7=9A=84callback;=203.=20utils=E4=B8=AD=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E5=AF=B9bio=E7=B1=BB=E5=9E=8B=E7=9A=84tag=E7=9A=84?= =?UTF-8?q?=E5=A4=84=E7=90=86;=204.=20embedding=E4=B8=AD=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?word=5Fdropout=E4=B8=8Edropout=E9=80=89=E9=A1=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/batch.py | 27 +++++- fastNLP/core/callback.py | 128 +++++++++++++++++++++++++ fastNLP/core/utils.py | 77 ++++++++++++++- fastNLP/modules/encoder/embedding.py | 135 +++++++++++++++++++-------- 4 files changed, 323 insertions(+), 44 deletions(-) diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index ca48a8e1..2d8c1a80 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -3,7 +3,6 @@ batch 模块实现了 fastNLP 所需的 Batch 类。 """ __all__ = [ - "BatchIter", "DataSetIter", "TorchLoaderIter", ] @@ -50,6 +49,7 @@ class DataSetGetter: return len(self.dataset) def collate_fn(self, batch: list): + # TODO 支持在DataSet中定义collate_fn,因为有时候可能需要不同的field之间融合,比如BERT的场景 batch_x = {n:[] for n in self.inputs.keys()} batch_y = {n:[] for n in self.targets.keys()} indices = [] @@ -136,6 +136,31 @@ class BatchIter: class DataSetIter(BatchIter): + """ + 别名::class:`fastNLP.DataSetIter` :class:`fastNLP.core.batch.DataSetIter` + + DataSetIter 用于从 `DataSet` 中按一定的顺序, 依次按 ``batch_size`` 的大小将数据取出, + 组成 `x` 和 `y`:: + + batch = DataSetIter(data_set, batch_size=16, sampler=SequentialSampler()) + num_batch = len(batch) + for batch_x, batch_y in batch: + # do stuff ... + + :param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 + :param int batch_size: 取出的batch大小 + :param sampler: 规定使用的 :class:`~fastNLP.Sampler` 方式. 若为 ``None`` , 使用 :class:`~fastNLP.SequentialSampler`. + + Default: ``None`` + :param bool as_numpy: 若为 ``True`` , 输出batch为 numpy.array. 否则为 :class:`torch.Tensor`. + + Default: ``False`` + :param int num_workers: 使用多少个进程来预处理数据 + :param bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快速度。 + :param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个 + :param timeout: + :param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 + """ def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False, num_workers=0, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 0b1890f8..bbe2f325 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -66,6 +66,8 @@ import os import torch from copy import deepcopy +import sys +from .utils import _save_model try: from tensorboardX import SummaryWriter @@ -737,6 +739,132 @@ class TensorboardCallback(Callback): del self._summary_writer +class WarmupCallback(Callback): + """ + 按一定的周期调节Learning rate的大小。 + + :param int,float warmup: 如果warmup为int,则在该step之前,learning rate根据schedule的策略变化; 如果warmup为float, + 如0.1, 则前10%的step是按照schedule策略调整learning rate。 + :param str schedule: 以哪种方式调整。linear: 前warmup的step上升到指定的learning rate(从Trainer中的optimizer处获取的), 后 + warmup的step下降到0; constant前warmup的step上升到指定learning rate,后面的step保持learning rate. + """ + def __init__(self, warmup=0.1, schedule='constant'): + super().__init__() + self.warmup = max(warmup, 0.) + + self.initial_lrs = [] # 存放param_group的learning rate + if schedule == 'constant': + self.get_lr = self._get_constant_lr + elif schedule == 'linear': + self.get_lr = self._get_linear_lr + else: + raise RuntimeError("Only support 'linear', 'constant'.") + + def _get_constant_lr(self, progress): + if progress1: + self.warmup = self.warmup/self.t_steps + self.t_steps = max(2, self.t_steps) # 不能小于2 + # 获取param_group的初始learning rate + for group in self.optimizer.param_groups: + self.initial_lrs.append(group['lr']) + + def on_backward_end(self): + if self.step%self.update_every==0: + progress = (self.step/self.update_every)/self.t_steps + for lr, group in zip(self.initial_lrs, self.optimizer.param_groups): + group['lr'] = lr * self.get_lr(progress) + + +class SaveModelCallback(Callback): + """ + 由于Trainer在训练过程中只会保存最佳的模型, 该callback可实现多种方式的结果存储。 + 会根据训练开始的时间戳在save_dir下建立文件夹,再在文件夹下存放多个模型 + -save_dir + -2019-07-03-15-06-36 + -epoch:0_step:20_{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_performance是性能 + -epoch:1_step:40_{metric_key}:{evaluate_performance}.pt + -2019-07-03-15-10-00 + -epoch:0_step:20_{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能 + :param str save_dir: 将模型存放在哪个目录下,会在该目录下创建以时间戳命名的目录,并存放模型 + :param int top: 保存dev表现top多少模型。-1为保存所有模型。 + :param bool only_param: 是否只保存模型d饿权重。 + :param save_on_exception: 发生exception时,是否保存一份发生exception的模型。模型名称为epoch:x_step:x_Exception:{exception_name}. + """ + def __init__(self, save_dir, top=3, only_param=False, save_on_exception=False): + super().__init__() + + if not os.path.isdir(save_dir): + raise IsADirectoryError("{} is not a directory.".format(save_dir)) + self.save_dir = save_dir + if top < 0: + self.top = sys.maxsize + else: + self.top = top + self._ordered_save_models = [] # List[Tuple], Tuple[0]是metric, Tuple[1]是path。metric是依次变好的,所以从头删 + + self.only_param = only_param + self.save_on_exception = save_on_exception + + def on_train_begin(self): + self.save_dir = os.path.join(self.save_dir, self.trainer.start_time) + + def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): + metric_value = list(eval_result.values())[0][metric_key] + self._save_this_model(metric_value) + + def _insert_into_ordered_save_models(self, pair): + # pair:(metric_value, model_name) + # 返回save的模型pair与删除的模型pair. pair中第一个元素是metric的值,第二个元素是模型的名称 + index = -1 + for _pair in self._ordered_save_models: + if _pair[0]>=pair[0] and self.trainer.increase_better: + break + if not self.trainer.increase_better and _pair[0]<=pair[0]: + break + index += 1 + save_pair = None + if len(self._ordered_save_models)=self.top and index!=-1): + save_pair = pair + self._ordered_save_models.insert(index+1, pair) + delete_pair = None + if len(self._ordered_save_models)>self.top: + delete_pair = self._ordered_save_models.pop(0) + return save_pair, delete_pair + + def _save_this_model(self, metric_value): + name = "epoch:{}_step:{}_{}:{:.6f}.pt".format(self.epoch, self.step, self.trainer.metric_key, metric_value) + save_pair, delete_pair = self._insert_into_ordered_save_models((metric_value, name)) + if save_pair: + try: + _save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param) + except Exception as e: + print(f"The following exception:{e} happens when save model to {self.save_dir}.") + if delete_pair: + try: + delete_model_path = os.path.join(self.save_dir, delete_pair[1]) + if os.path.exists(delete_model_path): + os.remove(delete_model_path) + except Exception as e: + print(f"Fail to delete model {name} at {self.save_dir} caused by exception:{e}.") + + def on_exception(self, exception): + if self.save_on_exception: + name = "epoch:{}_step:{}_Exception:{}.pt".format(self.epoch, self.step, exception.__class__.__name__) + _save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param) + + class CallbackException(BaseException): """ 当需要通过callback跳出训练的时候可以通过抛出CallbackException并在on_exception中捕获这个值。 diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 490f9f8f..9b23240c 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -16,6 +16,7 @@ from collections import Counter, namedtuple import numpy as np import torch import torch.nn as nn +from typing import List _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', 'varargs']) @@ -162,6 +163,30 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): return wrapper_ +def _save_model(model, model_name, save_dir, only_param=False): + """ 存储不含有显卡信息的state_dict或model + :param model: + :param model_name: + :param save_dir: 保存的directory + :param only_param: + :return: + """ + model_path = os.path.join(save_dir, model_name) + if not os.path.isdir(save_dir): + os.makedirs(save_dir, exist_ok=True) + if isinstance(model, nn.DataParallel): + model = model.module + if only_param: + state_dict = model.state_dict() + for key in state_dict: + state_dict[key] = state_dict[key].cpu() + torch.save(state_dict, model_path) + else: + _model_device = _get_model_device(model) + model.cpu() + torch.save(model, model_path) + model.to(_model_device) + # def save_pickle(obj, pickle_path, file_name): # """Save an object into a pickle file. @@ -277,7 +302,6 @@ def _move_model_to_device(model, device): return model - def _get_model_device(model): """ 传入一个nn.Module的模型,获取它所在的device @@ -285,7 +309,7 @@ def _get_model_device(model): :param model: nn.Module :return: torch.device,None 如果返回值为None,说明这个模型没有任何参数。 """ - # TODO 这个函数存在一定的风险,因为同一个模型可能存在某些parameter不在显卡中,比如BertEmbedding + # TODO 这个函数存在一定的风险,因为同一个模型可能存在某些parameter不在显卡中,比如BertEmbedding. 或者跨显卡 assert isinstance(model, nn.Module) parameters = list(model.parameters()) @@ -712,3 +736,52 @@ class _pseudo_tqdm: def __exit__(self, exc_type, exc_val, exc_tb): del self + +def iob2(tags:List[str])->List[str]: + """ + 检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两者的差异见 + https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format + + :param tags: 需要转换的tags, 需要为大写的BIO标签。 + """ + for i, tag in enumerate(tags): + if tag == "O": + continue + split = tag.split("-") + if len(split) != 2 or split[0] not in ["I", "B"]: + raise TypeError("The encoding schema is not a valid IOB type.") + if split[0] == "B": + continue + elif i == 0 or tags[i - 1] == "O": # conversion IOB1 to IOB2 + tags[i] = "B" + tag[1:] + elif tags[i - 1][1:] == tag[1:]: + continue + else: # conversion IOB1 to IOB2 + tags[i] = "B" + tag[1:] + return tags + +def iob2bioes(tags:List[str])->List[str]: + """ + 将iob的tag转换为bioes编码 + :param tags: List[str]. 编码需要是大写的。 + :return: + """ + new_tags = [] + for i, tag in enumerate(tags): + if tag == 'O': + new_tags.append(tag) + else: + split = tag.split('-')[0] + if split == 'B': + if i+1!=len(tags) and tags[i+1].split('-')[0] == 'I': + new_tags.append(tag) + else: + new_tags.append(tag.replace('B-', 'S-')) + elif split == 'I': + if i + 10 and not isinstance(unk_index, int): + if word_dropout>0 and not isinstance(unk_index, int): raise ValueError("When drop word is set, you need to pass in the unk_index.") else: self._embed_size = self.embed.embed_size unk_index = self.embed.get_word_vocab().unknown_idx self.unk_index = unk_index - self.dropout_word = dropout_word + self.word_dropout = word_dropout def forward(self, x): """ :param torch.LongTensor x: [batch, seq_len] :return: torch.Tensor : [batch, seq_len, embed_dim] """ - if self.dropout_word>0 and self.training: - mask = torch.ones_like(x).float() * self.dropout_word + if self.word_dropout>0 and self.training: + mask = torch.ones_like(x).float() * self.word_dropout mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1 x = x.masked_fill(mask, self.unk_index) x = self.embed(x) @@ -117,11 +117,38 @@ class Embedding(nn.Module): class TokenEmbedding(nn.Module): - def __init__(self, vocab): + def __init__(self, vocab, word_dropout=0.0, dropout=0.0): super(TokenEmbedding, self).__init__() - assert vocab.padding_idx is not None, "You vocabulary must have padding." + assert vocab.padding is not None, "Vocabulary must have a padding entry." self._word_vocab = vocab self._word_pad_index = vocab.padding_idx + if word_dropout>0: + assert vocab.unknown is not None, "Vocabulary must have unknown entry when you want to drop a word." + self.word_dropout = word_dropout + self._word_unk_index = vocab.unknown_idx + self.dropout_layer = nn.Dropout(dropout) + + def drop_word(self, words): + """ + 按照设定随机将words设置为unknown_index。 + + :param torch.LongTensor words: batch_size x max_len + :return: + """ + if self.dropout_word > 0 and self.training: + mask = torch.ones_like(words).float() * self.word_dropout + mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1 + words = words.masked_fill(mask, self._word_unk_index) + return words + + def dropout(self, words): + """ + 对embedding后的word表示进行drop。 + + :param torch.FloatTensor words: batch_size x max_len x embed_size + :return: + """ + return self.dropout_layer(words) @property def requires_grad(self): @@ -163,6 +190,9 @@ class TokenEmbedding(nn.Module): def size(self): return torch.Size(self.num_embedding, self._embed_size) + @abstractmethod + def forward(self, *input): + raise NotImplementedError class StaticEmbedding(TokenEmbedding): """ @@ -181,13 +211,15 @@ class StaticEmbedding(TokenEmbedding): `en-word2vec-300` : GoogleNews-vectors-negative300}。第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。 :param bool requires_grad: 是否需要gradient. 默认为True :param callable init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。调用该方法时传入一个tensor对象。 - :param bool normailize: 是否对vector进行normalize,使得每个vector的norm为1。 :param bool lower: 是否将vocab中的词语小写后再和预训练的词表进行匹配。如果你的词表中包含大写的词语,或者就是需要单独 为大写的词语开辟一个vector表示,则将lower设置为False。 + :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 + :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 + :param bool normailize: 是否对vector进行normalize,使得每个vector的norm为1。 """ def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', requires_grad: bool=True, init_method=None, - normalize=False, lower=False): - super(StaticEmbedding, self).__init__(vocab) + lower=False, dropout=0, word_dropout=0, normalize=False): + super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) # 得到cache_path if model_dir_or_name.lower() in PRETRAIN_STATIC_FILES: @@ -362,12 +394,15 @@ class StaticEmbedding(TokenEmbedding): """ if hasattr(self, 'words_to_words'): words = self.words_to_words[words] - return self.embedding(words) + words = self.drop_word(words) + words = self.embedding(words) + words = self.dropout(words) + return words class ContextualEmbedding(TokenEmbedding): - def __init__(self, vocab: Vocabulary): - super(ContextualEmbedding, self).__init__(vocab) + def __init__(self, vocab: Vocabulary, word_dropout:float=0.0, dropout:float=0.0): + super(ContextualEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) def add_sentence_cache(self, *datasets, batch_size=32, device='cpu', delete_weights: bool=True): """ @@ -473,12 +508,14 @@ class ElmoEmbedding(ContextualEmbedding): 按照这个顺序concat起来。默认为'2'。'mix'会使用可学习的权重结合不同层的表示(权重是否可训练与requires_grad保持一致, 初始化权重对三层结果进行mean-pooling, 可以通过ElmoEmbedding.set_mix_weights_requires_grad()方法只将mix weights设置为可学习。) :param requires_grad: bool, 该层是否需要gradient, 默认为False. + :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 + :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 :param cache_word_reprs: 可以选择对word的表示进行cache; 设置为True的话,将在初始化的时候为每个word生成对应的embedding, 并删除character encoder,之后将直接使用cache的embedding。默认为False。 """ - def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', - layers: str='2', requires_grad: bool=False, cache_word_reprs: bool=False): - super(ElmoEmbedding, self).__init__(vocab) + def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', layers: str='2', requires_grad: bool=False, + word_dropout=0.0, dropout=0.0, cache_word_reprs: bool=False): + super(ElmoEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) # 根据model_dir_or_name检查是否存在并下载 if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR: @@ -545,11 +582,13 @@ class ElmoEmbedding(ContextualEmbedding): :param words: batch_size x max_len :return: torch.FloatTensor. batch_size x max_len x (512*len(self.layers)) """ + words = self.drop_word(words) outputs = self._get_sent_reprs(words) if outputs is not None: - return outputs + return self.dropout(outputs) outputs = self.model(words) - return self._get_outputs(outputs) + outputs = self._get_outputs(outputs) + return self.dropout(outputs) def _delete_model_weights(self): for name in ['layers', 'model', 'layer_weights', 'gamma']: @@ -595,13 +634,16 @@ class BertEmbedding(ContextualEmbedding): :param str layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 :param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces 中计算得到它对应的表示。支持``last``, ``first``, ``avg``, ``max``。 + :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 + :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 :param bool include_cls_sep: bool,在bert计算句子的表示的时候,需要在前面加上[CLS]和[SEP], 是否在结果中保留这两个内容。 这样 会使得word embedding的结果比输入的结果长两个token。在使用 :class::StackEmbedding 可能会遇到问题。 :param bool requires_grad: 是否需要gradient。 """ def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en-base-uncased', layers: str='-1', - pool_method: str='first', include_cls_sep: bool=False, requires_grad: bool=False): - super(BertEmbedding, self).__init__(vocab) + pool_method: str='first', word_dropout=0, dropout=0, requires_grad: bool=False, + include_cls_sep: bool=False): + super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) # 根据model_dir_or_name检查是否存在并下载 if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: @@ -632,13 +674,14 @@ class BertEmbedding(ContextualEmbedding): :param torch.LongTensor words: [batch_size, max_len] :return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers)) """ + words = self.drop_word(words) outputs = self._get_sent_reprs(words) if outputs is not None: - return outputs + return self.dropout(words) outputs = self.model(words) outputs = torch.cat([*outputs], dim=-1) - return outputs + return self.dropout(words) @property def requires_grad(self): @@ -680,8 +723,8 @@ class CNNCharEmbedding(TokenEmbedding): """ 别名::class:`fastNLP.modules.CNNCharEmbedding` :class:`fastNLP.modules.encoder.embedding.CNNCharEmbedding` - 使用CNN生成character embedding。CNN的结果为, embed(x) -> Dropout(x) -> CNN(x) -> activation(x) -> pool - -> fc. 不同的kernel大小的fitler结果是concat起来的。 + 使用CNN生成character embedding。CNN的结果为, embed(x) -> Dropout(x) -> CNN(x) -> activation(x) -> pool -> fc -> Dropout. + 不同的kernel大小的fitler结果是concat起来的。 Example:: @@ -691,23 +734,24 @@ class CNNCharEmbedding(TokenEmbedding): :param vocab: 词表 :param embed_size: 该word embedding的大小,默认值为50. :param char_emb_size: character的embed的大小。character是从vocab中生成的。默认值为50. - :param dropout: 以多大的概率drop + :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 + :param float dropout: 以多大的概率drop :param filter_nums: filter的数量. 长度需要和kernels一致。默认值为[40, 30, 20]. :param kernel_sizes: kernel的大小. 默认值为[5, 3, 1]. :param pool_method: character的表示在合成一个表示时所使用的pool方法,支持'avg', 'max'. :param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh' 或者自定义函数. :param min_char_freq: character的最少出现次数。默认值为2. """ - def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, dropout:float=0.5, - filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1), pool_method: str='max', - activation='relu', min_char_freq: int=2): - super(CNNCharEmbedding, self).__init__(vocab) + def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0, + dropout:float=0.5, filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1), + pool_method: str='max', activation='relu', min_char_freq: int=2): + super(CNNCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) for kernel in kernel_sizes: assert kernel % 2 == 1, "Only odd kernel is allowed." assert pool_method in ('max', 'avg') - self.dropout = nn.Dropout(dropout, inplace=True) + self.dropout = nn.Dropout(dropout) self.pool_method = pool_method # activation function if isinstance(activation, str): @@ -757,6 +801,7 @@ class CNNCharEmbedding(TokenEmbedding): :param words: [batch_size, max_len] :return: [batch_size, max_len, embed_size] """ + words = self.drop_word(words) batch_size, max_len = words.size() chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len word_lengths = self.word_lengths[words] # batch_size x max_len @@ -779,7 +824,7 @@ class CNNCharEmbedding(TokenEmbedding): conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), 0) chars = torch.sum(conv_chars, dim=-2)/chars_masks.eq(0).sum(dim=-1, keepdim=True).float() chars = self.fc(chars) - return chars + return self.dropout(chars) @property def requires_grad(self): @@ -826,6 +871,7 @@ class LSTMCharEmbedding(TokenEmbedding): :param vocab: 词表 :param embed_size: embedding的大小。默认值为50. :param char_emb_size: character的embedding的大小。默认值为50. + :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 :param dropout: 以多大概率drop :param hidden_size: LSTM的中间hidden的大小,如果为bidirectional的,hidden会除二,默认为50. :param pool_method: 支持'max', 'avg' @@ -833,15 +879,16 @@ class LSTMCharEmbedding(TokenEmbedding): :param min_char_freq: character的最小出现次数。默认值为2. :param bidirectional: 是否使用双向的LSTM进行encode。默认值为True。 """ - def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, dropout:float=0.5, hidden_size=50, - pool_method: str='max', activation='relu', min_char_freq: int=2, bidirectional=True): + def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0, + dropout:float=0.5, hidden_size=50,pool_method: str='max', activation='relu', min_char_freq: int=2, + bidirectional=True): super(LSTMCharEmbedding, self).__init__(vocab) assert hidden_size % 2 == 0, "Only even kernel is allowed." assert pool_method in ('max', 'avg') self.pool_method = pool_method - self.dropout = nn.Dropout(dropout, inplace=True) + self.dropout = nn.Dropout(dropout) # activation function if isinstance(activation, str): if activation.lower() == 'relu': @@ -890,6 +937,7 @@ class LSTMCharEmbedding(TokenEmbedding): :param words: [batch_size, max_len] :return: [batch_size, max_len, embed_size] """ + words = self.drop_word(words) batch_size, max_len = words.size() chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len word_lengths = self.word_lengths[words] # batch_size x max_len @@ -914,7 +962,7 @@ class LSTMCharEmbedding(TokenEmbedding): chars = self.fc(chars) - return chars + return self.dropout(words) @property def requires_grad(self): @@ -953,9 +1001,12 @@ class StackEmbedding(TokenEmbedding): :param embeds: 一个由若干个TokenEmbedding组成的list,要求每一个TokenEmbedding的词表都保持一致 + :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。不同embedidng会在相同的位置 + 被设置为unknown。如果这里设置了dropout,则组成的embedding就不要再设置dropout了。 + :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 """ - def __init__(self, embeds: List[TokenEmbedding]): + def __init__(self, embeds: List[TokenEmbedding], word_dropout=0, dropout=0): vocabs = [] for embed in embeds: if hasattr(embed, 'get_word_vocab'): @@ -964,7 +1015,7 @@ class StackEmbedding(TokenEmbedding): for vocab in vocabs[1:]: assert vocab == _vocab, "All embeddings in StackEmbedding should use the same word vocabulary." - super(StackEmbedding, self).__init__(_vocab) + super(StackEmbedding, self).__init__(_vocab, word_dropout=word_dropout, dropout=dropout) assert isinstance(embeds, list) for embed in embeds: assert isinstance(embed, TokenEmbedding), "Only TokenEmbedding type is supported." @@ -1016,7 +1067,9 @@ class StackEmbedding(TokenEmbedding): :return: 返回的shape和当前这个stack embedding中embedding的组成有关 """ outputs = [] + words = self.drop_word(words) for embed in self.embeds: outputs.append(embed(words)) - return torch.cat(outputs, dim=-1) + outputs = self.dropout(torch.cat(outputs, dim=-1)) + return outputs