Browse Source

update documents in embedding.py

tags/v0.4.10
xuyige 6 years ago
parent
commit
a366f156ac
2 changed files with 77 additions and 48 deletions
  1. +1
    -0
      fastNLP/io/embed_loader.py
  2. +76
    -48
      fastNLP/modules/encoder/embedding.py

+ 1
- 0
fastNLP/io/embed_loader.py View File

@@ -26,6 +26,7 @@ class EmbeddingOption(Option):
error=error
)


class EmbedLoader(BaseLoader):
"""
别名::class:`fastNLP.io.EmbedLoader` :class:`fastNLP.io.embed_loader.EmbedLoader`


+ 76
- 48
fastNLP/modules/encoder/embedding.py View File

@@ -131,17 +131,23 @@ class TokenEmbedding(nn.Module):


class StaticEmbedding(TokenEmbedding):
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', requires_grad: bool=False):
"""
给定embedding的名称,根据vocab从embedding中抽取相应的数据。该Embedding可以就按照正常的embedding使用了
"""
别名::class:`fastNLP.modules.StaticEmbedding` :class:`fastNLP.modules.encoder.embedding.StaticEmbedding`

Example::
StaticEmbedding组件. 给定embedding的名称,根据vocab从embedding中抽取相应的数据。该Embedding可以就按照正常的embedding使用了

Example::

:param vocab: Vocabulary. 若该项为None则会读取所有的embedding。
:param model_dir_or_name: 资源所在位置,可传入简写embedding名称,embedding对应资源可参考xxx
:param requires_grad: 是否需要gradient
"""

:param vocab: Vocabulary. 若该项为None则会读取所有的embedding。
:param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding的文件名,第二种是传入embedding
的名称。目前支持的embedding包括{`en` 或者 `en-glove-840b-300` : glove.840B.300d, `en-glove-6b-50` : glove.6B.50d,
`en-word2vec-300` : GoogleNews-vectors-negative300}。第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。
:param requires_grad: 是否需要gradient

"""

def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', requires_grad: bool=False):
super(StaticEmbedding, self).__init__(vocab)

# 优先定义需要下载的static embedding有哪些。这里估计需要自己搞一个server,
@@ -185,11 +191,11 @@ class StaticEmbedding(TokenEmbedding):
return self.embedding(words)


class DynamicEmbedding(TokenEmbedding):
class ContextualEmbedding(TokenEmbedding):
def __init__(self, vocab: Vocabulary):
super(DynamicEmbedding, self).__init__(vocab)
super(ContextualEmbedding, self).__init__(vocab)

def add_sentence_cache(self, *datasets, batch_size=32, device='cpu', delete_weights:bool=True):
def add_sentence_cache(self, *datasets, batch_size=32, device='cpu', delete_weights: bool=True):
"""
由于动态embedding生成比较耗时,所以可以把每句话embedding缓存下来,这样就不需要每次都运行生成过程。

@@ -280,9 +286,12 @@ class DynamicEmbedding(TokenEmbedding):
del self.sent_embeds


class ElmoEmbedding(DynamicEmbedding):
class ElmoEmbedding(ContextualEmbedding):
"""
使用ELMO的embedding。初始化之后,只需要传入words就可以得到对应的embedding。
别名::class:`fastNLP.modules.ElmoEmbedding` :class:`fastNLP.modules.encoder.embedding.ElmoEmbedding`

使用ELMo的embedding。初始化之后,只需要传入words就可以得到对应的embedding。
我们提供的ELMo预训练模型来自 https://github.com/HIT-SCIR/ELMoForManyLangs

Example::

@@ -290,12 +299,13 @@ class ElmoEmbedding(DynamicEmbedding):
>>>

:param vocab: 词表
:param model_dir_or_name: 模型存放的目录或者模型的名称(将自动查看缓存中是否存在该模型,没有的话将自动下载)
:param model_dir_or_name: 可以有两种方式调用预训练好的ELMo embedding:第一种是传入ELMo权重的文件名,第二种是传入ELMo版本的名称,
目前支持的ELMo包括{`en` : 英文版本的ELMo, `cn` : 中文版本的ELMo,}。第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载
:param layers: str, 指定返回的层数, 以,隔开不同的层。如果要返回第二层的结果'2', 返回后两层的结果'1,2'。不同的层的结果
按照这个顺序concat起来。
:param requires_grad: bool, 该层是否需要gradient.
按照这个顺序concat起来。默认为'2'。
:param requires_grad: bool, 该层是否需要gradient. 默认为False
:param cache_word_reprs: 可以选择对word的表示进行cache; 设置为True的话,将在初始化的时候为每个word生成对应的embedding,
并删除character encoder,之后将直接使用cache的embedding。
并删除character encoder,之后将直接使用cache的embedding。默认为False。
"""
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en',
layers: str='2', requires_grad: bool=False, cache_word_reprs: bool=False):
@@ -370,12 +380,15 @@ class ElmoEmbedding(DynamicEmbedding):
param.requires_grad = value


