Browse Source

1. 增加learning rate WarmupCallback; 2.增加模型保存的callback; 3. utils中增加对bio类型的tag的处理; 4. embedding中增加word_dropout与dropout选项

tags/v0.4.10
yh 6 years ago
parent
commit
43fac849f9
4 changed files with 323 additions and 44 deletions
  1. +26
    -1
      fastNLP/core/batch.py
  2. +128
    -0
      fastNLP/core/callback.py
  3. +75
    -2
      fastNLP/core/utils.py
  4. +94
    -41
      fastNLP/modules/encoder/embedding.py

+ 26
- 1
fastNLP/core/batch.py View File

@@ -3,7 +3,6 @@ batch 模块实现了 fastNLP 所需的 Batch 类。

"""
__all__ = [
"BatchIter",
"DataSetIter",
"TorchLoaderIter",
]
@@ -50,6 +49,7 @@ class DataSetGetter:
return len(self.dataset)

def collate_fn(self, batch: list):
# TODO 支持在DataSet中定义collate_fn,因为有时候可能需要不同的field之间融合,比如BERT的场景
batch_x = {n:[] for n in self.inputs.keys()}
batch_y = {n:[] for n in self.targets.keys()}
indices = []
@@ -136,6 +136,31 @@ class BatchIter:


class DataSetIter(BatchIter):
"""
别名::class:`fastNLP.DataSetIter` :class:`fastNLP.core.batch.DataSetIter`

DataSetIter 用于从 `DataSet` 中按一定的顺序, 依次按 ``batch_size`` 的大小将数据取出,
组成 `x` 和 `y`::

batch = DataSetIter(data_set, batch_size=16, sampler=SequentialSampler())
num_batch = len(batch)
for batch_x, batch_y in batch:
# do stuff ...

:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集
:param int batch_size: 取出的batch大小
:param sampler: 规定使用的 :class:`~fastNLP.Sampler` 方式. 若为 ``None`` , 使用 :class:`~fastNLP.SequentialSampler`.

Default: ``None``
:param bool as_numpy: 若为 ``True`` , 输出batch为 numpy.array. 否则为 :class:`torch.Tensor`.

Default: ``False``
:param int num_workers: 使用多少个进程来预处理数据
:param bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快速度。
:param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个
:param timeout:
:param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。
"""
def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False,
num_workers=0, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):


+ 128
- 0
fastNLP/core/callback.py View File

@@ -66,6 +66,8 @@ import os

import torch
from copy import deepcopy
import sys
from .utils import _save_model

try:
from tensorboardX import SummaryWriter
@@ -737,6 +739,132 @@ class TensorboardCallback(Callback):
del self._summary_writer


class WarmupCallback(Callback):
"""
按一定的周期调节Learning rate的大小。

