Browse Source

update embedding.py

tags/v0.4.10
xuyige 6 years ago
parent
commit
0f7c732f21
1 changed files with 106 additions and 82 deletions
  1. +106
    -82
      fastNLP/modules/encoder/embedding.py

+ 106
- 82
fastNLP/modules/encoder/embedding.py View File

@@ -33,7 +33,7 @@ class Embedding(nn.Module):
也可以传入TokenEmbedding对象 也可以传入TokenEmbedding对象
:param float dropout: 对Embedding的输出的dropout。 :param float dropout: 对Embedding的输出的dropout。
""" """
super().__init__()
super(Embedding, self).__init__()


self.embed = get_embeddings(init_embed) self.embed = get_embeddings(init_embed)
@@ -52,11 +52,11 @@ class Embedding(nn.Module):
return self.dropout(x) return self.dropout(x)


@property @property
def embed_size(self)->int:
def embed_size(self) -> int:
return self._embed_size return self._embed_size


@property @property
def embedding_dim(self)->int:
def embedding_dim(self) -> int:
return self._embed_size return self._embed_size


@property @property
@@ -84,10 +84,11 @@ class Embedding(nn.Module):
else: else:
return self.embed.weight.size() return self.embed.weight.size()



class TokenEmbedding(nn.Module): class TokenEmbedding(nn.Module):
def __init__(self, vocab): def __init__(self, vocab):
super().__init__()
assert vocab.padding_idx!=None, "You vocabulary must have padding."
super(TokenEmbedding, self).__init__()
assert vocab.padding_idx is not None, "You vocabulary must have padding."
self._word_vocab = vocab self._word_vocab = vocab
self._word_pad_index = vocab.padding_idx self._word_pad_index = vocab.padding_idx


@@ -98,7 +99,7 @@ class TokenEmbedding(nn.Module):
:return: :return:
""" """
requires_grads = set([param.requires_grad for param in self.parameters()]) requires_grads = set([param.requires_grad for param in self.parameters()])
if len(requires_grads)==1:
if len(requires_grads) == 1:
return requires_grads.pop() return requires_grads.pop()
else: else:
return None return None
@@ -113,7 +114,7 @@ class TokenEmbedding(nn.Module):
pass pass


@property @property
def embed_size(self)->int:
def embed_size(self) -> int:
return self._embed_size return self._embed_size


def get_word_vocab(self): def get_word_vocab(self):
@@ -128,8 +129,9 @@ class TokenEmbedding(nn.Module):
def size(self): def size(self):
return torch.Size(self.embed._word_vocab, self._embed_size) return torch.Size(self.embed._word_vocab, self._embed_size)



class StaticEmbedding(TokenEmbedding): class StaticEmbedding(TokenEmbedding):
def __init__(self, vocab:Vocabulary, model_dir_or_name:str='en', requires_grad:bool=False):
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', requires_grad: bool=False):
""" """
给定embedding的名称,根据vocab从embedding中抽取相应的数据。该Embedding可以就按照正常的embedding使用了 给定embedding的名称,根据vocab从embedding中抽取相应的数据。该Embedding可以就按照正常的embedding使用了


@@ -140,19 +142,20 @@ class StaticEmbedding(TokenEmbedding):
:param model_dir_or_name: 资源所在位置,可传入简写embedding名称,embedding对应资源可参考xxx :param model_dir_or_name: 资源所在位置,可传入简写embedding名称,embedding对应资源可参考xxx
:param requires_grad: 是否需要gradient :param requires_grad: 是否需要gradient
""" """
super().__init__(vocab)
super(StaticEmbedding, self).__init__(vocab)


# 优先定义需要下载的static embedding有哪些。这里估计需要自己搞一个server, # 优先定义需要下载的static embedding有哪些。这里估计需要自己搞一个server,
PRETRAIN_URL = _get_base_url('static') PRETRAIN_URL = _get_base_url('static')
PRETRAIN_STATIC_FILES = { PRETRAIN_STATIC_FILES = {
'en': 'glove.840B.300d-cc1ad5e1.tar.gz', 'en': 'glove.840B.300d-cc1ad5e1.tar.gz',
'en-glove-840b-300': 'glove.840B.300d-cc1ad5e1.tar.gz',
'en-glove-6b-50': "glove.6B.50d-a6028c70.tar.gz", 'en-glove-6b-50': "glove.6B.50d-a6028c70.tar.gz",
'en-word2vec-300': "GoogleNews-vectors-negative300-be166d9d.tar.gz", 'en-word2vec-300': "GoogleNews-vectors-negative300-be166d9d.tar.gz",
'cn': "tencent_cn-dab24577.tar.gz" 'cn': "tencent_cn-dab24577.tar.gz"
} }