class BertEmbedding(DynamicEmbedding):
class BertEmbedding(ContextualEmbedding):
"""
使用bert对words进行encode的Embedding。
别名::class:`fastNLP.modules.BertEmbedding` :class:`fastNLP.modules.encoder.embedding.BertEmbedding`

使用BERT对words进行encode的Embedding。

Example::

>>>


:param fastNLP.Vocabulary vocab: 词表
@@ -395,7 +408,7 @@ class BertEmbedding(DynamicEmbedding):
PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip',
'en-base-uncased': 'bert-base-uncased-3413b23c.zip',
'en-base-cased': 'bert-base-cased-f89bfe08.zip',
'en-large-uncased': '',
'en-large-uncased': 'bert-large-uncased-20939f45.zip',
'en-large-cased': 'bert-large-cased-e0cf90fc.zip',

'cn': 'bert-base-chinese-29d0a84a.zip',
@@ -478,23 +491,27 @@ def _construct_char_vocab_from_vocab(vocab:Vocabulary, min_freq:int=1):

class CNNCharEmbedding(TokenEmbedding):
"""
别名::class:`fastNLP.modules.CNNCharEmbedding` :class:`fastNLP.modules.encoder.embedding.CNNCharEmbedding`

使用CNN生成character embedding。CNN的结果为, CNN(x) -> activation(x) -> pool -> fc. 不同的kernel大小的fitler结果是
concat起来的。

Example::

>>>


:param vocab:
:param embed_size: 该word embedding的大小
:param char_emb_size: character的embed的大小。character是从vocab中生成的。
:param filter_nums: filter的数量. 长度需要和kernels一致。
:param kernels: kernel的大小.
:param pool_method: character的表示在合成一个表示时所使用的pool方法,支持'avg', 'max'
:param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh' 或者自定义函数
:param min_char_freq: character的最少出现次数。
:param vocab: 词表
:param embed_size: 该word embedding的大小,默认值为50.
:param char_emb_size: character的embed的大小。character是从vocab中生成的。默认值为50.
:param filter_nums: filter的数量. 长度需要和kernels一致。默认值为[40, 30, 20].
:param kernels: kernel的大小. 默认值为[5, 3, 1].
:param pool_method: character的表示在合成一个表示时所使用的pool方法,支持'avg', 'max'.
:param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh' 或者自定义函数.
:param min_char_freq: character的最少出现次数。默认值为2.
"""
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50,
filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1), pool_method='max',
filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1), pool_method: str='max',
activation='relu', min_char_freq: int=2):
super(CNNCharEmbedding, self).__init__(vocab)

@@ -600,6 +617,8 @@ class CNNCharEmbedding(TokenEmbedding):

class LSTMCharEmbedding(TokenEmbedding):
"""
别名::class:`fastNLP.modules.LSTMCharEmbedding` :class:`fastNLP.modules.encoder.embedding.LSTMCharEmbedding`

使用LSTM的方式对character进行encode.

Example::
@@ -607,16 +626,16 @@ class LSTMCharEmbedding(TokenEmbedding):
>>>

:param vocab: 词表
:param embed_size: embedding的大小
:param char_emb_size: character的embedding的大小。
:param hidden_size: LSTM的中间hidden的大小,如果为bidirectional的,hidden会除二
:param embed_size: embedding的大小。默认值为50.
:param char_emb_size: character的embedding的大小。默认值为50.
:param hidden_size: LSTM的中间hidden的大小,如果为bidirectional的,hidden会除二,默认为50.
:param pool_method: 支持'max', 'avg'
:param activation: 激活函数,支持'relu', 'sigmoid', 'tanh', 或者自定义函数.
:param min_char_freq: character的最小出现次数。
:param bidirectional: 是否使用双向的LSTM进行encode。
:param min_char_freq: character的最小出现次数。默认值为2.
:param bidirectional: 是否使用双向的LSTM进行encode。默认值为True。
"""
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, hidden_size=50,
pool_method='max', activation='relu', min_char_freq: int=2, bidirectional=True):
pool_method: str='max', activation='relu', min_char_freq: int=2, bidirectional=True):
super(LSTMCharEmbedding, self).__init__(vocab)

assert hidden_size % 2 == 0, "Only even kernel is allowed."
@@ -669,8 +688,8 @@ class LSTMCharEmbedding(TokenEmbedding):
"""
输入words的index后,生成对应的words的表示。