:param int,float warmup: 如果warmup为int,则在该step之前,learning rate根据schedule的策略变化; 如果warmup为float,
如0.1, 则前10%的step是按照schedule策略调整learning rate。
:param str schedule: 以哪种方式调整。linear: 前warmup的step上升到指定的learning rate(从Trainer中的optimizer处获取的), 后
warmup的step下降到0; constant前warmup的step上升到指定learning rate,后面的step保持learning rate.
"""
def __init__(self, warmup=0.1, schedule='constant'):
super().__init__()
self.warmup = max(warmup, 0.)

self.initial_lrs = [] # 存放param_group的learning rate
if schedule == 'constant':
self.get_lr = self._get_constant_lr
elif schedule == 'linear':
self.get_lr = self._get_linear_lr
else:
raise RuntimeError("Only support 'linear', 'constant'.")

def _get_constant_lr(self, progress):
if progress<self.warmup:
return progress/self.warmup
return 1

def _get_linear_lr(self, progress):
if progress<self.warmup:
return progress/self.warmup
return max((progress - 1.) / (self.warmup - 1.), 0.)

def on_train_begin(self):
self.t_steps = (len(self.trainer.train_data) // (self.batch_size*self.update_every) +
int(len(self.trainer.train_data) % (self.batch_size*self.update_every)!= 0)) * self.n_epochs
if self.warmup>1:
self.warmup = self.warmup/self.t_steps
self.t_steps = max(2, self.t_steps) # 不能小于2
# 获取param_group的初始learning rate
for group in self.optimizer.param_groups:
self.initial_lrs.append(group['lr'])

def on_backward_end(self):
if self.step%self.update_every==0:
progress = (self.step/self.update_every)/self.t_steps
for lr, group in zip(self.initial_lrs, self.optimizer.param_groups):
group['lr'] = lr * self.get_lr(progress)


class SaveModelCallback(Callback):
"""
由于Trainer在训练过程中只会保存最佳的模型, 该callback可实现多种方式的结果存储。
会根据训练开始的时间戳在save_dir下建立文件夹,再在文件夹下存放多个模型
-save_dir
-2019-07-03-15-06-36
-epoch:0_step:20_{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_performance是性能
-epoch:1_step:40_{metric_key}:{evaluate_performance}.pt
-2019-07-03-15-10-00
-epoch:0_step:20_{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能
:param str save_dir: 将模型存放在哪个目录下,会在该目录下创建以时间戳命名的目录,并存放模型
:param int top: 保存dev表现top多少模型。-1为保存所有模型。
:param bool only_param: 是否只保存模型d饿权重。
:param save_on_exception: 发生exception时,是否保存一份发生exception的模型。模型名称为epoch:x_step:x_Exception:{exception_name}.
"""
def __init__(self, save_dir, top=3, only_param=False, save_on_exception=False):
super().__init__()

if not os.path.isdir(save_dir):
raise IsADirectoryError("{} is not a directory.".format(save_dir))
self.save_dir = save_dir
if top < 0:
self.top = sys.maxsize
else:
self.top = top
self._ordered_save_models = [] # List[Tuple], Tuple[0]是metric, Tuple[1]是path。metric是依次变好的,所以从头删

self.only_param = only_param
self.save_on_exception = save_on_exception

def on_train_begin(self):
self.save_dir = os.path.join(self.save_dir, self.trainer.start_time)

def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
metric_value = list(eval_result.values())[0][metric_key]
self._save_this_model(metric_value)

def _insert_into_ordered_save_models(self, pair):
# pair:(metric_value, model_name)
# 返回save的模型pair与删除的模型pair. pair中第一个元素是metric的值,第二个元素是模型的名称
index = -1
for _pair in self._ordered_save_models:
if _pair[0]>=pair[0] and self.trainer.increase_better:
break
if not self.trainer.increase_better and _pair[0]<=pair[0]:
break
index += 1
save_pair = None
if len(self._ordered_save_models)<self.top or (len(self._ordered_save_models)>=self.top and index!=-1):
save_pair = pair
self._ordered_save_models.insert(index+1, pair)
delete_pair = None
if len(self._ordered_save_models)>self.top:
delete_pair = self._ordered_save_models.pop(0)
return save_pair, delete_pair

def _save_this_model(self, metric_value):
name = "epoch:{}_step:{}_{}:{:.6f}.pt".format(self.epoch, self.step, self.trainer.metric_key, metric_value)
save_pair, delete_pair = self._insert_into_ordered_save_models((metric_value, name))
if save_pair:
try:
_save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param)
except Exception as e:
print(f"The following exception:{e} happens when save model to {self.save_dir}.")
if delete_pair:
try:
delete_model_path = os.path.join(self.save_dir, delete_pair[1])
if os.path.exists(delete_model_path):
os.remove(delete_model_path)
except Exception as e:
print(f"Fail to delete model {name} at {self.save_dir} caused by exception:{e}.")

def on_exception(self, exception):
if self.save_on_exception:
name = "epoch:{}_step:{}_Exception:{}.pt".format(self.epoch, self.step, exception.__class__.__name__)
_save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param)


class CallbackException(BaseException):
"""
当需要通过callback跳出训练的时候可以通过抛出CallbackException并在on_exception中捕获这个值。


+ 75
- 2
fastNLP/core/utils.py View File

@@ -16,6 +16,7 @@ from collections import Counter, namedtuple
import numpy as np
import torch
import torch.nn as nn
from typing import List

_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed',
'varargs'])
@@ -162,6 +163,30 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
return wrapper_

def _save_model(model, model_name, save_dir, only_param=False):
""" 存储不含有显卡信息的state_dict或model
:param model:
:param model_name:
:param save_dir: 保存的directory
:param only_param:
:return:
"""
model_path = os.path.join(save_dir, model_name)
if not os.path.isdir(save_dir):
os.makedirs(save_dir, exist_ok=True)
if isinstance(model, nn.DataParallel):
model = model.module
if only_param:
state_dict = model.state_dict()
for key in state_dict:
state_dict[key] = state_dict[key].cpu()
torch.save(state_dict, model_path)
else:
_model_device = _get_model_device(model)
model.cpu()
torch.save(model, model_path)
model.to(_model_device)