# 得到cache_path # 得到cache_path
if model_dir_or_name in PRETRAIN_STATIC_FILES:
if model_dir_or_name.lower() in PRETRAIN_STATIC_FILES:
model_name = PRETRAIN_STATIC_FILES[model_dir_or_name] model_name = PRETRAIN_STATIC_FILES[model_dir_or_name]
model_url = PRETRAIN_URL + model_name model_url = PRETRAIN_URL + model_name
model_path = cached_path(model_url) model_path = cached_path(model_url)
@@ -167,8 +170,8 @@ class StaticEmbedding(TokenEmbedding):
embedding = torch.tensor(embedding) embedding = torch.tensor(embedding)
self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1], self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1],
padding_idx=vocab.padding_idx, padding_idx=vocab.padding_idx,
max_norm=None, norm_type=2, scale_grad_by_freq=False,
sparse=False, _weight=embedding)
max_norm=None, norm_type=2, scale_grad_by_freq=False,
sparse=False, _weight=embedding)
self._embed_size = self.embedding.weight.size(1) self._embed_size = self.embedding.weight.size(1)
self.requires_grad = requires_grad self.requires_grad = requires_grad


@@ -181,9 +184,10 @@ class StaticEmbedding(TokenEmbedding):
""" """
return self.embedding(words) return self.embedding(words)


class DynmicEmbedding(TokenEmbedding):
def __init__(self, vocab:Vocabulary):
super().__init__(vocab)

class DynamicEmbedding(TokenEmbedding):
def __init__(self, vocab: Vocabulary):
super(DynamicEmbedding, self).__init__(vocab)


def add_sentence_cache(self, *datasets, batch_size=32, device='cpu', delete_weights:bool=True): def add_sentence_cache(self, *datasets, batch_size=32, device='cpu', delete_weights:bool=True):
""" """
@@ -256,7 +260,7 @@ class DynmicEmbedding(TokenEmbedding):
_embeds.append(embed) _embeds.append(embed)
max_sent_len = max(map(len, _embeds)) max_sent_len = max(map(len, _embeds))
embeds = words.new_zeros(len(_embeds), max_sent_len, self.embed_size, dtype=torch.float, embeds = words.new_zeros(len(_embeds), max_sent_len, self.embed_size, dtype=torch.float,
device=words.device)
device=words.device)
for i, embed in enumerate(_embeds): for i, embed in enumerate(_embeds):
embeds[i, :len(embed)] = torch.FloatTensor(embed).to(words.device) embeds[i, :len(embed)] = torch.FloatTensor(embed).to(words.device)
return embeds return embeds
@@ -276,7 +280,7 @@ class DynmicEmbedding(TokenEmbedding):
del self.sent_embeds del self.sent_embeds




class ElmoEmbedding(DynmicEmbedding):
class ElmoEmbedding(DynamicEmbedding):
""" """
使用ELMO的embedding。初始化之后,只需要传入words就可以得到对应的embedding。 使用ELMO的embedding。初始化之后,只需要传入words就可以得到对应的embedding。


@@ -293,13 +297,13 @@ class ElmoEmbedding(DynmicEmbedding):
:param cache_word_reprs: 可以选择对word的表示进行cache; 设置为True的话,将在初始化的时候为每个word生成对应的embedding, :param cache_word_reprs: 可以选择对word的表示进行cache; 设置为True的话,将在初始化的时候为每个word生成对应的embedding,
并删除character encoder,之后将直接使用cache的embedding。 并删除character encoder,之后将直接使用cache的embedding。
""" """
def __init__(self, vocab:Vocabulary, model_dir_or_name:str='en',
layers:str='2', requires_grad:bool=False, cache_word_reprs:bool=False):
super().__init__(vocab)
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en',
layers: str='2', requires_grad: bool=False, cache_word_reprs: bool=False):
super(ElmoEmbedding, self).__init__(vocab)
layers = list(map(int, layers.split(','))) layers = list(map(int, layers.split(',')))
assert len(layers)>0, "Must choose one output"
assert len(layers) > 0, "Must choose one output"
for layer in layers: for layer in layers:
assert 0<=layer<=2, "Layer index should be in range [0, 2]."
assert 0 <= layer <= 2, "Layer index should be in range [0, 2]."
self.layers = layers self.layers = layers