:param words: batch_size x max_len
:return: batch_size x max_len x embed_size
:param words: [batch_size, max_len]
:return: [batch_size, max_len, embed_size]
"""
batch_size, max_len = words.size()
chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len
@@ -681,16 +700,18 @@ class LSTMCharEmbedding(TokenEmbedding):
chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了
chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size

reshaped_chars = chars.reshape(batch_size*max_len, max_word_len, -1)
char_seq_len = chars_masks.eq(0).sum(dim=-1).reshape(batch_size*max_len)
lstm_chars = self.lstm(reshaped_chars, char_seq_len)[0].reshape(batch_size, max_len, max_word_len, -1) # B x M x M x H
reshaped_chars = chars.reshape(batch_size * max_len, max_word_len, -1)
char_seq_len = chars_masks.eq(0).sum(dim=-1).reshape(batch_size * max_len)
lstm_chars = self.lstm(reshaped_chars, char_seq_len)[0].reshape(batch_size, max_len, max_word_len, -1)
# B x M x M x H

lstm_chars = self.activation(lstm_chars)
if self.pool_method == 'max':
lstm_chars = lstm_chars.masked_fill(chars_masks.unsqueeze(-1), float('-inf'))
chars, _ = torch.max(lstm_chars, dim=-2) # batch_size x max_len x H
chars, _ = torch.max(lstm_chars, dim=-2) # batch_size x max_len x H
else:
lstm_chars = lstm_chars.masked_fill(chars_masks.unsqueeze(-1), 0)
chars = torch.sum(lstm_chars, dim=-2)/chars_masks.eq(0).sum(dim=-1, keepdim=True).float()
chars = torch.sum(lstm_chars, dim=-2) / chars_masks.eq(0).sum(dim=-1, keepdim=True).float()

chars = self.fc(chars)

@@ -707,7 +728,7 @@ class LSTMCharEmbedding(TokenEmbedding):
if 'words_to_chars_embedding' not in name and 'word_lengths' not in name:
params.append(param)
requires_grads = set(params)
if len(requires_grads)==1:
if len(requires_grads) == 1:
return requires_grads.pop()
else:
return None
@@ -715,34 +736,41 @@ class LSTMCharEmbedding(TokenEmbedding):
@requires_grad.setter
def requires_grad(self, value):
for name, param in self.named_parameters():
if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中
if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中
pass
param.requires_grad = value


class StackEmbedding(TokenEmbedding):
"""
别名::class:`fastNLP.modules.StackEmbedding` :class:`fastNLP.modules.encoder.embedding.StackEmbedding`

支持将多个embedding集合成一个embedding。

Example::

>>>


:param embeds: 一个由若干个TokenEmbedding组成的list,要求每一个TokenEmbedding的词表都保持一致

"""
def __init__(self, embeds:List[TokenEmbedding]):
def __init__(self, embeds: List[TokenEmbedding]):
vocabs = []
for embed in embeds:
vocabs.append(embed.get_word_vocab())
_vocab = vocabs[0]
for vocab in vocabs[1:]:
assert vocab==_vocab, "All embeddings should use the same word vocabulary."
assert vocab == _vocab, "All embeddings should use the same word vocabulary."

super().__init__(_vocab)
super(StackEmbedding, self).__init__(_vocab)
assert isinstance(embeds, list)
for embed in embeds:
assert isinstance(embed, TokenEmbedding), "Only TokenEmbedding type is supported."
self.embeds = nn.ModuleList(embeds)
self._embed_size = sum([embed.embed_size for embed in self.embeds])

def append(self, embed:TokenEmbedding):
def append(self, embed: TokenEmbedding):
"""
添加一个embedding到结尾。
:param embed:


Loading…
Cancel
Save