# def save_pickle(obj, pickle_path, file_name):
# """Save an object into a pickle file.
@@ -277,7 +302,6 @@ def _move_model_to_device(model, device):
return model



def _get_model_device(model):
"""
传入一个nn.Module的模型,获取它所在的device
@@ -285,7 +309,7 @@ def _get_model_device(model):
:param model: nn.Module
:return: torch.device,None 如果返回值为None,说明这个模型没有任何参数。
"""
# TODO 这个函数存在一定的风险,因为同一个模型可能存在某些parameter不在显卡中,比如BertEmbedding
# TODO 这个函数存在一定的风险,因为同一个模型可能存在某些parameter不在显卡中,比如BertEmbedding. 或者跨显卡
assert isinstance(model, nn.Module)
parameters = list(model.parameters())
@@ -712,3 +736,52 @@ class _pseudo_tqdm:
def __exit__(self, exc_type, exc_val, exc_tb):
del self

def iob2(tags:List[str])->List[str]:
"""
检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两者的差异见
https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format

:param tags: 需要转换的tags, 需要为大写的BIO标签。
"""
for i, tag in enumerate(tags):
if tag == "O":
continue
split = tag.split("-")
if len(split) != 2 or split[0] not in ["I", "B"]:
raise TypeError("The encoding schema is not a valid IOB type.")
if split[0] == "B":
continue
elif i == 0 or tags[i - 1] == "O": # conversion IOB1 to IOB2
tags[i] = "B" + tag[1:]
elif tags[i - 1][1:] == tag[1:]:
continue
else: # conversion IOB1 to IOB2
tags[i] = "B" + tag[1:]
return tags

def iob2bioes(tags:List[str])->List[str]:
"""
将iob的tag转换为bioes编码
:param tags: List[str]. 编码需要是大写的。
:return:
"""
new_tags = []
for i, tag in enumerate(tags):
if tag == 'O':
new_tags.append(tag)
else:
split = tag.split('-')[0]
if split == 'B':
if i+1!=len(tags) and tags[i+1].split('-')[0] == 'I':
new_tags.append(tag)
else:
new_tags.append(tag.replace('B-', 'S-'))
elif split == 'I':
if i + 1<len(tags) and tags[i+1].split('-')[0] == 'I':
new_tags.append(tag)
else:
new_tags.append(tag.replace('I-', 'E-'))
else:
raise TypeError("Invalid IOB format.")
return new_tags

+ 94
- 41
fastNLP/modules/encoder/embedding.py View File

@@ -35,15 +35,15 @@ class Embedding(nn.Module):

Embedding组件. 可以通过self.num_embeddings获取词表大小; self.embedding_dim获取embedding的维度"""
def __init__(self, init_embed, dropout=0.0, dropout_word=0, unk_index=None):
def __init__(self, init_embed, word_dropout=0, dropout=0.0, unk_index=None):
"""

:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray init_embed: Embedding的大小(传入tuple(int, int),
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding;
也可以传入TokenEmbedding对象
:param float word_dropout: 按照一定概率随机将word设置为unk_index,这样可以使得unk这个token得到足够的训练, 且会对网络有
一定的regularize的作用。
:param float dropout: 对Embedding的输出的dropout。
:param float dropout_word: 按照一定比例随机将word设置为unk的idx,这样可以使得unk这个token得到足够的训练
:param int unk_index: drop word时替换为的index,如果init_embed为TokenEmbedding不需要传入该值。
:param int unk_index: drop word时替换为的index。fastNLP的Vocabulary的unk_index默认为1。
"""
super(Embedding, self).__init__()

@@ -52,21 +52,21 @@ class Embedding(nn.Module):
self.dropout = nn.Dropout(dropout)
if not isinstance(self.embed, TokenEmbedding):
self._embed_size = self.embed.weight.size(1)
if dropout_word>0 and not isinstance(unk_index, int):
if word_dropout>0 and not isinstance(unk_index, int):
raise ValueError("When drop word is set, you need to pass in the unk_index.")
else:
self._embed_size = self.embed.embed_size
unk_index = self.embed.get_word_vocab().unknown_idx
self.unk_index = unk_index
self.dropout_word = dropout_word
self.word_dropout = word_dropout

def forward(self, x):
"""
:param torch.LongTensor x: [batch, seq_len]
:return: torch.Tensor : [batch, seq_len, embed_dim]
"""
if self.dropout_word>0 and self.training:
mask = torch.ones_like(x).float() * self.dropout_word
if self.word_dropout>0 and self.training:
mask = torch.ones_like(x).float() * self.word_dropout
mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1
x = x.masked_fill(mask, self.unk_index)
x = self.embed(x)
@@ -117,11 +117,38 @@ class Embedding(nn.Module):


class TokenEmbedding(nn.Module):
def __init__(self, vocab):
def __init__(self, vocab, word_dropout=0.0, dropout=0.0):
super(TokenEmbedding, self).__init__()
assert vocab.padding_idx is not None, "You vocabulary must have padding."
assert vocab.padding is not None, "Vocabulary must have a padding entry."
self._word_vocab = vocab
self._word_pad_index = vocab.padding_idx
if word_dropout>0:
assert vocab.unknown is not None, "Vocabulary must have unknown entry when you want to drop a word."
self.word_dropout = word_dropout
self._word_unk_index = vocab.unknown_idx
self.dropout_layer = nn.Dropout(dropout)

def drop_word(self, words):
"""
按照设定随机将words设置为unknown_index。

:param torch.LongTensor words: batch_size x max_len
:return:
"""
if self.dropout_word > 0 and self.training:
mask = torch.ones_like(words).float() * self.word_dropout
mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1
words = words.masked_fill(mask, self._word_unk_index)
return words

def dropout(self, words):
"""
对embedding后的word表示进行drop。

:param torch.FloatTensor words: batch_size x max_len x embed_size
:return:
"""
return self.dropout_layer(words)

@property
def requires_grad(self):
@@ -163,6 +190,9 @@ class TokenEmbedding(nn.Module):
def size(self):
return torch.Size(self.num_embedding, self._embed_size)

@abstractmethod
def forward(self, *input):
raise NotImplementedError

class StaticEmbedding(TokenEmbedding):
"""
@@ -181,13 +211,15 @@ class StaticEmbedding(TokenEmbedding):
`en-word2vec-300` : GoogleNews-vectors-negative300}。第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。
:param bool requires_grad: 是否需要gradient. 默认为True
:param callable init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。调用该方法时传入一个tensor对象。
:param bool normailize: 是否对vector进行normalize,使得每个vector的norm为1。
:param bool lower: 是否将vocab中的词语小写后再和预训练的词表进行匹配。如果你的词表中包含大写的词语,或者就是需要单独
为大写的词语开辟一个vector表示,则将lower设置为False。
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
:param bool normailize: 是否对vector进行normalize,使得每个vector的norm为1。
"""
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', requires_grad: bool=True, init_method=None,
normalize=False, lower=False):
super(StaticEmbedding, self).__init__(vocab)
lower=False, dropout=0, word_dropout=0, normalize=False):
super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)

# 得到cache_path
if model_dir_or_name.lower() in PRETRAIN_STATIC_FILES:
@@ -362,12 +394,15 @@ class StaticEmbedding(TokenEmbedding):
"""
if hasattr(self, 'words_to_words'):
words = self.words_to_words[words]
return self.embedding(words)
words = self.drop_word(words)
words = self.embedding(words)
words = self.dropout(words)
return words


class ContextualEmbedding(TokenEmbedding):
def __init__(self, vocab: Vocabulary):
super(ContextualEmbedding, self).__init__(vocab)
def __init__(self, vocab: Vocabulary, word_dropout:float=0.0, dropout:float=0.0):
super(ContextualEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)

def add_sentence_cache(self, *datasets, batch_size=32, device='cpu', delete_weights: bool=True):
"""
@@ -473,12 +508,14 @@ class ElmoEmbedding(ContextualEmbedding):
按照这个顺序concat起来。默认为'2'。'mix'会使用可学习的权重结合不同层的表示(权重是否可训练与requires_grad保持一致,
初始化权重对三层结果进行mean-pooling, 可以通过ElmoEmbedding.set_mix_weights_requires_grad()方法只将mix weights设置为可学习。)
:param requires_grad: bool, 该层是否需要gradient, 默认为False.
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
:param cache_word_reprs: 可以选择对word的表示进行cache; 设置为True的话,将在初始化的时候为每个word生成对应的embedding,
并删除character encoder,之后将直接使用cache的embedding。默认为False。
"""
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en',
layers: str='2', requires_grad: bool=False, cache_word_reprs: bool=False):
super(ElmoEmbedding, self).__init__(vocab)
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', layers: str='2', requires_grad: bool=False,
word_dropout=0.0, dropout=0.0, cache_word_reprs: bool=False):
super(ElmoEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)

