@@ -110,11 +110,12 @@ class BertEmbedding(ContextualEmbedding): | |||
if '[CLS]' in vocab: | |||
self._word_cls_index = vocab['[CLS]'] | |||
min_freq = kwargs.get('min_freq', 1) | |||
min_freq = kwargs.pop('min_freq', 1) | |||
self._min_freq = min_freq | |||
self.model = _BertWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | |||
pool_method=pool_method, include_cls_sep=include_cls_sep, | |||
pooled_cls=pooled_cls, min_freq=min_freq, auto_truncate=auto_truncate) | |||
pooled_cls=pooled_cls, min_freq=min_freq, auto_truncate=auto_truncate, | |||
**kwargs) | |||
self.requires_grad = requires_grad | |||
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | |||
@@ -367,32 +368,44 @@ class BertWordPieceEncoder(nn.Module): | |||
class _BertWordModel(nn.Module): | |||
def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', | |||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2): | |||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2, | |||
**kwargs): | |||
super().__init__() | |||
if isinstance(layers, list): | |||
self.layers = [int(l) for l in layers] | |||
elif isinstance(layers, str): | |||
self.layers = list(map(int, layers.split(','))) | |||
if layers.lower() == 'all': | |||
self.layers = None | |||
else: | |||
self.layers = list(map(int, layers.split(','))) | |||
else: | |||
raise TypeError("`layers` only supports str or list[int]") | |||
assert len(self.layers) > 0, "There is no layer selected!" | |||
neg_num_output_layer = -16384 | |||
pos_num_output_layer = 0 | |||
for layer in self.layers: | |||
if layer < 0: | |||
neg_num_output_layer = max(layer, neg_num_output_layer) | |||
else: | |||
pos_num_output_layer = max(layer, pos_num_output_layer) | |||
if self.layers is None: | |||
neg_num_output_layer = -1 | |||
else: | |||
for layer in self.layers: | |||
if layer < 0: | |||
neg_num_output_layer = max(layer, neg_num_output_layer) | |||
else: | |||
pos_num_output_layer = max(layer, pos_num_output_layer) | |||
self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name) | |||
self.encoder = BertModel.from_pretrained(model_dir_or_name, | |||
neg_num_output_layer=neg_num_output_layer, | |||
pos_num_output_layer=pos_num_output_layer) | |||
pos_num_output_layer=pos_num_output_layer, | |||
**kwargs) | |||
self._max_position_embeddings = self.encoder.config.max_position_embeddings | |||
# 检查encoder_layer_number是否合理 | |||
encoder_layer_number = len(self.encoder.encoder.layer) | |||
if self.layers is None: | |||
self.layers = [idx for idx in range(encoder_layer_number + 1)] | |||
logger.info(f'Bert Model will return {len(self.layers)} layers (layer-0 ' | |||
f'is embedding result): {self.layers}') | |||
assert len(self.layers) > 0, "There is no layer selected!" | |||
for layer in self.layers: | |||
if layer < 0: | |||
assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | |||
@@ -417,7 +430,7 @@ class _BertWordModel(nn.Module): | |||
word = '[PAD]' | |||
elif index == vocab.unknown_idx: | |||
word = '[UNK]' | |||
elif vocab.word_count[word]<min_freq: | |||
elif vocab.word_count[word] < min_freq: | |||
word = '[UNK]' | |||
word_pieces = self.tokenzier.wordpiece_tokenizer.tokenize(word) | |||
word_pieces = self.tokenzier.convert_tokens_to_ids(word_pieces) | |||
@@ -481,14 +494,15 @@ class _BertWordModel(nn.Module): | |||
token_type_ids = torch.zeros_like(word_pieces) | |||
# 2. 获取hidden的结果,根据word_pieces进行对应的pool计算 | |||
# all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...] | |||
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks, | |||
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, | |||
attention_mask=attn_masks, | |||
output_all_encoded_layers=True) | |||
# output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size | |||
if self.include_cls_sep: | |||
s_shift = 1 | |||
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2, | |||
bert_outputs[-1].size(-1)) | |||
bert_outputs[-1].size(-1)) | |||
else: | |||
s_shift = 0 | |||
@@ -93,12 +93,13 @@ class RobertaEmbedding(ContextualEmbedding): | |||
if '<s>' in vocab: | |||
self._word_cls_index = vocab['<s>'] | |||
min_freq = kwargs.get('min_freq', 1) | |||
min_freq = kwargs.pop('min_freq', 1) | |||
self._min_freq = min_freq | |||
self.model = _RobertaWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | |||
pool_method=pool_method, include_cls_sep=include_cls_sep, | |||
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq) | |||
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq, | |||
**kwargs) | |||
self.requires_grad = requires_grad | |||
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | |||
@@ -193,33 +194,45 @@ class RobertaEmbedding(ContextualEmbedding): | |||
class _RobertaWordModel(nn.Module): | |||
def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', | |||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2): | |||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2, | |||
**kwargs): | |||
super().__init__() | |||
if isinstance(layers, list): | |||
self.layers = [int(l) for l in layers] | |||
elif isinstance(layers, str): | |||
self.layers = list(map(int, layers.split(','))) | |||
if layers.lower() == 'all': | |||
self.layers = None | |||
else: | |||
self.layers = list(map(int, layers.split(','))) | |||
else: | |||
raise TypeError("`layers` only supports str or list[int]") | |||
assert len(self.layers) > 0, "There is no layer selected!" | |||
neg_num_output_layer = -16384 | |||
pos_num_output_layer = 0 | |||
for layer in self.layers: | |||
if layer < 0: | |||
neg_num_output_layer = max(layer, neg_num_output_layer) | |||
else: | |||
pos_num_output_layer = max(layer, pos_num_output_layer) | |||
if self.layers is None: | |||
neg_num_output_layer = -1 | |||
else: | |||
for layer in self.layers: | |||
if layer < 0: | |||
neg_num_output_layer = max(layer, neg_num_output_layer) | |||
else: | |||
pos_num_output_layer = max(layer, pos_num_output_layer) | |||
self.tokenizer = RobertaTokenizer.from_pretrained(model_dir_or_name) | |||
self.encoder = RobertaModel.from_pretrained(model_dir_or_name, | |||
neg_num_output_layer=neg_num_output_layer, | |||
pos_num_output_layer=pos_num_output_layer) | |||
pos_num_output_layer=pos_num_output_layer, | |||
**kwargs) | |||
# 由于RobertaEmbedding中设置了padding_idx为1, 且使用了非常神奇的position计算方式,所以-2 | |||
self._max_position_embeddings = self.encoder.config.max_position_embeddings - 2 | |||
# 检查encoder_layer_number是否合理 | |||
encoder_layer_number = len(self.encoder.encoder.layer) | |||
if self.layers is None: | |||
self.layers = [idx for idx in range(encoder_layer_number + 1)] | |||
logger.info(f'RoBERTa Model will return {len(self.layers)} layers (layer-0 ' | |||
f'is embedding result): {self.layers}') | |||
assert len(self.layers) > 0, "There is no layer selected!" | |||
for layer in self.layers: | |||
if layer < 0: | |||
assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | |||
@@ -241,7 +254,7 @@ class _RobertaWordModel(nn.Module): | |||
word = '<pad>' | |||
elif index == vocab.unknown_idx: | |||
word = '<unk>' | |||
elif vocab.word_count[word]<min_freq: | |||
elif vocab.word_count[word] < min_freq: | |||
word = '<unk>' | |||
word_pieces = self.tokenizer.tokenize(word) | |||
word_pieces = self.tokenizer.convert_tokens_to_ids(word_pieces) | |||
@@ -265,13 +278,15 @@ class _RobertaWordModel(nn.Module): | |||
batch_size, max_word_len = words.size() | |||
word_mask = words.ne(self._word_pad_index) # 为1的地方有word | |||
seq_len = word_mask.sum(dim=-1) | |||
batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(False), 0) # batch_size x max_len | |||
batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(False), | |||
0) # batch_size x max_len | |||
word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) # batch_size | |||
max_word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding) | |||
if max_word_piece_length + 2 > self._max_position_embeddings: | |||
if self.auto_truncate: | |||
word_pieces_lengths = word_pieces_lengths.masked_fill( | |||
word_pieces_lengths + 2 > self._max_position_embeddings, self._max_position_embeddings - 2) | |||
word_pieces_lengths + 2 > self._max_position_embeddings, | |||
self._max_position_embeddings - 2) | |||
else: | |||
raise RuntimeError( | |||
"After split words into word pieces, the lengths of word pieces are longer than the " | |||
@@ -290,6 +305,7 @@ class _RobertaWordModel(nn.Module): | |||
word_pieces_i = word_pieces_i[:self._max_position_embeddings - 2] | |||
word_pieces[i, 1:word_pieces_lengths[i] + 1] = torch.LongTensor(word_pieces_i) | |||
attn_masks[i, :word_pieces_lengths[i] + 2].fill_(1) | |||
# 添加<s>和</s> | |||
word_pieces[:, 0].fill_(self._cls_index) | |||
batch_indexes = torch.arange(batch_size).to(words) | |||
word_pieces[batch_indexes, word_pieces_lengths + 1] = self._sep_index | |||
@@ -362,6 +378,12 @@ class _RobertaWordModel(nn.Module): | |||
return outputs | |||
def save(self, folder): | |||
""" | |||
给定一个folder保存pytorch_model.bin, config.json, vocab.txt | |||
:param str folder: | |||
:return: | |||
""" | |||
self.tokenizer.save_pretrained(folder) | |||
self.encoder.save_pretrained(folder) | |||
@@ -184,21 +184,23 @@ class DistilBertEmbeddings(nn.Module): | |||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12) | |||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
def forward(self, input_ids, token_type_ids): | |||
def forward(self, input_ids, token_type_ids, position_ids=None): | |||
r""" | |||
Parameters | |||
---------- | |||
input_ids: torch.tensor(bs, max_seq_length) | |||
The token ids to embed. | |||
token_type_ids: no used. | |||
position_ids: no used. | |||
Outputs | |||
------- | |||
embeddings: torch.tensor(bs, max_seq_length, dim) | |||
The embedded tokens (plus position embeddings, no token_type embeddings) | |||
""" | |||
seq_length = input_ids.size(1) | |||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length) | |||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length) | |||
if position_ids is None: | |||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length) | |||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length) | |||
word_embeddings = self.word_embeddings(input_ids) # (bs, max_seq_length, dim) | |||
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim) | |||