From 0f7c732f21e37b944e6cf7925a40d29557ef250b Mon Sep 17 00:00:00 2001 From: xuyige Date: Wed, 12 Jun 2019 19:08:36 +0800 Subject: [PATCH] update embedding.py --- fastNLP/modules/encoder/embedding.py | 188 +++++++++++++++------------ 1 file changed, 106 insertions(+), 82 deletions(-) diff --git a/fastNLP/modules/encoder/embedding.py b/fastNLP/modules/encoder/embedding.py index 45ba7885..2f3007df 100644 --- a/fastNLP/modules/encoder/embedding.py +++ b/fastNLP/modules/encoder/embedding.py @@ -33,7 +33,7 @@ class Embedding(nn.Module): 也可以传入TokenEmbedding对象 :param float dropout: 对Embedding的输出的dropout。 """ - super().__init__() + super(Embedding, self).__init__() self.embed = get_embeddings(init_embed) @@ -52,11 +52,11 @@ class Embedding(nn.Module): return self.dropout(x) @property - def embed_size(self)->int: + def embed_size(self) -> int: return self._embed_size @property - def embedding_dim(self)->int: + def embedding_dim(self) -> int: return self._embed_size @property @@ -84,10 +84,11 @@ class Embedding(nn.Module): else: return self.embed.weight.size() + class TokenEmbedding(nn.Module): def __init__(self, vocab): - super().__init__() - assert vocab.padding_idx!=None, "You vocabulary must have padding." + super(TokenEmbedding, self).__init__() + assert vocab.padding_idx is not None, "You vocabulary must have padding." self._word_vocab = vocab self._word_pad_index = vocab.padding_idx @@ -98,7 +99,7 @@ class TokenEmbedding(nn.Module): :return: """ requires_grads = set([param.requires_grad for param in self.parameters()]) - if len(requires_grads)==1: + if len(requires_grads) == 1: return requires_grads.pop() else: return None @@ -113,7 +114,7 @@ class TokenEmbedding(nn.Module): pass @property - def embed_size(self)->int: + def embed_size(self) -> int: return self._embed_size def get_word_vocab(self): @@ -128,8 +129,9 @@ class TokenEmbedding(nn.Module): def size(self): return torch.Size(self.embed._word_vocab, self._embed_size) + class StaticEmbedding(TokenEmbedding): - def __init__(self, vocab:Vocabulary, model_dir_or_name:str='en', requires_grad:bool=False): + def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', requires_grad: bool=False): """ 给定embedding的名称,根据vocab从embedding中抽取相应的数据。该Embedding可以就按照正常的embedding使用了 @@ -140,19 +142,20 @@ class StaticEmbedding(TokenEmbedding): :param model_dir_or_name: 资源所在位置,可传入简写embedding名称,embedding对应资源可参考xxx :param requires_grad: 是否需要gradient """ - super().__init__(vocab) + super(StaticEmbedding, self).__init__(vocab) # 优先定义需要下载的static embedding有哪些。这里估计需要自己搞一个server, PRETRAIN_URL = _get_base_url('static') PRETRAIN_STATIC_FILES = { 'en': 'glove.840B.300d-cc1ad5e1.tar.gz', + 'en-glove-840b-300': 'glove.840B.300d-cc1ad5e1.tar.gz', 'en-glove-6b-50': "glove.6B.50d-a6028c70.tar.gz", 'en-word2vec-300': "GoogleNews-vectors-negative300-be166d9d.tar.gz", 'cn': "tencent_cn-dab24577.tar.gz" } # 得到cache_path - if model_dir_or_name in PRETRAIN_STATIC_FILES: + if model_dir_or_name.lower() in PRETRAIN_STATIC_FILES: model_name = PRETRAIN_STATIC_FILES[model_dir_or_name] model_url = PRETRAIN_URL + model_name model_path = cached_path(model_url) @@ -167,8 +170,8 @@ class StaticEmbedding(TokenEmbedding): embedding = torch.tensor(embedding) self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1], padding_idx=vocab.padding_idx, - max_norm=None, norm_type=2, scale_grad_by_freq=False, - sparse=False, _weight=embedding) + max_norm=None, norm_type=2, scale_grad_by_freq=False, + sparse=False, _weight=embedding) self._embed_size = self.embedding.weight.size(1) self.requires_grad = requires_grad @@ -181,9 +184,10 @@ class StaticEmbedding(TokenEmbedding): """ return self.embedding(words) -class DynmicEmbedding(TokenEmbedding): - def __init__(self, vocab:Vocabulary): - super().__init__(vocab) + +class DynamicEmbedding(TokenEmbedding): + def __init__(self, vocab: Vocabulary): + super(DynamicEmbedding, self).__init__(vocab) def add_sentence_cache(self, *datasets, batch_size=32, device='cpu', delete_weights:bool=True): """ @@ -256,7 +260,7 @@ class DynmicEmbedding(TokenEmbedding): _embeds.append(embed) max_sent_len = max(map(len, _embeds)) embeds = words.new_zeros(len(_embeds), max_sent_len, self.embed_size, dtype=torch.float, - device=words.device) + device=words.device) for i, embed in enumerate(_embeds): embeds[i, :len(embed)] = torch.FloatTensor(embed).to(words.device) return embeds @@ -276,7 +280,7 @@ class DynmicEmbedding(TokenEmbedding): del self.sent_embeds -class ElmoEmbedding(DynmicEmbedding): +class ElmoEmbedding(DynamicEmbedding): """ 使用ELMO的embedding。初始化之后,只需要传入words就可以得到对应的embedding。 @@ -293,13 +297,13 @@ class ElmoEmbedding(DynmicEmbedding): :param cache_word_reprs: 可以选择对word的表示进行cache; 设置为True的话,将在初始化的时候为每个word生成对应的embedding, 并删除character encoder,之后将直接使用cache的embedding。 """ - def __init__(self, vocab:Vocabulary, model_dir_or_name:str='en', - layers:str='2', requires_grad:bool=False, cache_word_reprs:bool=False): - super().__init__(vocab) + def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', + layers: str='2', requires_grad: bool=False, cache_word_reprs: bool=False): + super(ElmoEmbedding, self).__init__(vocab) layers = list(map(int, layers.split(','))) - assert len(layers)>0, "Must choose one output" + assert len(layers) > 0, "Must choose one output" for layer in layers: - assert 0<=layer<=2, "Layer index should be in range [0, 2]." + assert 0 <= layer <= 2, "Layer index should be in range [0, 2]." self.layers = layers # 根据model_dir_or_name检查是否存在并下载 @@ -308,7 +312,7 @@ class ElmoEmbedding(DynmicEmbedding): PRETRAINED_ELMO_MODEL_DIR = {'en': 'elmo_en-d39843fe.tar.gz', 'cn': 'elmo_cn-5e9b34e2.tar.gz'} - if model_dir_or_name in PRETRAINED_ELMO_MODEL_DIR: + if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR: model_name = PRETRAINED_ELMO_MODEL_DIR[model_dir_or_name] model_url = PRETRAIN_URL + model_name model_dir = cached_path(model_url) @@ -319,9 +323,9 @@ class ElmoEmbedding(DynmicEmbedding): raise ValueError(f"Cannot recognize {model_dir_or_name}.") self.model = _ElmoModel(model_dir, vocab, cache_word_reprs=cache_word_reprs) self.requires_grad = requires_grad - self._embed_size = len(self.layers)*self.model.config['encoder']['projection_dim']*2 + self._embed_size = len(self.layers) * self.model.config['encoder']['projection_dim'] * 2 - def forward(self, words:torch.LongTensor): + def forward(self, words: torch.LongTensor): """ 计算words的elmo embedding表示。根据elmo文章中介绍的ELMO实际上是有2L+1层结果,但是为了让结果比较容易拆分,token的 被重复了一次,使得实际上layer=0的结果是[token_embedding;token_embedding], 而layer=1的结果是[forward_hiddens; @@ -334,7 +338,7 @@ class ElmoEmbedding(DynmicEmbedding): if outputs is not None: return outputs outputs = self.model(words) - if len(self.layers)==1: + if len(self.layers) == 1: outputs = outputs[self.layers[0]] else: outputs = torch.cat([*outputs[self.layers]], dim=-1) @@ -353,7 +357,7 @@ class ElmoEmbedding(DynmicEmbedding): """ requires_grads = set([param.requires_grad for name, param in self.named_parameters() if 'words_to_chars_embedding' not in name]) - if len(requires_grads)==1: + if len(requires_grads) == 1: return requires_grads.pop() else: return None @@ -366,7 +370,7 @@ class ElmoEmbedding(DynmicEmbedding): param.requires_grad = value -class BertEmbedding(DynmicEmbedding): +class BertEmbedding(DynamicEmbedding): """ 使用bert对words进行encode的Embedding。 @@ -374,25 +378,35 @@ class BertEmbedding(DynmicEmbedding): - :param vocab: Vocabulary - :param model_dir_or_name: 模型所在目录或者模型的名称。 - :param layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 - :param pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces - 中计算得到他对应的表示。支持'last', 'first', 'avg', 'max'. - :param include_cls_sep: bool,在bert计算句子的表示的时候,需要在前面加上[CLS]和[SEP], 是否在结果中保留这两个内容。 这样 + :param fastNLP.Vocabulary vocab: 词表 + :param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为``en-base-uncased`` + :param str layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 + :param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces + 中计算得到他对应的表示。支持``last``, ``first``, ``avg``, ``max``. + :param bool include_cls_sep: bool,在bert计算句子的表示的时候,需要在前面加上[CLS]和[SEP], 是否在结果中保留这两个内容。 这样 会使得word embedding的结果比输入的结果长两个token。在使用 :class::StackEmbedding 可能会遇到问题。 - :param requires_grad: 是否需要gradient。 + :param bool requires_grad: 是否需要gradient。 """ - def __init__(self, vocab:Vocabulary, model_dir_or_name:str='en-base', layers:str='-1', pool_method:str='first', - include_cls_sep:bool=False, requires_grad:bool=False): - super().__init__(vocab) + def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en-base-uncased', layers: str='-1', + pool_method: str='first', include_cls_sep: bool=False, requires_grad: bool=False): + super(BertEmbedding, self).__init__(vocab) # 根据model_dir_or_name检查是否存在并下载 PRETRAIN_URL = _get_base_url('bert') PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip', - 'en-base': 'bert-base-cased-f89bfe08.zip', - 'cn-base': 'bert-base-chinese-29d0a84a.zip'} + 'en-base-uncased': 'bert-base-uncased-3413b23c.zip', + 'en-base-cased': 'bert-base-cased-f89bfe08.zip', + 'en-large-uncased': '', + 'en-large-cased': 'bert-large-cased-e0cf90fc.zip', + + 'cn': 'bert-base-chinese-29d0a84a.zip', + 'cn-base': 'bert-base-chinese-29d0a84a.zip', + + 'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip', + 'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip', + 'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip', + } - if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR: + if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] model_url = PRETRAIN_URL + model_name model_dir = cached_path(model_url) @@ -435,7 +449,7 @@ class BertEmbedding(DynmicEmbedding): """ requires_grads = set([param.requires_grad for name, param in self.named_parameters() if 'word_pieces_lengths' not in name]) - if len(requires_grads)==1: + if len(requires_grads) == 1: return requires_grads.pop() else: return None @@ -443,10 +457,11 @@ class BertEmbedding(DynmicEmbedding): @requires_grad.setter def requires_grad(self, value): for name, param in self.named_parameters(): - if 'word_pieces_lengths' in name: # 这个不能加入到requires_grad中 + if 'word_pieces_lengths' in name: # 这个不能加入到requires_grad中 pass param.requires_grad = value + def _construct_char_vocab_from_vocab(vocab:Vocabulary, min_freq:int=1): """ 给定一个word的vocabulary生成character的vocabulary. @@ -475,30 +490,34 @@ class CNNCharEmbedding(TokenEmbedding): :param filter_nums: filter的数量. 长度需要和kernels一致。 :param kernels: kernel的大小. :param pool_method: character的表示在合成一个表示时所使用的pool方法,支持'avg', 'max' - :param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh' + :param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh' 或者自定义函数 :param min_char_freq: character的最少出现次数。 """ - def __init__(self, vocab:Vocabulary, embed_size:int=50, char_emb_size:int=50, filter_nums:List[int]=(40, 30, 20), - kernel_sizes:List[int]=(5, 3, 1), pool_method='max', activation='relu', min_char_freq:int=2): - super().__init__(vocab) + def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, + filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1), pool_method='max', + activation='relu', min_char_freq: int=2): + super(CNNCharEmbedding, self).__init__(vocab) for kernel in kernel_sizes: - assert kernel%2==1, "Only odd kernel is allowed." + assert kernel % 2 == 1, "Only odd kernel is allowed." assert pool_method in ('max', 'avg') self.pool_method = pool_method # activation function - if activation == 'relu': - self.activation = F.relu - elif activation == 'sigmoid': - self.activation = F.sigmoid - elif activation == 'tanh': - self.activation = F.tanh - elif activation == None: - self.activation = lambda x:x + if isinstance(activation, str): + if activation.lower() == 'relu': + self.activation = F.relu + elif activation.lower() == 'sigmoid': + self.activation = F.sigmoid + elif activation.lower() == 'tanh': + self.activation = F.tanh + elif activation is None: + self.activation = lambda x: x + elif callable(activation): + self.activation = activation else: raise Exception( - "Undefined activation function: choose from: relu, tanh, sigmoid") + "Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]") print("Start constructing character vocabulary.") # 建立char的词表 @@ -506,20 +525,21 @@ class CNNCharEmbedding(TokenEmbedding): self.char_pad_index = self.char_vocab.padding_idx print(f"In total, there are {len(self.char_vocab)} distinct characters.") # 对vocab进行index - self.max_word_len = max(map(lambda x:len(x[0]), vocab)) + self.max_word_len = max(map(lambda x: len(x[0]), vocab)) self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), self.max_word_len), fill_value=self.char_pad_index, dtype=torch.long), requires_grad=False) self.word_lengths = nn.Parameter(torch.zeros(len(vocab)).long(), requires_grad=False) for word, index in vocab: # if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了。 修改为不区分pad, 这样所有的也是同一个embed - self.words_to_chars_embedding[index, :len(word)] = torch.LongTensor([self.char_vocab.to_index(c) for c in word]) + self.words_to_chars_embedding[index, :len(word)] = \ + torch.LongTensor([self.char_vocab.to_index(c) for c in word]) self.word_lengths[index] = len(word) self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size) - self.convs = nn.ModuleList([ - nn.Conv1d(char_emb_size, filter_nums[i], kernel_size=kernel_sizes[i], bias=True, padding=kernel_sizes[i] // 2) - for i in range(len(kernel_sizes))]) + self.convs = nn.ModuleList([nn.Conv1d( + char_emb_size, filter_nums[i], kernel_size=kernel_sizes[i], bias=True, padding=kernel_sizes[i] // 2) + for i in range(len(kernel_sizes))]) self._embed_size = embed_size self.fc = nn.Linear(sum(filter_nums), embed_size) @@ -527,8 +547,8 @@ class CNNCharEmbedding(TokenEmbedding): """ 输入words的index后,生成对应的words的表示。 - :param words: batch_size x max_len - :return: batch_size x max_len x embed_size + :param words: [batch_size, max_len] + :return: [batch_size, max_len, embed_size] """ batch_size, max_len = words.size() chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len @@ -565,7 +585,7 @@ class CNNCharEmbedding(TokenEmbedding): if 'words_to_chars_embedding' not in name and 'word_lengths' not in name: params.append(param.requires_grad) requires_grads = set(params) - if len(requires_grads)==1: + if len(requires_grads) == 1: return requires_grads.pop() else: return None @@ -573,7 +593,7 @@ class CNNCharEmbedding(TokenEmbedding): @requires_grad.setter def requires_grad(self, value): for name, param in self.named_parameters(): - if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中 + if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中 pass param.requires_grad = value @@ -591,13 +611,13 @@ class LSTMCharEmbedding(TokenEmbedding): :param char_emb_size: character的embedding的大小。 :param hidden_size: LSTM的中间hidden的大小,如果为bidirectional的,hidden会除二 :param pool_method: 支持'max', 'avg' - :param activation: 激活函数,支持'relu', 'sigmoid', 'tanh'. + :param activation: 激活函数,支持'relu', 'sigmoid', 'tanh', 或者自定义函数. :param min_char_freq: character的最小出现次数。 :param bidirectional: 是否使用双向的LSTM进行encode。 """ - def __init__(self, vocab:Vocabulary, embed_size:int=50, char_emb_size:int=50, hidden_size=50, - pool_method='max', activation='relu', min_char_freq:int=2, bidirectional=True): - super().__init__(vocab) + def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, hidden_size=50, + pool_method='max', activation='relu', min_char_freq: int=2, bidirectional=True): + super(LSTMCharEmbedding, self).__init__(vocab) assert hidden_size % 2 == 0, "Only even kernel is allowed." @@ -605,17 +625,20 @@ class LSTMCharEmbedding(TokenEmbedding): self.pool_method = pool_method # activation function - if activation == 'relu': - self.activation = F.relu - elif activation == 'sigmoid': - self.activation = F.sigmoid - elif activation == 'tanh': - self.activation = F.tanh - elif activation == None: - self.activation = lambda x:x + if isinstance(activation, str): + if activation.lower() == 'relu': + self.activation = F.relu + elif activation.lower() == 'sigmoid': + self.activation = F.sigmoid + elif activation.lower() == 'tanh': + self.activation = F.tanh + elif activation is None: + self.activation = lambda x: x + elif callable(activation): + self.activation = activation else: raise Exception( - "Undefined activation function: choose from: relu, tanh, sigmoid") + "Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]") print("Start constructing character vocabulary.") # 建立char的词表 @@ -623,14 +646,15 @@ class LSTMCharEmbedding(TokenEmbedding): self.char_pad_index = self.char_vocab.padding_idx print(f"In total, there are {len(self.char_vocab)} distinct characters.") # 对vocab进行index - self.max_word_len = max(map(lambda x:len(x[0]), vocab)) + self.max_word_len = max(map(lambda x: len(x[0]), vocab)) self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), self.max_word_len), fill_value=self.char_pad_index, dtype=torch.long), requires_grad=False) self.word_lengths = nn.Parameter(torch.zeros(len(vocab)).long(), requires_grad=False) for word, index in vocab: # if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了. 修改为不区分pad与否 - self.words_to_chars_embedding[index, :len(word)] = torch.LongTensor([self.char_vocab.to_index(c) for c in word]) + self.words_to_chars_embedding[index, :len(word)] = \ + torch.LongTensor([self.char_vocab.to_index(c) for c in word]) self.word_lengths[index] = len(word) self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size) @@ -650,7 +674,7 @@ class LSTMCharEmbedding(TokenEmbedding): """ batch_size, max_len = words.size() chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len - word_lengths = self.word_lengths[words] # batch_size x max_len + word_lengths = self.word_lengths[words] # batch_size x max_len max_word_len = word_lengths.max() chars = chars[:, :, :max_word_len] # 为mask的地方为1