# 根据model_dir_or_name检查是否存在并下载
if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR:
@@ -545,11 +582,13 @@ class ElmoEmbedding(ContextualEmbedding):
:param words: batch_size x max_len
:return: torch.FloatTensor. batch_size x max_len x (512*len(self.layers))
"""
words = self.drop_word(words)
outputs = self._get_sent_reprs(words)
if outputs is not None:
return outputs
return self.dropout(outputs)
outputs = self.model(words)
return self._get_outputs(outputs)
outputs = self._get_outputs(outputs)
return self.dropout(outputs)

def _delete_model_weights(self):
for name in ['layers', 'model', 'layer_weights', 'gamma']:
@@ -595,13 +634,16 @@ class BertEmbedding(ContextualEmbedding):
:param str layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层
:param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces
中计算得到它对应的表示。支持``last``, ``first``, ``avg``, ``max``。
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
:param bool include_cls_sep: bool,在bert计算句子的表示的时候,需要在前面加上[CLS]和[SEP], 是否在结果中保留这两个内容。 这样
会使得word embedding的结果比输入的结果长两个token。在使用 :class::StackEmbedding 可能会遇到问题。
:param bool requires_grad: 是否需要gradient。
"""
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en-base-uncased', layers: str='-1',
pool_method: str='first', include_cls_sep: bool=False, requires_grad: bool=False):
super(BertEmbedding, self).__init__(vocab)
pool_method: str='first', word_dropout=0, dropout=0, requires_grad: bool=False,
include_cls_sep: bool=False):
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)

# 根据model_dir_or_name检查是否存在并下载
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR:
@@ -632,13 +674,14 @@ class BertEmbedding(ContextualEmbedding):
:param torch.LongTensor words: [batch_size, max_len]
:return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers))
"""
words = self.drop_word(words)
outputs = self._get_sent_reprs(words)
if outputs is not None:
return outputs
return self.dropout(words)
outputs = self.model(words)
outputs = torch.cat([*outputs], dim=-1)

return outputs
return self.dropout(words)

@property
def requires_grad(self):
@@ -680,8 +723,8 @@ class CNNCharEmbedding(TokenEmbedding):
"""
别名::class:`fastNLP.modules.CNNCharEmbedding` :class:`fastNLP.modules.encoder.embedding.CNNCharEmbedding`

使用CNN生成character embedding。CNN的结果为, embed(x) -> Dropout(x) -> CNN(x) -> activation(x) -> pool
-> fc. 不同的kernel大小的fitler结果是concat起来的。
使用CNN生成character embedding。CNN的结果为, embed(x) -> Dropout(x) -> CNN(x) -> activation(x) -> pool -> fc -> Dropout.
不同的kernel大小的fitler结果是concat起来的。

Example::

@@ -691,23 +734,24 @@ class CNNCharEmbedding(TokenEmbedding):
:param vocab: 词表
:param embed_size: 该word embedding的大小,默认值为50.
:param char_emb_size: character的embed的大小。character是从vocab中生成的。默认值为50.
:param dropout: 以多大的概率drop
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
:param float dropout: 以多大的概率drop
:param filter_nums: filter的数量. 长度需要和kernels一致。默认值为[40, 30, 20].
:param kernel_sizes: kernel的大小. 默认值为[5, 3, 1].
:param pool_method: character的表示在合成一个表示时所使用的pool方法,支持'avg', 'max'.
:param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh' 或者自定义函数.
:param min_char_freq: character的最少出现次数。默认值为2.
"""
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, dropout:float=0.5,
filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1), pool_method: str='max',
activation='relu', min_char_freq: int=2):
super(CNNCharEmbedding, self).__init__(vocab)
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0,
dropout:float=0.5, filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1),
pool_method: str='max', activation='relu', min_char_freq: int=2):
super(CNNCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)

for kernel in kernel_sizes:
assert kernel % 2 == 1, "Only odd kernel is allowed."

assert pool_method in ('max', 'avg')
self.dropout = nn.Dropout(dropout, inplace=True)
self.dropout = nn.Dropout(dropout)
self.pool_method = pool_method
# activation function
if isinstance(activation, str):
@@ -757,6 +801,7 @@ class CNNCharEmbedding(TokenEmbedding):
:param words: [batch_size, max_len]
:return: [batch_size, max_len, embed_size]
"""
words = self.drop_word(words)
batch_size, max_len = words.size()
chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len
word_lengths = self.word_lengths[words] # batch_size x max_len
@@ -779,7 +824,7 @@ class CNNCharEmbedding(TokenEmbedding):
conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), 0)
chars = torch.sum(conv_chars, dim=-2)/chars_masks.eq(0).sum(dim=-1, keepdim=True).float()
chars = self.fc(chars)
return chars
return self.dropout(chars)