# 根据model_dir_or_name检查是否存在并下载 # 根据model_dir_or_name检查是否存在并下载
@@ -308,7 +312,7 @@ class ElmoEmbedding(DynmicEmbedding):
PRETRAINED_ELMO_MODEL_DIR = {'en': 'elmo_en-d39843fe.tar.gz', PRETRAINED_ELMO_MODEL_DIR = {'en': 'elmo_en-d39843fe.tar.gz',
'cn': 'elmo_cn-5e9b34e2.tar.gz'} 'cn': 'elmo_cn-5e9b34e2.tar.gz'}


if model_dir_or_name in PRETRAINED_ELMO_MODEL_DIR:
if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR:
model_name = PRETRAINED_ELMO_MODEL_DIR[model_dir_or_name] model_name = PRETRAINED_ELMO_MODEL_DIR[model_dir_or_name]
model_url = PRETRAIN_URL + model_name model_url = PRETRAIN_URL + model_name
model_dir = cached_path(model_url) model_dir = cached_path(model_url)
@@ -319,9 +323,9 @@ class ElmoEmbedding(DynmicEmbedding):
raise ValueError(f"Cannot recognize {model_dir_or_name}.") raise ValueError(f"Cannot recognize {model_dir_or_name}.")
self.model = _ElmoModel(model_dir, vocab, cache_word_reprs=cache_word_reprs) self.model = _ElmoModel(model_dir, vocab, cache_word_reprs=cache_word_reprs)
self.requires_grad = requires_grad self.requires_grad = requires_grad
self._embed_size = len(self.layers)*self.model.config['encoder']['projection_dim']*2
self._embed_size = len(self.layers) * self.model.config['encoder']['projection_dim'] * 2


def forward(self, words:torch.LongTensor):
def forward(self, words: torch.LongTensor):
""" """
计算words的elmo embedding表示。根据elmo文章中介绍的ELMO实际上是有2L+1层结果,但是为了让结果比较容易拆分,token的 计算words的elmo embedding表示。根据elmo文章中介绍的ELMO实际上是有2L+1层结果,但是为了让结果比较容易拆分,token的
被重复了一次,使得实际上layer=0的结果是[token_embedding;token_embedding], 而layer=1的结果是[forward_hiddens; 被重复了一次,使得实际上layer=0的结果是[token_embedding;token_embedding], 而layer=1的结果是[forward_hiddens;
@@ -334,7 +338,7 @@ class ElmoEmbedding(DynmicEmbedding):
if outputs is not None: if outputs is not None:
return outputs return outputs
outputs = self.model(words) outputs = self.model(words)
if len(self.layers)==1:
if len(self.layers) == 1:
outputs = outputs[self.layers[0]] outputs = outputs[self.layers[0]]
else: else:
outputs = torch.cat([*outputs[self.layers]], dim=-1) outputs = torch.cat([*outputs[self.layers]], dim=-1)
@@ -353,7 +357,7 @@ class ElmoEmbedding(DynmicEmbedding):
""" """
requires_grads = set([param.requires_grad for name, param in self.named_parameters() requires_grads = set([param.requires_grad for name, param in self.named_parameters()
if 'words_to_chars_embedding' not in name]) if 'words_to_chars_embedding' not in name])
if len(requires_grads)==1:
if len(requires_grads) == 1:
return requires_grads.pop() return requires_grads.pop()
else: else:
return None return None
@@ -366,7 +370,7 @@ class ElmoEmbedding(DynmicEmbedding):
param.requires_grad = value param.requires_grad = value




class BertEmbedding(DynmicEmbedding):
class BertEmbedding(DynamicEmbedding):
""" """
使用bert对words进行encode的Embedding。 使用bert对words进行encode的Embedding。


@@ -374,25 +378,35 @@ class BertEmbedding(DynmicEmbedding):






:param vocab: Vocabulary
:param model_dir_or_name: 模型所在目录或者模型的名称。
:param layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层
:param pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces
中计算得到他对应的表示。支持'last', 'first', 'avg', 'max'.
:param include_cls_sep: bool,在bert计算句子的表示的时候,需要在前面加上[CLS]和[SEP], 是否在结果中保留这两个内容。 这样
:param fastNLP.Vocabulary vocab: 词表
:param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为``en-base-uncased``
:param str layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层
:param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces
中计算得到他对应的表示。支持``last``, ``first``, ``avg``, ``max``.
:param bool include_cls_sep: bool,在bert计算句子的表示的时候,需要在前面加上[CLS]和[SEP], 是否在结果中保留这两个内容。 这样
会使得word embedding的结果比输入的结果长两个token。在使用 :class::StackEmbedding 可能会遇到问题。 会使得word embedding的结果比输入的结果长两个token。在使用 :class::StackEmbedding 可能会遇到问题。
:param requires_grad: 是否需要gradient。
:param bool requires_grad: 是否需要gradient。
""" """
def __init__(self, vocab:Vocabulary, model_dir_or_name:str='en-base', layers:str='-1', pool_method:str='first',
include_cls_sep:bool=False, requires_grad:bool=False):
super().__init__(vocab)
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en-base-uncased', layers: str='-1',
pool_method: str='first', include_cls_sep: bool=False, requires_grad: bool=False):
super(BertEmbedding, self).__init__(vocab)
# 根据model_dir_or_name检查是否存在并下载 # 根据model_dir_or_name检查是否存在并下载
PRETRAIN_URL = _get_base_url('bert') PRETRAIN_URL = _get_base_url('bert')
PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip', PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip',
'en-base': 'bert-base-cased-f89bfe08.zip',
'cn-base': 'bert-base-chinese-29d0a84a.zip'}
'en-base-uncased': 'bert-base-uncased-3413b23c.zip',
'en-base-cased': 'bert-base-cased-f89bfe08.zip',
'en-large-uncased': '',
'en-large-cased': 'bert-large-cased-e0cf90fc.zip',

