Browse Source

1. 修改bert,elmo的cache方式; 这样不需要使用sentence_index这种方式进行索引

tags/v0.4.10
yh_cc 6 years ago
parent
commit
ed3098e1b8
1 changed files with 71 additions and 67 deletions
  1. +71
    -67
      fastNLP/modules/encoder/embedding.py

+ 71
- 67
fastNLP/modules/encoder/embedding.py View File

@@ -17,8 +17,7 @@ from typing import List

from ... import DataSet, Batch, SequentialSampler
from ...core.utils import _move_model_to_device, _get_model_device
import numpy as np
from ...core.utils import _build_args


class Embedding(nn.Module):
"""
@@ -44,15 +43,12 @@ class Embedding(nn.Module):
else:
self._embed_size = self.embed.embed_size
def forward(self, x, sentence_index=None):
def forward(self, x):
"""
:param torch.LongTensor x: [batch, seq_len]
:param torch.LongTensor sentence_index:[batch_size, ]在一些动态embedding缓存的时候会用上。
:return: torch.Tensor : [batch, seq_len, embed_dim]
"""
# TODO 修改为更合理的方式
inputs = _build_args(self.embed.forward, words=x, sentence_index=sentence_index)
x = self.embed(**inputs)
x = self.embed(x)
return self.dropout(x)

@property
@@ -81,9 +77,19 @@ class Embedding(nn.Module):
else:
self.embed.requires_grad = value

@property
def size(self):
if isinstance(self.embed, TokenEmbedding):
return torch.Size(self.embed._word_vocab, self.embed.embed_size)
else:
return self.embed.weight.size()

class TokenEmbedding(nn.Module):
def __init__(self):
def __init__(self, vocab):
super().__init__()
assert vocab.padding_idx!=None, "You vocabulary must have padding."
self._word_vocab = vocab
self._word_pad_index = vocab.padding_idx

@property
def requires_grad(self):
@@ -110,8 +116,20 @@ class TokenEmbedding(nn.Module):
def embed_size(self)->int:
return self._embed_size

def get_word_vocab(self):
"""
返回embedding的词典。

:return: Vocabulary
"""
return self._word_vocab

@property
def size(self):
return torch.Size(self.embed._word_vocab, self._embed_size)

class StaticEmbedding(TokenEmbedding):
def __init__(self, vocab:Vocabulary, model_dir_or_name:str='en', requires_grad:bool=False):
def __init__(self, vocab:Vocabulary, model_dir_or_name:str='en', requires_grad:bool=False):
"""
给定embedding的名称,根据vocab从embedding中抽取相应的数据。该Embedding可以就按照正常的embedding使用了

@@ -122,7 +140,7 @@ class StaticEmbedding(TokenEmbedding):
:param model_dir_or_name: 资源所在位置,可传入简写embedding名称,embedding对应资源可参考xxx
:param requires_grad: 是否需要gradient
"""
super().__init__()
super().__init__(vocab)

# 优先定义需要下载的static embedding有哪些。这里估计需要自己搞一个server,
PRETRAIN_URL = _get_base_url('static')
@@ -145,11 +163,7 @@ class StaticEmbedding(TokenEmbedding):
raise ValueError(f"Cannot recognize {model_dir_or_name}.")

# 读取embedding
if vocab:
embedding = EmbedLoader.load_with_vocab(model_path, vocab=vocab)
else:
embedding, vocab = EmbedLoader.load_without_vocab(model_path)
self._vocab = vocab
embedding = EmbedLoader.load_with_vocab(model_path, vocab=vocab)
embedding = torch.tensor(embedding)
self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1],
padding_idx=vocab.padding_idx,
@@ -158,14 +172,6 @@ class StaticEmbedding(TokenEmbedding):
self._embed_size = self.embedding.weight.size(1)
self.requires_grad = requires_grad

def get_vocab(self):
"""
返回embedding的词典。如果是通过传入vocab获取的embedding,则返回的就是传入的vocab

:return: Vocabulary
"""
return self._vocab

def forward(self, words):
"""
传入words的index
@@ -177,15 +183,11 @@ class StaticEmbedding(TokenEmbedding):

class DynmicEmbedding(TokenEmbedding):
def __init__(self, vocab:Vocabulary):
assert vocab.padding_idx!=None, "You vocabulary must have padding."
super().__init__()
self._word_vocab = vocab
super().__init__(vocab)

