|
@@ -35,15 +35,15 @@ class Embedding(nn.Module): |
|
|
|
|
|
|
|
|
Embedding组件. 可以通过self.num_embeddings获取词表大小; self.embedding_dim获取embedding的维度""" |
|
|
Embedding组件. 可以通过self.num_embeddings获取词表大小; self.embedding_dim获取embedding的维度""" |
|
|
|
|
|
|
|
|
def __init__(self, init_embed, dropout=0.0, dropout_word=0, unk_index=None): |
|
|
|
|
|
|
|
|
def __init__(self, init_embed, word_dropout=0, dropout=0.0, unk_index=None): |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray init_embed: Embedding的大小(传入tuple(int, int), |
|
|
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray init_embed: Embedding的大小(传入tuple(int, int), |
|
|
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding; |
|
|
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding; |
|
|
也可以传入TokenEmbedding对象 |
|
|
|
|
|
|
|
|
:param float word_dropout: 按照一定概率随机将word设置为unk_index,这样可以使得unk这个token得到足够的训练, 且会对网络有 |
|
|
|
|
|
一定的regularize的作用。 |
|
|
:param float dropout: 对Embedding的输出的dropout。 |
|
|
:param float dropout: 对Embedding的输出的dropout。 |
|
|
:param float dropout_word: 按照一定比例随机将word设置为unk的idx,这样可以使得unk这个token得到足够的训练 |
|
|
|
|
|
:param int unk_index: drop word时替换为的index,如果init_embed为TokenEmbedding不需要传入该值。 |
|
|
|
|
|
|
|
|
:param int unk_index: drop word时替换为的index。fastNLP的Vocabulary的unk_index默认为1。 |
|
|
""" |
|
|
""" |
|
|
super(Embedding, self).__init__() |
|
|
super(Embedding, self).__init__() |
|
|
|
|
|
|
|
@@ -52,21 +52,21 @@ class Embedding(nn.Module): |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
if not isinstance(self.embed, TokenEmbedding): |
|
|
if not isinstance(self.embed, TokenEmbedding): |
|
|
self._embed_size = self.embed.weight.size(1) |
|
|
self._embed_size = self.embed.weight.size(1) |
|
|
if dropout_word>0 and not isinstance(unk_index, int): |
|
|
|
|
|
|
|
|
if word_dropout>0 and not isinstance(unk_index, int): |
|
|
raise ValueError("When drop word is set, you need to pass in the unk_index.") |
|
|
raise ValueError("When drop word is set, you need to pass in the unk_index.") |
|
|
else: |
|
|
else: |
|
|
self._embed_size = self.embed.embed_size |
|
|
self._embed_size = self.embed.embed_size |
|
|
unk_index = self.embed.get_word_vocab().unknown_idx |
|
|
unk_index = self.embed.get_word_vocab().unknown_idx |
|
|
self.unk_index = unk_index |
|
|
self.unk_index = unk_index |
|
|
self.dropout_word = dropout_word |
|
|
|
|
|
|
|
|
self.word_dropout = word_dropout |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
def forward(self, x): |
|
|
""" |
|
|
""" |
|
|
:param torch.LongTensor x: [batch, seq_len] |
|
|
:param torch.LongTensor x: [batch, seq_len] |
|
|
:return: torch.Tensor : [batch, seq_len, embed_dim] |
|
|
:return: torch.Tensor : [batch, seq_len, embed_dim] |
|
|
""" |
|
|
""" |
|
|
if self.dropout_word>0 and self.training: |
|
|
|
|
|
mask = torch.ones_like(x).float() * self.dropout_word |
|
|
|
|
|
|
|
|
if self.word_dropout>0 and self.training: |
|
|
|
|
|
mask = torch.ones_like(x).float() * self.word_dropout |
|
|
mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1 |
|
|
mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1 |
|
|
x = x.masked_fill(mask, self.unk_index) |
|
|
x = x.masked_fill(mask, self.unk_index) |
|
|
x = self.embed(x) |
|
|
x = self.embed(x) |
|
@@ -117,11 +117,38 @@ class Embedding(nn.Module): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TokenEmbedding(nn.Module): |
|
|
class TokenEmbedding(nn.Module): |
|
|
def __init__(self, vocab): |
|
|
|
|
|
|
|
|
def __init__(self, vocab, word_dropout=0.0, dropout=0.0): |
|
|
super(TokenEmbedding, self).__init__() |
|
|
super(TokenEmbedding, self).__init__() |
|
|
assert vocab.padding_idx is not None, "You vocabulary must have padding." |
|
|
|
|
|
|
|
|
assert vocab.padding is not None, "Vocabulary must have a padding entry." |
|
|
self._word_vocab = vocab |
|
|
self._word_vocab = vocab |
|
|
self._word_pad_index = vocab.padding_idx |
|
|
self._word_pad_index = vocab.padding_idx |
|
|
|
|
|
if word_dropout>0: |
|
|
|
|
|
assert vocab.unknown is not None, "Vocabulary must have unknown entry when you want to drop a word." |
|
|
|
|
|
self.word_dropout = word_dropout |
|
|
|
|
|
self._word_unk_index = vocab.unknown_idx |
|
|
|
|
|
self.dropout_layer = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
|
|
|
def drop_word(self, words): |
|
|
|
|
|
""" |
|
|
|
|
|
按照设定随机将words设置为unknown_index。 |
|
|
|
|
|
|
|
|
|
|
|
:param torch.LongTensor words: batch_size x max_len |
|
|
|
|
|
:return: |
|
|
|
|
|
""" |
|
|
|
|
|
if self.dropout_word > 0 and self.training: |
|
|
|
|
|
mask = torch.ones_like(words).float() * self.word_dropout |
|
|
|
|
|
mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1 |
|
|
|
|
|
words = words.masked_fill(mask, self._word_unk_index) |
|
|
|
|
|
return words |
|
|
|
|
|
|
|
|
|
|
|
def dropout(self, words): |
|
|
|
|
|
""" |
|
|
|
|
|
对embedding后的word表示进行drop。 |
|
|
|
|
|
|
|
|
|
|
|
:param torch.FloatTensor words: batch_size x max_len x embed_size |
|
|
|
|
|
:return: |
|
|
|
|
|
""" |
|
|
|
|
|
return self.dropout_layer(words) |
|
|
|
|
|
|
|
|
@property |
|
|
@property |
|
|
def requires_grad(self): |
|
|
def requires_grad(self): |
|
@@ -163,6 +190,9 @@ class TokenEmbedding(nn.Module): |
|
|
def size(self): |
|
|
def size(self): |
|
|
return torch.Size(self.num_embedding, self._embed_size) |
|
|
return torch.Size(self.num_embedding, self._embed_size) |
|
|
|
|
|
|
|
|
|
|
|
@abstractmethod |
|
|
|
|
|
def forward(self, *input): |
|
|
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
class StaticEmbedding(TokenEmbedding): |
|
|
class StaticEmbedding(TokenEmbedding): |
|
|
""" |
|
|
""" |
|
@@ -181,13 +211,15 @@ class StaticEmbedding(TokenEmbedding): |
|
|
`en-word2vec-300` : GoogleNews-vectors-negative300}。第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。 |
|
|
`en-word2vec-300` : GoogleNews-vectors-negative300}。第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。 |
|
|
:param bool requires_grad: 是否需要gradient. 默认为True |
|
|
:param bool requires_grad: 是否需要gradient. 默认为True |
|
|
:param callable init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。调用该方法时传入一个tensor对象。 |
|
|
:param callable init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。调用该方法时传入一个tensor对象。 |
|
|
:param bool normailize: 是否对vector进行normalize,使得每个vector的norm为1。 |
|
|
|
|
|
:param bool lower: 是否将vocab中的词语小写后再和预训练的词表进行匹配。如果你的词表中包含大写的词语,或者就是需要单独 |
|
|
:param bool lower: 是否将vocab中的词语小写后再和预训练的词表进行匹配。如果你的词表中包含大写的词语,或者就是需要单独 |
|
|
为大写的词语开辟一个vector表示,则将lower设置为False。 |
|
|
为大写的词语开辟一个vector表示,则将lower设置为False。 |
|
|
|
|
|
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 |
|
|
|
|
|
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 |
|
|
|
|
|
:param bool normailize: 是否对vector进行normalize,使得每个vector的norm为1。 |
|
|
""" |
|
|
""" |
|
|
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', requires_grad: bool=True, init_method=None, |
|
|
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', requires_grad: bool=True, init_method=None, |
|
|
normalize=False, lower=False): |
|
|
|
|
|
super(StaticEmbedding, self).__init__(vocab) |
|
|
|
|
|
|
|
|
lower=False, dropout=0, word_dropout=0, normalize=False): |
|
|
|
|
|
super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) |
|
|
|
|
|
|
|
|
# 得到cache_path |
|
|
# 得到cache_path |
|
|
if model_dir_or_name.lower() in PRETRAIN_STATIC_FILES: |
|
|
if model_dir_or_name.lower() in PRETRAIN_STATIC_FILES: |
|
@@ -362,12 +394,15 @@ class StaticEmbedding(TokenEmbedding): |
|
|
""" |
|
|
""" |
|
|
if hasattr(self, 'words_to_words'): |
|
|
if hasattr(self, 'words_to_words'): |
|
|
words = self.words_to_words[words] |
|
|
words = self.words_to_words[words] |
|
|
return self.embedding(words) |
|
|
|
|
|
|
|
|
words = self.drop_word(words) |
|
|
|
|
|
words = self.embedding(words) |
|
|
|
|
|
words = self.dropout(words) |
|
|
|
|
|
return words |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ContextualEmbedding(TokenEmbedding): |
|
|
class ContextualEmbedding(TokenEmbedding): |
|
|
def __init__(self, vocab: Vocabulary): |
|
|
|
|
|
super(ContextualEmbedding, self).__init__(vocab) |
|
|
|
|
|
|
|
|
def __init__(self, vocab: Vocabulary, word_dropout:float=0.0, dropout:float=0.0): |
|
|
|
|
|
super(ContextualEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) |
|
|
|
|
|
|
|
|
def add_sentence_cache(self, *datasets, batch_size=32, device='cpu', delete_weights: bool=True): |
|
|
def add_sentence_cache(self, *datasets, batch_size=32, device='cpu', delete_weights: bool=True): |
|
|
""" |
|
|
""" |
|
@@ -473,12 +508,14 @@ class ElmoEmbedding(ContextualEmbedding): |
|
|
按照这个顺序concat起来。默认为'2'。'mix'会使用可学习的权重结合不同层的表示(权重是否可训练与requires_grad保持一致, |
|
|
按照这个顺序concat起来。默认为'2'。'mix'会使用可学习的权重结合不同层的表示(权重是否可训练与requires_grad保持一致, |
|
|
初始化权重对三层结果进行mean-pooling, 可以通过ElmoEmbedding.set_mix_weights_requires_grad()方法只将mix weights设置为可学习。) |
|
|
初始化权重对三层结果进行mean-pooling, 可以通过ElmoEmbedding.set_mix_weights_requires_grad()方法只将mix weights设置为可学习。) |
|
|
:param requires_grad: bool, 该层是否需要gradient, 默认为False. |
|
|
:param requires_grad: bool, 该层是否需要gradient, 默认为False. |
|
|
|
|
|
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 |
|
|
|
|
|
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 |
|
|
:param cache_word_reprs: 可以选择对word的表示进行cache; 设置为True的话,将在初始化的时候为每个word生成对应的embedding, |
|
|
:param cache_word_reprs: 可以选择对word的表示进行cache; 设置为True的话,将在初始化的时候为每个word生成对应的embedding, |
|
|
并删除character encoder,之后将直接使用cache的embedding。默认为False。 |
|
|
并删除character encoder,之后将直接使用cache的embedding。默认为False。 |
|
|
""" |
|
|
""" |
|
|
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', |
|
|
|
|
|
layers: str='2', requires_grad: bool=False, cache_word_reprs: bool=False): |
|
|
|
|
|
super(ElmoEmbedding, self).__init__(vocab) |
|
|
|
|
|
|
|
|
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', layers: str='2', requires_grad: bool=False, |
|
|
|
|
|
word_dropout=0.0, dropout=0.0, cache_word_reprs: bool=False): |
|
|
|
|
|
super(ElmoEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) |
|
|
|
|
|
|
|
|
# 根据model_dir_or_name检查是否存在并下载 |
|
|
# 根据model_dir_or_name检查是否存在并下载 |
|
|
if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR: |
|
|
if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR: |
|
@@ -545,11 +582,13 @@ class ElmoEmbedding(ContextualEmbedding): |
|
|
:param words: batch_size x max_len |
|
|
:param words: batch_size x max_len |
|
|
:return: torch.FloatTensor. batch_size x max_len x (512*len(self.layers)) |
|
|
:return: torch.FloatTensor. batch_size x max_len x (512*len(self.layers)) |
|
|
""" |
|
|
""" |
|
|
|
|
|
words = self.drop_word(words) |
|
|
outputs = self._get_sent_reprs(words) |
|
|
outputs = self._get_sent_reprs(words) |
|
|
if outputs is not None: |
|
|
if outputs is not None: |
|
|
return outputs |
|
|
|
|
|
|
|
|
return self.dropout(outputs) |
|
|
outputs = self.model(words) |
|
|
outputs = self.model(words) |
|
|
return self._get_outputs(outputs) |
|
|
|
|
|
|
|
|
outputs = self._get_outputs(outputs) |
|
|
|
|
|
return self.dropout(outputs) |
|
|
|
|
|
|
|
|
def _delete_model_weights(self): |
|
|
def _delete_model_weights(self): |
|
|
for name in ['layers', 'model', 'layer_weights', 'gamma']: |
|
|
for name in ['layers', 'model', 'layer_weights', 'gamma']: |
|
@@ -595,13 +634,16 @@ class BertEmbedding(ContextualEmbedding): |
|
|
:param str layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 |
|
|
:param str layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 |
|
|
:param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces |
|
|
:param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces |
|
|
中计算得到它对应的表示。支持``last``, ``first``, ``avg``, ``max``。 |
|
|
中计算得到它对应的表示。支持``last``, ``first``, ``avg``, ``max``。 |
|
|
|
|
|
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 |
|
|
|
|
|
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 |
|
|
:param bool include_cls_sep: bool,在bert计算句子的表示的时候,需要在前面加上[CLS]和[SEP], 是否在结果中保留这两个内容。 这样 |
|
|
:param bool include_cls_sep: bool,在bert计算句子的表示的时候,需要在前面加上[CLS]和[SEP], 是否在结果中保留这两个内容。 这样 |
|
|
会使得word embedding的结果比输入的结果长两个token。在使用 :class::StackEmbedding 可能会遇到问题。 |
|
|
会使得word embedding的结果比输入的结果长两个token。在使用 :class::StackEmbedding 可能会遇到问题。 |
|
|
:param bool requires_grad: 是否需要gradient。 |
|
|
:param bool requires_grad: 是否需要gradient。 |
|
|
""" |
|
|
""" |
|
|
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en-base-uncased', layers: str='-1', |
|
|
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en-base-uncased', layers: str='-1', |
|
|
pool_method: str='first', include_cls_sep: bool=False, requires_grad: bool=False): |
|
|
|
|
|
super(BertEmbedding, self).__init__(vocab) |
|
|
|
|
|
|
|
|
pool_method: str='first', word_dropout=0, dropout=0, requires_grad: bool=False, |
|
|
|
|
|
include_cls_sep: bool=False): |
|
|
|
|
|
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) |
|
|
|
|
|
|
|
|
# 根据model_dir_or_name检查是否存在并下载 |
|
|
# 根据model_dir_or_name检查是否存在并下载 |
|
|
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: |
|
|
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: |
|
@@ -632,13 +674,14 @@ class BertEmbedding(ContextualEmbedding): |
|
|
:param torch.LongTensor words: [batch_size, max_len] |
|
|
:param torch.LongTensor words: [batch_size, max_len] |
|
|
:return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers)) |
|
|
:return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers)) |
|
|
""" |
|
|
""" |
|
|
|
|
|
words = self.drop_word(words) |
|
|
outputs = self._get_sent_reprs(words) |
|
|
outputs = self._get_sent_reprs(words) |
|
|
if outputs is not None: |
|
|
if outputs is not None: |
|
|
return outputs |
|
|
|
|
|
|
|
|
return self.dropout(words) |
|
|
outputs = self.model(words) |
|
|
outputs = self.model(words) |
|
|
outputs = torch.cat([*outputs], dim=-1) |
|
|
outputs = torch.cat([*outputs], dim=-1) |
|
|
|
|
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
return self.dropout(words) |
|
|
|
|
|
|
|
|
@property |
|
|
@property |
|
|
def requires_grad(self): |
|
|
def requires_grad(self): |
|
@@ -680,8 +723,8 @@ class CNNCharEmbedding(TokenEmbedding): |
|
|
""" |
|
|
""" |
|
|
别名::class:`fastNLP.modules.CNNCharEmbedding` :class:`fastNLP.modules.encoder.embedding.CNNCharEmbedding` |
|
|
别名::class:`fastNLP.modules.CNNCharEmbedding` :class:`fastNLP.modules.encoder.embedding.CNNCharEmbedding` |
|
|
|
|
|
|
|
|
使用CNN生成character embedding。CNN的结果为, embed(x) -> Dropout(x) -> CNN(x) -> activation(x) -> pool |
|
|
|
|
|
-> fc. 不同的kernel大小的fitler结果是concat起来的。 |
|
|
|
|
|
|
|
|
使用CNN生成character embedding。CNN的结果为, embed(x) -> Dropout(x) -> CNN(x) -> activation(x) -> pool -> fc -> Dropout. |
|
|
|
|
|
不同的kernel大小的fitler结果是concat起来的。 |
|
|
|
|
|
|
|
|
Example:: |
|
|
Example:: |
|
|
|
|
|
|
|
@@ -691,23 +734,24 @@ class CNNCharEmbedding(TokenEmbedding): |
|
|
:param vocab: 词表 |
|
|
:param vocab: 词表 |
|
|
:param embed_size: 该word embedding的大小,默认值为50. |
|
|
:param embed_size: 该word embedding的大小,默认值为50. |
|
|
:param char_emb_size: character的embed的大小。character是从vocab中生成的。默认值为50. |
|
|
:param char_emb_size: character的embed的大小。character是从vocab中生成的。默认值为50. |
|
|
:param dropout: 以多大的概率drop |
|
|
|
|
|
|
|
|
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 |
|
|
|
|
|
:param float dropout: 以多大的概率drop |
|
|
:param filter_nums: filter的数量. 长度需要和kernels一致。默认值为[40, 30, 20]. |
|
|
:param filter_nums: filter的数量. 长度需要和kernels一致。默认值为[40, 30, 20]. |
|
|
:param kernel_sizes: kernel的大小. 默认值为[5, 3, 1]. |
|
|
:param kernel_sizes: kernel的大小. 默认值为[5, 3, 1]. |
|
|
:param pool_method: character的表示在合成一个表示时所使用的pool方法,支持'avg', 'max'. |
|
|
:param pool_method: character的表示在合成一个表示时所使用的pool方法,支持'avg', 'max'. |
|
|
:param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh' 或者自定义函数. |
|
|
:param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh' 或者自定义函数. |
|
|
:param min_char_freq: character的最少出现次数。默认值为2. |
|
|
:param min_char_freq: character的最少出现次数。默认值为2. |
|
|
""" |
|
|
""" |
|
|
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, dropout:float=0.5, |
|
|
|
|
|
filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1), pool_method: str='max', |
|
|
|
|
|
activation='relu', min_char_freq: int=2): |
|
|
|
|
|
super(CNNCharEmbedding, self).__init__(vocab) |
|
|
|
|
|
|
|
|
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0, |
|
|
|
|
|
dropout:float=0.5, filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1), |
|
|
|
|
|
pool_method: str='max', activation='relu', min_char_freq: int=2): |
|
|
|
|
|
super(CNNCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) |
|
|
|
|
|
|
|
|
for kernel in kernel_sizes: |
|
|
for kernel in kernel_sizes: |
|
|
assert kernel % 2 == 1, "Only odd kernel is allowed." |
|
|
assert kernel % 2 == 1, "Only odd kernel is allowed." |
|
|
|
|
|
|
|
|
assert pool_method in ('max', 'avg') |
|
|
assert pool_method in ('max', 'avg') |
|
|
self.dropout = nn.Dropout(dropout, inplace=True) |
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.pool_method = pool_method |
|
|
self.pool_method = pool_method |
|
|
# activation function |
|
|
# activation function |
|
|
if isinstance(activation, str): |
|
|
if isinstance(activation, str): |
|
@@ -757,6 +801,7 @@ class CNNCharEmbedding(TokenEmbedding): |
|
|
:param words: [batch_size, max_len] |
|
|
:param words: [batch_size, max_len] |
|
|
:return: [batch_size, max_len, embed_size] |
|
|
:return: [batch_size, max_len, embed_size] |
|
|
""" |
|
|
""" |
|
|
|
|
|
words = self.drop_word(words) |
|
|
batch_size, max_len = words.size() |
|
|
batch_size, max_len = words.size() |
|
|
chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len |
|
|
chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len |
|
|
word_lengths = self.word_lengths[words] # batch_size x max_len |
|
|
word_lengths = self.word_lengths[words] # batch_size x max_len |
|
@@ -779,7 +824,7 @@ class CNNCharEmbedding(TokenEmbedding): |
|
|
conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), 0) |
|
|
conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), 0) |
|
|
chars = torch.sum(conv_chars, dim=-2)/chars_masks.eq(0).sum(dim=-1, keepdim=True).float() |
|
|
chars = torch.sum(conv_chars, dim=-2)/chars_masks.eq(0).sum(dim=-1, keepdim=True).float() |
|
|
chars = self.fc(chars) |
|
|
chars = self.fc(chars) |
|
|
return chars |
|
|
|
|
|
|
|
|
return self.dropout(chars) |
|
|
|
|
|
|
|
|
@property |
|
|
@property |
|
|
def requires_grad(self): |
|
|
def requires_grad(self): |
|
@@ -826,6 +871,7 @@ class LSTMCharEmbedding(TokenEmbedding): |
|
|
:param vocab: 词表 |
|
|
:param vocab: 词表 |
|
|
:param embed_size: embedding的大小。默认值为50. |
|
|
:param embed_size: embedding的大小。默认值为50. |
|
|
:param char_emb_size: character的embedding的大小。默认值为50. |
|
|
:param char_emb_size: character的embedding的大小。默认值为50. |
|
|
|
|
|
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 |
|
|
:param dropout: 以多大概率drop |
|
|
:param dropout: 以多大概率drop |
|
|
:param hidden_size: LSTM的中间hidden的大小,如果为bidirectional的,hidden会除二,默认为50. |
|
|
:param hidden_size: LSTM的中间hidden的大小,如果为bidirectional的,hidden会除二,默认为50. |
|
|
:param pool_method: 支持'max', 'avg' |
|
|
:param pool_method: 支持'max', 'avg' |
|
@@ -833,15 +879,16 @@ class LSTMCharEmbedding(TokenEmbedding): |
|
|
:param min_char_freq: character的最小出现次数。默认值为2. |
|
|
:param min_char_freq: character的最小出现次数。默认值为2. |
|
|
:param bidirectional: 是否使用双向的LSTM进行encode。默认值为True。 |
|
|
:param bidirectional: 是否使用双向的LSTM进行encode。默认值为True。 |
|
|
""" |
|
|
""" |
|
|
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, dropout:float=0.5, hidden_size=50, |
|
|
|
|
|
pool_method: str='max', activation='relu', min_char_freq: int=2, bidirectional=True): |
|
|
|
|
|
|
|
|
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0, |
|
|
|
|
|
dropout:float=0.5, hidden_size=50,pool_method: str='max', activation='relu', min_char_freq: int=2, |
|
|
|
|
|
bidirectional=True): |
|
|
super(LSTMCharEmbedding, self).__init__(vocab) |
|
|
super(LSTMCharEmbedding, self).__init__(vocab) |
|
|
|
|
|
|
|
|
assert hidden_size % 2 == 0, "Only even kernel is allowed." |
|
|
assert hidden_size % 2 == 0, "Only even kernel is allowed." |
|
|
|
|
|
|
|
|
assert pool_method in ('max', 'avg') |
|
|
assert pool_method in ('max', 'avg') |
|
|
self.pool_method = pool_method |
|
|
self.pool_method = pool_method |
|
|
self.dropout = nn.Dropout(dropout, inplace=True) |
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
# activation function |
|
|
# activation function |
|
|
if isinstance(activation, str): |
|
|
if isinstance(activation, str): |
|
|
if activation.lower() == 'relu': |
|
|
if activation.lower() == 'relu': |
|
@@ -890,6 +937,7 @@ class LSTMCharEmbedding(TokenEmbedding): |
|
|
:param words: [batch_size, max_len] |
|
|
:param words: [batch_size, max_len] |
|
|
:return: [batch_size, max_len, embed_size] |
|
|
:return: [batch_size, max_len, embed_size] |
|
|
""" |
|
|
""" |
|
|
|
|
|
words = self.drop_word(words) |
|
|
batch_size, max_len = words.size() |
|
|
batch_size, max_len = words.size() |
|
|
chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len |
|
|
chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len |
|
|
word_lengths = self.word_lengths[words] # batch_size x max_len |
|
|
word_lengths = self.word_lengths[words] # batch_size x max_len |
|
@@ -914,7 +962,7 @@ class LSTMCharEmbedding(TokenEmbedding): |
|
|
|
|
|
|
|
|
chars = self.fc(chars) |
|
|
chars = self.fc(chars) |
|
|
|
|
|
|
|
|
return chars |
|
|
|
|
|
|
|
|
return self.dropout(words) |
|
|
|
|
|
|
|
|
@property |
|
|
@property |
|
|
def requires_grad(self): |
|
|
def requires_grad(self): |
|
@@ -953,9 +1001,12 @@ class StackEmbedding(TokenEmbedding): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
:param embeds: 一个由若干个TokenEmbedding组成的list,要求每一个TokenEmbedding的词表都保持一致 |
|
|
:param embeds: 一个由若干个TokenEmbedding组成的list,要求每一个TokenEmbedding的词表都保持一致 |
|
|
|
|
|
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。不同embedidng会在相同的位置 |
|
|
|
|
|
被设置为unknown。如果这里设置了dropout,则组成的embedding就不要再设置dropout了。 |
|
|
|
|
|
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
def __init__(self, embeds: List[TokenEmbedding]): |
|
|
|
|
|
|
|
|
def __init__(self, embeds: List[TokenEmbedding], word_dropout=0, dropout=0): |
|
|
vocabs = [] |
|
|
vocabs = [] |
|
|
for embed in embeds: |
|
|
for embed in embeds: |
|
|
if hasattr(embed, 'get_word_vocab'): |
|
|
if hasattr(embed, 'get_word_vocab'): |
|
@@ -964,7 +1015,7 @@ class StackEmbedding(TokenEmbedding): |
|
|
for vocab in vocabs[1:]: |
|
|
for vocab in vocabs[1:]: |
|
|
assert vocab == _vocab, "All embeddings in StackEmbedding should use the same word vocabulary." |
|
|
assert vocab == _vocab, "All embeddings in StackEmbedding should use the same word vocabulary." |
|
|
|
|
|
|
|
|
super(StackEmbedding, self).__init__(_vocab) |
|
|
|
|
|
|
|
|
super(StackEmbedding, self).__init__(_vocab, word_dropout=word_dropout, dropout=dropout) |
|
|
assert isinstance(embeds, list) |
|
|
assert isinstance(embeds, list) |
|
|
for embed in embeds: |
|
|
for embed in embeds: |
|
|
assert isinstance(embed, TokenEmbedding), "Only TokenEmbedding type is supported." |
|
|
assert isinstance(embed, TokenEmbedding), "Only TokenEmbedding type is supported." |
|
@@ -1016,7 +1067,9 @@ class StackEmbedding(TokenEmbedding): |
|
|
:return: 返回的shape和当前这个stack embedding中embedding的组成有关 |
|
|
:return: 返回的shape和当前这个stack embedding中embedding的组成有关 |
|
|
""" |
|
|
""" |
|
|
outputs = [] |
|
|
outputs = [] |
|
|
|
|
|
words = self.drop_word(words) |
|
|
for embed in self.embeds: |
|
|
for embed in self.embeds: |
|
|
outputs.append(embed(words)) |
|
|
outputs.append(embed(words)) |
|
|
return torch.cat(outputs, dim=-1) |
|
|
|
|
|
|
|
|
outputs = self.dropout(torch.cat(outputs, dim=-1)) |
|
|
|
|
|
return outputs |
|
|
|
|
|
|