'cn': 'bert-base-chinese-29d0a84a.zip',
'cn-base': 'bert-base-chinese-29d0a84a.zip',

'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip',
'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip',
'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip',
}


if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR:
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR:
model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name]
model_url = PRETRAIN_URL + model_name model_url = PRETRAIN_URL + model_name
model_dir = cached_path(model_url) model_dir = cached_path(model_url)
@@ -435,7 +449,7 @@ class BertEmbedding(DynmicEmbedding):
""" """
requires_grads = set([param.requires_grad for name, param in self.named_parameters() requires_grads = set([param.requires_grad for name, param in self.named_parameters()
if 'word_pieces_lengths' not in name]) if 'word_pieces_lengths' not in name])
if len(requires_grads)==1:
if len(requires_grads) == 1:
return requires_grads.pop() return requires_grads.pop()
else: else:
return None return None
@@ -443,10 +457,11 @@ class BertEmbedding(DynmicEmbedding):
@requires_grad.setter @requires_grad.setter
def requires_grad(self, value): def requires_grad(self, value):
for name, param in self.named_parameters(): for name, param in self.named_parameters():
if 'word_pieces_lengths' in name: # 这个不能加入到requires_grad中
if 'word_pieces_lengths' in name: # 这个不能加入到requires_grad中
pass pass
param.requires_grad = value param.requires_grad = value



def _construct_char_vocab_from_vocab(vocab:Vocabulary, min_freq:int=1): def _construct_char_vocab_from_vocab(vocab:Vocabulary, min_freq:int=1):
""" """
给定一个word的vocabulary生成character的vocabulary. 给定一个word的vocabulary生成character的vocabulary.
@@ -475,30 +490,34 @@ class CNNCharEmbedding(TokenEmbedding):
:param filter_nums: filter的数量. 长度需要和kernels一致。 :param filter_nums: filter的数量. 长度需要和kernels一致。
:param kernels: kernel的大小. :param kernels: kernel的大小.
:param pool_method: character的表示在合成一个表示时所使用的pool方法,支持'avg', 'max' :param pool_method: character的表示在合成一个表示时所使用的pool方法,支持'avg', 'max'
:param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh'
:param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh' 或者自定义函数
:param min_char_freq: character的最少出现次数。 :param min_char_freq: character的最少出现次数。
""" """
def __init__(self, vocab:Vocabulary, embed_size:int=50, char_emb_size:int=50, filter_nums:List[int]=(40, 30, 20),
kernel_sizes:List[int]=(5, 3, 1), pool_method='max', activation='relu', min_char_freq:int=2):
super().__init__(vocab)
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50,
filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1), pool_method='max',
activation='relu', min_char_freq: int=2):
super(CNNCharEmbedding, self).__init__(vocab)


for kernel in kernel_sizes: for kernel in kernel_sizes:
assert kernel%2==1, "Only odd kernel is allowed."
assert kernel % 2 == 1, "Only odd kernel is allowed."