def add_sentence_cache(self, *datasets, batch_size=32, device='cpu', delete_weights:bool=True):
"""
由于动态embedding生成比较耗时,所以可以把每句话embedding缓存下来,这样就不需要每次都运行生成过程。缓存的机制是,
给dataset中加入sentence_index这个column,然后每次输入的时候,将sentence_index这个column作为输入,这样将直接
对应index将结果返回。
由于动态embedding生成比较耗时,所以可以把每句话embedding缓存下来,这样就不需要每次都运行生成过程。

Example::

@@ -202,13 +204,11 @@ class DynmicEmbedding(TokenEmbedding):
try:
assert isinstance(dataset, DataSet), "Only fastNLP.DataSet object is allowed."
assert 'words' in dataset.get_input_name(), "`words` field has to be set as input."
if dataset.has_field('sentence_index'):
print("Warning: dataset has `sentence_index` already, refresh sometimes will cause chaos.")
except Exception as e:
print(f"Exception happens at {index} dataset.")
raise e

sent_embeds = []
sent_embeds = {}
_move_model_to_device(self, device=device)
device = _get_model_device(self)
pad_index = self._word_vocab.padding_idx
@@ -219,45 +219,47 @@ class DynmicEmbedding(TokenEmbedding):
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), prefetch=False)
for batch_x, batch_y in batch:
words = batch_x['words'].to(device)
words_list = words.tolist()
seq_len = words.ne(pad_index).sum(dim=-1)
max_len = words.size(1)
# 因为有些情况可能包含CLS, SEP, 从后面往前计算比较安全。
seq_len_from_behind =(max_len - words.ne(pad_index).sum(dim=-1)).tolist()
seq_len_from_behind =(max_len - seq_len).tolist()
word_embeds = self(words).detach().cpu().numpy()
for b in range(words.size(0)):
length = seq_len_from_behind[b]
if length==0:
sent_embeds.append(word_embeds[b])
sent_embeds[tuple(words_list[b][:seq_len[b]])] = word_embeds[b]
else:
sent_embeds.append(word_embeds[b, :-length])
sent_embeds[tuple(words_list[b][:seq_len[b]])] = word_embeds[b, :-length]
except Exception as e:
print(f"Exception happens at {index} dataset.")
raise e
print("Finish calculating sentence representations.")
start_idx = 0
for dataset in datasets:
sent_index = list(range(start_idx, start_idx+len(dataset)))
dataset.add_field('sentence_index', sent_index, is_input=True)
self.sent_embeds = np.array(sent_embeds)
self.sent_embeds = sent_embeds
if delete_weights:
self._delete_model_weights()

def _get_sent_reprs(self, sentence_index):
def _get_sent_reprs(self, words):
"""
获取sentence的表示,如果有缓存,则返回缓存的值; 没有缓存则返回None

:param sentence_index: torch.LongTensor
:param words: torch.LongTensor
:return:
"""
if sentence_index is not None:
if hasattr(self, 'sent_embeds'):
sentence_index_lst = sentence_index.tolist()
_embeds = self.sent_embeds[sentence_index_lst]
max_sent_len = max(map(len, _embeds))
embeds = sentence_index.new_zeros(len(_embeds), max_sent_len, self.embed_size, dtype=torch.float,
device=sentence_index.device)
for i, embed in enumerate(_embeds):
embeds[i, :len(embed)] = torch.FloatTensor(embed).to(sentence_index.device)
return embeds
if hasattr(self, 'sent_embeds'):
words_list = words.tolist()
seq_len = words.ne(self._word_pad_index).sum(dim=-1)
_embeds = []
for b in range(len(words)):
words_i = tuple(words_list[b][:seq_len[b]])
embed = self.sent_embeds[words_i]
_embeds.append(embed)
max_sent_len = max(map(len, _embeds))
embeds = words.new_zeros(len(_embeds), max_sent_len, self.embed_size, dtype=torch.float,
device=words.device)
for i, embed in enumerate(_embeds):
embeds[i, :len(embed)] = torch.FloatTensor(embed).to(words.device)
return embeds
return None

@abstractmethod
@@ -304,7 +306,7 @@ class ElmoEmbedding(DynmicEmbedding):
PRETRAIN_URL = _get_base_url('elmo')
# TODO 把baidu云上的加上去
PRETRAINED_ELMO_MODEL_DIR = {'en': 'elmo_en-d39843fe.tar.gz',
'cn': 'elmo_cn.zip'}
'cn': 'elmo_cn-5e9b34e2.tar.gz'}

if model_dir_or_name in PRETRAINED_ELMO_MODEL_DIR:
model_name = PRETRAINED_ELMO_MODEL_DIR[model_dir_or_name]
@@ -319,17 +321,16 @@ class ElmoEmbedding(DynmicEmbedding):
self.requires_grad = requires_grad
self._embed_size = len(self.layers)*self.model.config['encoder']['projection_dim']*2

def forward(self, words:torch.LongTensor, sentence_index=None):
def forward(self, words:torch.LongTensor):
"""
计算words的elmo embedding表示。根据elmo文章中介绍的ELMO实际上是有2L+1层结果,但是为了让结果比较容易拆分,token的
被重复了一次,使得实际上layer=0的结果是[token_embedding;token_embedding], 而layer=1的结果是[forward_hiddens;
backward_hiddens].