@property
def requires_grad(self):
@@ -826,6 +871,7 @@ class LSTMCharEmbedding(TokenEmbedding):
:param vocab: 词表
:param embed_size: embedding的大小。默认值为50.
:param char_emb_size: character的embedding的大小。默认值为50.
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
:param dropout: 以多大概率drop
:param hidden_size: LSTM的中间hidden的大小,如果为bidirectional的,hidden会除二,默认为50.
:param pool_method: 支持'max', 'avg'
@@ -833,15 +879,16 @@ class LSTMCharEmbedding(TokenEmbedding):
:param min_char_freq: character的最小出现次数。默认值为2.
:param bidirectional: 是否使用双向的LSTM进行encode。默认值为True。
"""
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, dropout:float=0.5, hidden_size=50,
pool_method: str='max', activation='relu', min_char_freq: int=2, bidirectional=True):
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0,
dropout:float=0.5, hidden_size=50,pool_method: str='max', activation='relu', min_char_freq: int=2,
bidirectional=True):
super(LSTMCharEmbedding, self).__init__(vocab)

assert hidden_size % 2 == 0, "Only even kernel is allowed."

assert pool_method in ('max', 'avg')
self.pool_method = pool_method
self.dropout = nn.Dropout(dropout, inplace=True)
self.dropout = nn.Dropout(dropout)
# activation function
if isinstance(activation, str):
if activation.lower() == 'relu':
@@ -890,6 +937,7 @@ class LSTMCharEmbedding(TokenEmbedding):
:param words: [batch_size, max_len]
:return: [batch_size, max_len, embed_size]
"""
words = self.drop_word(words)
batch_size, max_len = words.size()
chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len
word_lengths = self.word_lengths[words] # batch_size x max_len
@@ -914,7 +962,7 @@ class LSTMCharEmbedding(TokenEmbedding):

chars = self.fc(chars)

return chars
return self.dropout(words)

@property
def requires_grad(self):
@@ -953,9 +1001,12 @@ class StackEmbedding(TokenEmbedding):


:param embeds: 一个由若干个TokenEmbedding组成的list,要求每一个TokenEmbedding的词表都保持一致
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。不同embedidng会在相同的位置
被设置为unknown。如果这里设置了dropout,则组成的embedding就不要再设置dropout了。
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。

"""
def __init__(self, embeds: List[TokenEmbedding]):
def __init__(self, embeds: List[TokenEmbedding], word_dropout=0, dropout=0):
vocabs = []
for embed in embeds:
if hasattr(embed, 'get_word_vocab'):
@@ -964,7 +1015,7 @@ class StackEmbedding(TokenEmbedding):
for vocab in vocabs[1:]:
assert vocab == _vocab, "All embeddings in StackEmbedding should use the same word vocabulary."

super(StackEmbedding, self).__init__(_vocab)
super(StackEmbedding, self).__init__(_vocab, word_dropout=word_dropout, dropout=dropout)
assert isinstance(embeds, list)
for embed in embeds:
assert isinstance(embed, TokenEmbedding), "Only TokenEmbedding type is supported."
@@ -1016,7 +1067,9 @@ class StackEmbedding(TokenEmbedding):
:return: 返回的shape和当前这个stack embedding中embedding的组成有关
"""
outputs = []
words = self.drop_word(words)
for embed in self.embeds:
outputs.append(embed(words))
return torch.cat(outputs, dim=-1)
outputs = self.dropout(torch.cat(outputs, dim=-1))
return outputs


Loading…
Cancel
Save