assert pool_method in ('max', 'avg') assert pool_method in ('max', 'avg')
self.pool_method = pool_method self.pool_method = pool_method
# activation function # activation function
if activation == 'relu':
self.activation = F.relu
elif activation == 'sigmoid':
self.activation = F.sigmoid
elif activation == 'tanh':
self.activation = F.tanh
elif activation == None:
self.activation = lambda x:x
if isinstance(activation, str):
if activation.lower() == 'relu':
self.activation = F.relu
elif activation.lower() == 'sigmoid':
self.activation = F.sigmoid
elif activation.lower() == 'tanh':
self.activation = F.tanh
elif activation is None:
self.activation = lambda x: x
elif callable(activation):
self.activation = activation
else: else:
raise Exception( raise Exception(
"Undefined activation function: choose from: relu, tanh, sigmoid")
"Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]")


print("Start constructing character vocabulary.") print("Start constructing character vocabulary.")
# 建立char的词表 # 建立char的词表
@@ -506,20 +525,21 @@ class CNNCharEmbedding(TokenEmbedding):
self.char_pad_index = self.char_vocab.padding_idx self.char_pad_index = self.char_vocab.padding_idx
print(f"In total, there are {len(self.char_vocab)} distinct characters.") print(f"In total, there are {len(self.char_vocab)} distinct characters.")
# 对vocab进行index # 对vocab进行index
self.max_word_len = max(map(lambda x:len(x[0]), vocab))
self.max_word_len = max(map(lambda x: len(x[0]), vocab))
self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), self.max_word_len), self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), self.max_word_len),
fill_value=self.char_pad_index, dtype=torch.long), fill_value=self.char_pad_index, dtype=torch.long),
requires_grad=False) requires_grad=False)
self.word_lengths = nn.Parameter(torch.zeros(len(vocab)).long(), requires_grad=False) self.word_lengths = nn.Parameter(torch.zeros(len(vocab)).long(), requires_grad=False)
for word, index in vocab: for word, index in vocab:
# if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了。 修改为不区分pad, 这样所有的<pad>也是同一个embed # if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了。 修改为不区分pad, 这样所有的<pad>也是同一个embed
self.words_to_chars_embedding[index, :len(word)] = torch.LongTensor([self.char_vocab.to_index(c) for c in word])
self.words_to_chars_embedding[index, :len(word)] = \
torch.LongTensor([self.char_vocab.to_index(c) for c in word])
self.word_lengths[index] = len(word) self.word_lengths[index] = len(word)
self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size) self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size)


