|
-
- r"""undocumented
- 这个页面的代码很大程度上参考(复制粘贴)了https://github.com/huggingface/pytorch-pretrained-BERT的代码, 如果你发现该代码对你
- 有用,也请引用一下他们。
- """
-
- __all__ = [
- 'RobertaModel'
- ]
-
- import torch
- import torch.nn as nn
-
- from .bert import BertEmbeddings, BertModel, BertConfig
- from fastNLP.io.file_utils import _get_file_name_base_on_postfix
- from ...io.file_utils import _get_roberta_dir
- from ...core import logger
-
- PRETRAINED_ROBERTA_POSITIONAL_EMBEDDINGS_SIZES = {
- "roberta-base": 512,
- "roberta-large": 512,
- "roberta-large-mnli": 512,
- "distilroberta-base": 512,
- "roberta-base-openai-detector": 512,
- "roberta-large-openai-detector": 512,
- }
-
-
- class RobertaEmbeddings(BertEmbeddings):
- """
- Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
- """
-
- def __init__(self, config):
- super().__init__(config)
- self.padding_idx = 1
- self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx)
- self.position_embeddings = nn.Embedding(
- config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
- )
-
- def forward(self, input_ids, token_type_ids, words_embeddings=None):
- position_ids = self.create_position_ids_from_input_ids(input_ids)
-
- return super().forward(
- input_ids, token_type_ids=token_type_ids, position_ids=position_ids, words_embeddings=words_embeddings
- )
-
- def create_position_ids_from_input_ids(self, x):
- """ Replace non-padding symbols with their position numbers. Position numbers begin at
- padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
- `utils.make_positions`.
-
- :param torch.Tensor x:
- :return torch.Tensor:
- """
- mask = x.ne(self.padding_idx).long()
- incremental_indicies = torch.cumsum(mask, dim=1) * mask
- return incremental_indicies + self.padding_idx
-
-
- class RobertaModel(BertModel):
- r"""
- undocumented
- """
-
- def __init__(self, config):
- super().__init__(config)
-
- self.embeddings = RobertaEmbeddings(config)
- self.apply(self.init_bert_weights)
-
- @classmethod
- def from_pretrained(cls, model_dir_or_name, *inputs, **kwargs):
- state_dict = kwargs.get('state_dict', None)
- kwargs.pop('state_dict', None)
- kwargs.pop('cache_dir', None)
- kwargs.pop('from_tf', None)
-
- # get model dir from name or dir
- pretrained_model_dir = _get_roberta_dir(model_dir_or_name)
-
- # Load config
- config_file = _get_file_name_base_on_postfix(pretrained_model_dir, 'config.json')
- config = BertConfig.from_json_file(config_file)
-
- # Load model
- if state_dict is None:
- weights_path = _get_file_name_base_on_postfix(pretrained_model_dir, '.bin')
- state_dict = torch.load(weights_path, map_location='cpu')
- else:
- logger.error(f'Cannot load parameters through `state_dict` variable.')
- raise RuntimeError(f'Cannot load parameters through `state_dict` variable.')
-
- # Instantiate model.
- model = cls(config, *inputs, **kwargs)
-
- missing_keys = []
- unexpected_keys = []
- error_msgs = []
-
- # Convert old format to new format if needed from a PyTorch state_dict
- old_keys = []
- new_keys = []
- for key in state_dict.keys():
- new_key = None
- if "gamma" in key:
- new_key = key.replace("gamma", "weight")
- if "beta" in key:
- new_key = key.replace("beta", "bias")
- if new_key:
- old_keys.append(key)
- new_keys.append(new_key)
- for old_key, new_key in zip(old_keys, new_keys):
- state_dict[new_key] = state_dict.pop(old_key)
-
- # copy state_dict so _load_from_state_dict can modify it
- metadata = getattr(state_dict, "_metadata", None)
- state_dict = state_dict.copy()
- if metadata is not None:
- state_dict._metadata = metadata
-
- # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
- # so we need to apply the function recursively.
- def load(module: nn.Module, prefix=""):
- local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
- module._load_from_state_dict(
- state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs,
- )
- for name, child in module._modules.items():
- if child is not None:
- load(child, prefix + name + ".")
-
- # Make sure we are able to load base models as well as derived models (with heads)
- start_prefix = ""
- model_to_load = model
- if not hasattr(model, 'roberta') and any(
- s.startswith('roberta') for s in state_dict.keys()
- ):
- start_prefix = 'roberta.'
- if hasattr(model, 'roberta') and not any(
- s.startswith('roberta') for s in state_dict.keys()
- ):
- model_to_load = getattr(model, 'roberta')
-
- load(model_to_load, prefix=start_prefix)
-
- if model.__class__.__name__ != model_to_load.__class__.__name__:
- base_model_state_dict = model_to_load.state_dict().keys()
- head_model_state_dict_without_base_prefix = [
- key.split('roberta.')[-1] for key in model.state_dict().keys()
- ]
-
- missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
-
- if len(missing_keys) > 0:
- logger.info(
- "Weights of {} not initialized from pretrained model: {}".format(
- model.__class__.__name__, missing_keys
- )
- )
- if len(unexpected_keys) > 0:
- logger.info(
- "Weights from pretrained model not used in {}: {}".format(
- model.__class__.__name__, unexpected_keys
- )
- )
- if len(error_msgs) > 0:
- raise RuntimeError(
- "Error(s) in loading state_dict for {}:\n\t{}".format(
- model.__class__.__name__, "\n\t".join(error_msgs)
- )
- )
-
- # Set model in evaluation mode to desactivate DropOut modules by default
- model.eval()
-
- logger.info(f"Load pre-trained RoBERTa parameters from file {weights_path}.")
-
- return model
-
|