|
|
@@ -49,10 +49,13 @@ class BertEmbedding(ContextualEmbedding): |
|
|
|
:param bool pooled_cls: 返回的[CLS]是否使用预训练中的BertPool映射一下,仅在include_cls_sep时有效。如果下游任务只取[CLS]做预测, |
|
|
|
一般该值为True。 |
|
|
|
:param bool requires_grad: 是否需要gradient以更新Bert的权重。 |
|
|
|
:param bool auto_truncate: 当句子words拆分为word pieces长度超过bert最大允许长度(一般为512), 自动截掉拆分后的超过510个 |
|
|
|
word pieces后的内容,并将第512个word piece置为[SEP]。超过长度的部分的encode结果直接全部置零。一般仅有只使用[CLS] |
|
|
|
来进行分类的任务将auto_truncate置为True。 |
|
|
|
""" |
|
|
|
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en-base-uncased', layers: str='-1', |
|
|
|
pool_method: str='first', word_dropout=0, dropout=0, include_cls_sep: bool=False, |
|
|
|
pooled_cls=True, requires_grad: bool=False): |
|
|
|
pooled_cls=True, requires_grad: bool=False, auto_truncate:bool=False): |
|
|
|
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) |
|
|
|
|
|
|
|
# 根据model_dir_or_name检查是否存在并下载 |
|
|
@@ -69,7 +72,7 @@ class BertEmbedding(ContextualEmbedding): |
|
|
|
|
|
|
|
self.model = _WordBertModel(model_dir=model_dir, vocab=vocab, layers=layers, |
|
|
|
pool_method=pool_method, include_cls_sep=include_cls_sep, |
|
|
|
pooled_cls=pooled_cls) |
|
|
|
pooled_cls=pooled_cls, auto_truncate=auto_truncate) |
|
|
|
|
|
|
|
self.requires_grad = requires_grad |
|
|
|
self._embed_size = len(self.model.layers)*self.model.encoder.hidden_size |
|
|
@@ -202,11 +205,12 @@ class BertWordPieceEncoder(nn.Module): |
|
|
|
|
|
|
|
class _WordBertModel(nn.Module): |
|
|
|
def __init__(self, model_dir:str, vocab:Vocabulary, layers:str='-1', pool_method:str='first', |
|
|
|
include_cls_sep:bool=False, pooled_cls:bool=False): |
|
|
|
include_cls_sep:bool=False, pooled_cls:bool=False, auto_truncate:bool=False): |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
self.tokenzier = BertTokenizer.from_pretrained(model_dir) |
|
|
|
self.encoder = BertModel.from_pretrained(model_dir) |
|
|
|
self._max_position_embeddings = self.encoder.config.max_position_embeddings |
|
|
|
# 检查encoder_layer_number是否合理 |
|
|
|
encoder_layer_number = len(self.encoder.encoder.layer) |
|
|
|
self.layers = list(map(int, layers.split(','))) |
|
|
@@ -222,6 +226,7 @@ class _WordBertModel(nn.Module): |
|
|
|
self.pool_method = pool_method |
|
|
|
self.include_cls_sep = include_cls_sep |
|
|
|
self.pooled_cls = pooled_cls |
|
|
|
self.auto_truncate = auto_truncate |
|
|
|
|
|
|
|
# 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP] |
|
|
|
print("Start to generating word pieces for word.") |
|
|
@@ -290,6 +295,17 @@ class _WordBertModel(nn.Module): |
|
|
|
batch_word_pieces_length = self.word_pieces_lengths[words] # batch_size x max_len |
|
|
|
word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) |
|
|
|
max_word_piece_length = word_pieces_lengths.max().item() |
|
|
|
real_max_word_piece_length = max_word_piece_length # 表示没有截断的word piece的长度 |
|
|
|
if max_word_piece_length+2>self._max_position_embeddings: |
|
|
|
if self.auto_truncate: |
|
|
|
word_pieces_lengths = word_pieces_lengths.masked_fill(word_pieces_lengths+2>self._max_position_embeddings, |
|
|
|
self._max_position_embeddings-2) |
|
|
|
max_word_piece_length = self._max_position_embeddings-2 |
|
|
|
else: |
|
|
|
raise RuntimeError("After split words into word pieces, the lengths of word pieces are longer than the " |
|
|
|
f"maximum allowed sequence length:{self._max_position_embeddings} of bert.") |
|
|
|
|
|
|
|
|
|
|
|
# +2是由于需要加入[CLS]与[SEP] |
|
|
|
word_pieces = words.new_full((batch_size, max_word_piece_length+2), fill_value=self._wordpiece_pad_index) |
|
|
|
word_pieces[:, 0].fill_(self._cls_index) |
|
|
@@ -300,6 +316,8 @@ class _WordBertModel(nn.Module): |
|
|
|
word_indexes = words.tolist() |
|
|
|
for i in range(batch_size): |
|
|
|
word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i]])) |
|
|
|
if self.auto_truncate and len(word_pieces_i)>self._max_position_embeddings-2: |
|
|
|
word_pieces_i = word_pieces_i[:self._max_position_embeddings-2] |
|
|
|
word_pieces[i, 1:len(word_pieces_i)+1] = torch.LongTensor(word_pieces_i) |
|
|
|
attn_masks[i, :len(word_pieces_i)+2].fill_(1) |
|
|
|
# TODO 截掉长度超过的部分。 |
|
|
@@ -321,6 +339,11 @@ class _WordBertModel(nn.Module): |
|
|
|
batch_word_pieces_cum_length[:, 1:] = batch_word_pieces_length.cumsum(dim=-1) # batch_size x max_len |
|
|
|
for l_index, l in enumerate(self.layers): |
|
|
|
output_layer = bert_outputs[l] |
|
|
|
if real_max_word_piece_length > max_word_piece_length: # 如果实际上是截取出来的 |
|
|
|
paddings = output_layer.new_zeros(batch_size, |
|
|
|
real_max_word_piece_length-max_word_piece_length, |
|
|
|
output_layer.size(2)) |
|
|
|
output_layer = torch.cat((output_layer, paddings), dim=1).contiguous() |
|
|
|
# 从word_piece collapse到word的表示 |
|
|
|
truncate_output_layer = output_layer[:, 1:-1] # 删除[CLS]与[SEP] batch_size x len x hidden_size |
|
|
|
outputs_seq_len = seq_len + s_shift |
|
|
|