self.convs = nn.ModuleList([
nn.Conv1d(char_emb_size, filter_nums[i], kernel_size=kernel_sizes[i], bias=True, padding=kernel_sizes[i] // 2)
for i in range(len(kernel_sizes))])
self.convs = nn.ModuleList([nn.Conv1d(
char_emb_size, filter_nums[i], kernel_size=kernel_sizes[i], bias=True, padding=kernel_sizes[i] // 2)
for i in range(len(kernel_sizes))])
self._embed_size = embed_size self._embed_size = embed_size
self.fc = nn.Linear(sum(filter_nums), embed_size) self.fc = nn.Linear(sum(filter_nums), embed_size)


@@ -527,8 +547,8 @@ class CNNCharEmbedding(TokenEmbedding):
""" """
输入words的index后,生成对应的words的表示。 输入words的index后,生成对应的words的表示。


:param words: batch_size x max_len
:return: batch_size x max_len x embed_size
:param words: [batch_size, max_len]
:return: [batch_size, max_len, embed_size]
""" """
batch_size, max_len = words.size() batch_size, max_len = words.size()
chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len
@@ -565,7 +585,7 @@ class CNNCharEmbedding(TokenEmbedding):
if 'words_to_chars_embedding' not in name and 'word_lengths' not in name: if 'words_to_chars_embedding' not in name and 'word_lengths' not in name:
params.append(param.requires_grad) params.append(param.requires_grad)
requires_grads = set(params) requires_grads = set(params)
if len(requires_grads)==1:
if len(requires_grads) == 1:
return requires_grads.pop() return requires_grads.pop()
else: else:
return None return None
@@ -573,7 +593,7 @@ class CNNCharEmbedding(TokenEmbedding):
@requires_grad.setter @requires_grad.setter
def requires_grad(self, value): def requires_grad(self, value):
for name, param in self.named_parameters(): for name, param in self.named_parameters():
if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中
if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中
pass pass
param.requires_grad = value param.requires_grad = value


@@ -591,13 +611,13 @@ class LSTMCharEmbedding(TokenEmbedding):
:param char_emb_size: character的embedding的大小。 :param char_emb_size: character的embedding的大小。
:param hidden_size: LSTM的中间hidden的大小,如果为bidirectional的,hidden会除二 :param hidden_size: LSTM的中间hidden的大小,如果为bidirectional的,hidden会除二
:param pool_method: 支持'max', 'avg' :param pool_method: 支持'max', 'avg'
:param activation: 激活函数,支持'relu', 'sigmoid', 'tanh'.
:param activation: 激活函数,支持'relu', 'sigmoid', 'tanh', 或者自定义函数.
:param min_char_freq: character的最小出现次数。 :param min_char_freq: character的最小出现次数。
:param bidirectional: 是否使用双向的LSTM进行encode。 :param bidirectional: 是否使用双向的LSTM进行encode。
""" """
def __init__(self, vocab:Vocabulary, embed_size:int=50, char_emb_size:int=50, hidden_size=50,
pool_method='max', activation='relu', min_char_freq:int=2, bidirectional=True):
super().__init__(vocab)
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, hidden_size=50,
pool_method='max', activation='relu', min_char_freq: int=2, bidirectional=True):
super(LSTMCharEmbedding, self).__init__(vocab)


assert hidden_size % 2 == 0, "Only even kernel is allowed." assert hidden_size % 2 == 0, "Only even kernel is allowed."


@@ -605,17 +625,20 @@ class LSTMCharEmbedding(TokenEmbedding):
self.pool_method = pool_method self.pool_method = pool_method


# activation function # activation function
if activation == 'relu':
self.activation = F.relu
elif activation == 'sigmoid':
self.activation = F.sigmoid
elif activation == 'tanh':
self.activation = F.tanh
elif activation == None:
self.activation = lambda x:x
if isinstance(activation, str):
if activation.lower() == 'relu':
self.activation = F.relu
elif activation.lower() == 'sigmoid':
self.activation = F.sigmoid
elif activation.lower() == 'tanh':
self.activation = F.tanh
elif activation is None:
self.activation = lambda x: x
elif callable(activation):
self.activation = activation
else: else:
raise Exception( raise Exception(
"Undefined activation function: choose from: relu, tanh, sigmoid")
"Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]")


print("Start constructing character vocabulary.") print("Start constructing character vocabulary.")
# 建立char的词表 # 建立char的词表
@@ -623,14 +646,15 @@ class LSTMCharEmbedding(TokenEmbedding):
self.char_pad_index = self.char_vocab.padding_idx self.char_pad_index = self.char_vocab.padding_idx
print(f"In total, there are {len(self.char_vocab)} distinct characters.") print(f"In total, there are {len(self.char_vocab)} distinct characters.")
# 对vocab进行index # 对vocab进行index
self.max_word_len = max(map(lambda x:len(x[0]), vocab))
self.max_word_len = max(map(lambda x: len(x[0]), vocab))
self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), self.max_word_len), self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), self.max_word_len),
fill_value=self.char_pad_index, dtype=torch.long), fill_value=self.char_pad_index, dtype=torch.long),
requires_grad=False) requires_grad=False)
self.word_lengths = nn.Parameter(torch.zeros(len(vocab)).long(), requires_grad=False) self.word_lengths = nn.Parameter(torch.zeros(len(vocab)).long(), requires_grad=False)
for word, index in vocab: for word, index in vocab:
# if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了. 修改为不区分pad与否 # if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了. 修改为不区分pad与否
self.words_to_chars_embedding[index, :len(word)] = torch.LongTensor([self.char_vocab.to_index(c) for c in word])
self.words_to_chars_embedding[index, :len(word)] = \
torch.LongTensor([self.char_vocab.to_index(c) for c in word])
self.word_lengths[index] = len(word) self.word_lengths[index] = len(word)
self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size) self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size)


@@ -650,7 +674,7 @@ class LSTMCharEmbedding(TokenEmbedding):
""" """
batch_size, max_len = words.size() batch_size, max_len = words.size()
chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len
word_lengths = self.word_lengths[words] # batch_size x max_len
word_lengths = self.word_lengths[words] # batch_size x max_len
max_word_len = word_lengths.max() max_word_len = word_lengths.max()
chars = chars[:, :, :max_word_len] chars = chars[:, :, :max_word_len]
# 为mask的地方为1 # 为mask的地方为1


Loading…
Cancel
Save