:param words: batch_size x max_len
:param sentence_index: batch_size, 在使用了sentence缓存的时候会有用。
:return: torch.FloatTensor. batch_size x max_len x (512*len(self.layers))
"""
outputs = self._get_sent_reprs(sentence_index)
outputs = self._get_sent_reprs(words)
if outputs is not None:
return outputs
outputs = self.model(words)
@@ -373,7 +374,6 @@ class BertEmbedding(DynmicEmbedding):




:param vocab: Vocabulary
:param model_dir_or_name: 模型所在目录或者模型的名称。
:param layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层
@@ -411,16 +411,15 @@ class BertEmbedding(DynmicEmbedding):
def _delete_model_weights(self):
del self.model

def forward(self, words, sentence_index=None):
def forward(self, words):
"""
计算words的bert embedding表示。计算之前会在每句话的开始增加[CLS]在结束增加[SEP], 并根据include_cls_sep判断要不要
删除这两个token的表示。

:param words: batch_size x max_len
:param sentence_index: batch_size, 在缓存了sentence的表示的使用
:return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers))
"""
outputs = self._get_sent_reprs(sentence_index)
outputs = self._get_sent_reprs(words)
if outputs is not None:
return outputs
outputs = self.model(words)
@@ -481,7 +480,7 @@ class CNNCharEmbedding(TokenEmbedding):
"""
def __init__(self, vocab:Vocabulary, embed_size:int=50, char_emb_size:int=50, filter_nums:List[int]=(40, 30, 20),
kernel_sizes:List[int]=(5, 3, 1), pool_method='max', activation='relu', min_char_freq:int=2):
super().__init__()
super().__init__(vocab)

for kernel in kernel_sizes:
assert kernel%2==1, "Only odd kernel is allowed."
@@ -598,7 +597,7 @@ class LSTMCharEmbedding(TokenEmbedding):
"""
def __init__(self, vocab:Vocabulary, embed_size:int=50, char_emb_size:int=50, hidden_size=50,
pool_method='max', activation='relu', min_char_freq:int=2, bidirectional=True):
super().__init__()
super().__init__(vocab)

assert hidden_size % 2 == 0, "Only even kernel is allowed."

@@ -705,7 +704,14 @@ class StackEmbedding(TokenEmbedding):

"""
def __init__(self, embeds:List[TokenEmbedding]):
super().__init__()
vocabs = []
for embed in embeds:
vocabs.append(embed.get_word_vocab())
_vocab = vocabs[0]
for vocab in vocabs[1:]:
assert vocab==_vocab, "All embeddings should use the same word vocabulary."

super().__init__(_vocab)
assert isinstance(embeds, list)
for embed in embeds:
assert isinstance(embed, TokenEmbedding), "Only TokenEmbedding type is supported."
@@ -749,17 +755,15 @@ class StackEmbedding(TokenEmbedding):
for embed in self.embeds():
embed.requires_grad = value

def forward(self, words, sentence_index=None):
def forward(self, words):
"""
得到多个embedding的结果,并把结果按照顺序concat起来。

:param words: batch_size x max_len
:param sentence_index: batch_size, 仅在包含的embedding中具有sentence cache的时候用
:return: 返回的shape和当前这个stack embedding中embedding的组成有关
"""
outputs = []
for embed in self.embeds:
inputs = _build_args(embed.forward, words=words, sentence_index=sentence_index)
outputs.append(embed(**inputs))
outputs.append(embed(words))
return torch.cat(outputs, dim=-1)


Loading…
Cancel
Save