From 030e0aa3ee31b1489d3348f42998d4aae14a5f56 Mon Sep 17 00:00:00 2001 From: Yige Xu Date: Sat, 26 Dec 2020 09:29:17 +0800 Subject: [PATCH] update some function about bert and roberta --- fastNLP/embeddings/bert_embedding.py | 42 ++++++++++++++------- fastNLP/embeddings/roberta_embedding.py | 50 ++++++++++++++++++------- fastNLP/modules/encoder/bert.py | 8 ++-- 3 files changed, 69 insertions(+), 31 deletions(-) diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index 29b17c65..6434cc0d 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -110,11 +110,12 @@ class BertEmbedding(ContextualEmbedding): if '[CLS]' in vocab: self._word_cls_index = vocab['[CLS]'] - min_freq = kwargs.get('min_freq', 1) + min_freq = kwargs.pop('min_freq', 1) self._min_freq = min_freq self.model = _BertWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, pool_method=pool_method, include_cls_sep=include_cls_sep, - pooled_cls=pooled_cls, min_freq=min_freq, auto_truncate=auto_truncate) + pooled_cls=pooled_cls, min_freq=min_freq, auto_truncate=auto_truncate, + **kwargs) self.requires_grad = requires_grad self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size @@ -367,32 +368,44 @@ class BertWordPieceEncoder(nn.Module): class _BertWordModel(nn.Module): def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', - include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2): + include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2, + **kwargs): super().__init__() if isinstance(layers, list): self.layers = [int(l) for l in layers] elif isinstance(layers, str): - self.layers = list(map(int, layers.split(','))) + if layers.lower() == 'all': + self.layers = None + else: + self.layers = list(map(int, layers.split(','))) else: raise TypeError("`layers` only supports str or list[int]") - assert len(self.layers) > 0, "There is no layer selected!" neg_num_output_layer = -16384 pos_num_output_layer = 0 - for layer in self.layers: - if layer < 0: - neg_num_output_layer = max(layer, neg_num_output_layer) - else: - pos_num_output_layer = max(layer, pos_num_output_layer) + if self.layers is None: + neg_num_output_layer = -1 + else: + for layer in self.layers: + if layer < 0: + neg_num_output_layer = max(layer, neg_num_output_layer) + else: + pos_num_output_layer = max(layer, pos_num_output_layer) self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name) self.encoder = BertModel.from_pretrained(model_dir_or_name, neg_num_output_layer=neg_num_output_layer, - pos_num_output_layer=pos_num_output_layer) + pos_num_output_layer=pos_num_output_layer, + **kwargs) self._max_position_embeddings = self.encoder.config.max_position_embeddings # 检查encoder_layer_number是否合理 encoder_layer_number = len(self.encoder.encoder.layer) + if self.layers is None: + self.layers = [idx for idx in range(encoder_layer_number + 1)] + logger.info(f'Bert Model will return {len(self.layers)} layers (layer-0 ' + f'is embedding result): {self.layers}') + assert len(self.layers) > 0, "There is no layer selected!" for layer in self.layers: if layer < 0: assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ @@ -417,7 +430,7 @@ class _BertWordModel(nn.Module): word = '[PAD]' elif index == vocab.unknown_idx: word = '[UNK]' - elif vocab.word_count[word]' in vocab: self._word_cls_index = vocab[''] - min_freq = kwargs.get('min_freq', 1) + min_freq = kwargs.pop('min_freq', 1) self._min_freq = min_freq self.model = _RobertaWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, pool_method=pool_method, include_cls_sep=include_cls_sep, - pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq) + pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq, + **kwargs) self.requires_grad = requires_grad self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size @@ -193,33 +194,45 @@ class RobertaEmbedding(ContextualEmbedding): class _RobertaWordModel(nn.Module): def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', - include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2): + include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2, + **kwargs): super().__init__() if isinstance(layers, list): self.layers = [int(l) for l in layers] elif isinstance(layers, str): - self.layers = list(map(int, layers.split(','))) + if layers.lower() == 'all': + self.layers = None + else: + self.layers = list(map(int, layers.split(','))) else: raise TypeError("`layers` only supports str or list[int]") - assert len(self.layers) > 0, "There is no layer selected!" neg_num_output_layer = -16384 pos_num_output_layer = 0 - for layer in self.layers: - if layer < 0: - neg_num_output_layer = max(layer, neg_num_output_layer) - else: - pos_num_output_layer = max(layer, pos_num_output_layer) + if self.layers is None: + neg_num_output_layer = -1 + else: + for layer in self.layers: + if layer < 0: + neg_num_output_layer = max(layer, neg_num_output_layer) + else: + pos_num_output_layer = max(layer, pos_num_output_layer) self.tokenizer = RobertaTokenizer.from_pretrained(model_dir_or_name) self.encoder = RobertaModel.from_pretrained(model_dir_or_name, neg_num_output_layer=neg_num_output_layer, - pos_num_output_layer=pos_num_output_layer) + pos_num_output_layer=pos_num_output_layer, + **kwargs) # 由于RobertaEmbedding中设置了padding_idx为1, 且使用了非常神奇的position计算方式,所以-2 self._max_position_embeddings = self.encoder.config.max_position_embeddings - 2 # 检查encoder_layer_number是否合理 encoder_layer_number = len(self.encoder.encoder.layer) + if self.layers is None: + self.layers = [idx for idx in range(encoder_layer_number + 1)] + logger.info(f'RoBERTa Model will return {len(self.layers)} layers (layer-0 ' + f'is embedding result): {self.layers}') + assert len(self.layers) > 0, "There is no layer selected!" for layer in self.layers: if layer < 0: assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ @@ -241,7 +254,7 @@ class _RobertaWordModel(nn.Module): word = '' elif index == vocab.unknown_idx: word = '' - elif vocab.word_count[word] self._max_position_embeddings: if self.auto_truncate: word_pieces_lengths = word_pieces_lengths.masked_fill( - word_pieces_lengths + 2 > self._max_position_embeddings, self._max_position_embeddings - 2) + word_pieces_lengths + 2 > self._max_position_embeddings, + self._max_position_embeddings - 2) else: raise RuntimeError( "After split words into word pieces, the lengths of word pieces are longer than the " @@ -290,6 +305,7 @@ class _RobertaWordModel(nn.Module): word_pieces_i = word_pieces_i[:self._max_position_embeddings - 2] word_pieces[i, 1:word_pieces_lengths[i] + 1] = torch.LongTensor(word_pieces_i) attn_masks[i, :word_pieces_lengths[i] + 2].fill_(1) + # 添加 word_pieces[:, 0].fill_(self._cls_index) batch_indexes = torch.arange(batch_size).to(words) word_pieces[batch_indexes, word_pieces_lengths + 1] = self._sep_index @@ -362,6 +378,12 @@ class _RobertaWordModel(nn.Module): return outputs def save(self, folder): + """ + 给定一个folder保存pytorch_model.bin, config.json, vocab.txt + + :param str folder: + :return: + """ self.tokenizer.save_pretrained(folder) self.encoder.save_pretrained(folder) diff --git a/fastNLP/modules/encoder/bert.py b/fastNLP/modules/encoder/bert.py index f304073d..55e79d63 100644 --- a/fastNLP/modules/encoder/bert.py +++ b/fastNLP/modules/encoder/bert.py @@ -184,21 +184,23 @@ class DistilBertEmbeddings(nn.Module): self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, input_ids, token_type_ids): + def forward(self, input_ids, token_type_ids, position_ids=None): r""" Parameters ---------- input_ids: torch.tensor(bs, max_seq_length) The token ids to embed. token_type_ids: no used. + position_ids: no used. Outputs ------- embeddings: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type embeddings) """ seq_length = input_ids.size(1) - position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length) - position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length) + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length) word_embeddings = self.word_embeddings(input_ids) # (bs, max_seq_length, dim) position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)