|
- # coding=utf-8
- # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """ PyTorch CPT model. modified from transformers==4.4.1"""
- import math
- import random
- from typing import Optional, Tuple
-
- from fastNLP.transformers.torch.activations import ACT2FN
- from fastNLP.transformers.torch.file_utils import (
- add_code_sample_docstrings,
- add_start_docstrings,
- add_start_docstrings_to_model_forward,
- replace_return_docstrings,
- )
- from fastNLP.transformers.torch.modeling_outputs import (
- BaseModelOutput,
- BaseModelOutputWithPastAndCrossAttentions,
- Seq2SeqLMOutput,
- Seq2SeqModelOutput,
- Seq2SeqQuestionAnsweringModelOutput,
- Seq2SeqSequenceClassifierOutput,
- )
- from fastNLP.transformers.torch.modeling_utils import PreTrainedModel
- from ..bart import BartConfig as CPTConfig
- from ..bert import BertModel, BertConfig
- from fastNLP.core.log import logger
- from fastNLP.envs.imports import _NEED_IMPORT_TORCH
-
- if _NEED_IMPORT_TORCH:
- import torch
- import torch.nn.functional as F
- import torch.utils.checkpoint
- from torch import nn
- from torch.nn import CrossEntropyLoss, LayerNorm, Module, Embedding
- else:
- from fastNLP.core.utils.dummy_class import(
- DummyClass as Module,
- DummyClass as Embedding
- )
-
- __all__ = [
- "CPT_PRETRAINED_MODEL_ARCHIVE_LIST",
- "CPTForConditionalGeneration",
- "CPTForSequenceClassification",
- "CPTForMaskedLM",
- "CPTForQuestionAnswering",
- "CPTModel",
- "CPTPretrainedModel",
- ]
-
- _CHECKPOINT_FOR_DOC = "fnlp/cpt-large"
- _CONFIG_FOR_DOC = "CPTConfig"
- _TOKENIZER_FOR_DOC = "CPTTokenizer"
-
-
- CPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
- "fnlp/cpt-large",
- ]
-
-
- def shift_tokens_right(input_ids: "torch.Tensor", pad_token_id: int, decoder_start_token_id: int):
- """
- Shift input ids one token to the right.
- """
- shifted_input_ids = input_ids.new_zeros(input_ids.shape)
- shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
- shifted_input_ids[:, 0] = decoder_start_token_id
-
- assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
- # replace possible -100 values in labels by `pad_token_id`
- shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
-
- return shifted_input_ids
-
-
- def _make_causal_mask(input_ids_shape: "torch.Size", dtype: "torch.dtype", past_key_values_length: int = 0):
- """
- Make causal mask used for bi-directional self-attention.
- """
- bsz, tgt_len = input_ids_shape
- mask = torch.full((tgt_len, tgt_len), float("-inf"))
- mask_cond = torch.arange(mask.size(-1))
- mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
- mask = mask.to(dtype)
-
- if past_key_values_length > 0:
- mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
- return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
-
-
- def _expand_mask(mask: "torch.Tensor", dtype: "torch.dtype", tgt_len: Optional[int] = None):
- """
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
- """
- bsz, src_len = mask.size()
- tgt_len = tgt_len if tgt_len is not None else src_len
-
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
-
- inverted_mask = 1.0 - expanded_mask
-
- return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
-
- def attention_mask_func(attention_scores, attention_mask):
- return attention_scores + attention_mask
-
- def init_method(std):
- def init_(tensor):
- return torch.nn.init.normal_(tensor, mean=0.0, std=std)
-
- return init_
-
- class CPTLearnedPositionalEmbedding(Embedding):
- """
- This module learns positional embeddings up to a fixed maximum size.
- """
-
- def __init__(self, num_embeddings: int, embedding_dim: int):
- # CPT is set up so that if padding_idx is specified then offset the embedding ids by 2
- # and adjust num_embeddings appropriately. Other models dont have this hack
- self.offset = 2
- super().__init__(num_embeddings + self.offset, embedding_dim)
-
- def forward(self, input_ids_shape: "torch.Size", past_key_values_length: int = 0):
- """`input_ids_shape` is expected to be [bsz x seqlen]."""
- bsz, seq_len = input_ids_shape[:2]
- positions = torch.arange(
- past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
- )
- return super().forward(positions + self.offset)
-
-
- class CPTAttention(Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
-
- def __init__(
- self,
- embed_dim: int,
- num_heads: int,
- dropout: float = 0.0,
- is_decoder: bool = False,
- bias: bool = True,
- ):
- super().__init__()
- self.embed_dim = embed_dim
- self.num_heads = num_heads
- self.dropout = dropout
- self.head_dim = embed_dim // num_heads
- assert (
- self.head_dim * num_heads == self.embed_dim
- ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})."
- self.scaling = self.head_dim ** -0.5
- self.is_decoder = is_decoder
-
- self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
-
-
- def _shape(self, tensor: "torch.Tensor", seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
-
- def forward(
- self,
- hidden_states: "torch.Tensor",
- key_value_states: Optional["torch.Tensor"] = None,
- past_key_value: Optional[Tuple["torch.Tensor"]] = None,
- attention_mask: Optional["torch.Tensor"] = None,
- layer_head_mask: Optional["torch.Tensor"] = None,
- output_attentions: bool = False,
- ) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
- """Input shape: Batch x Time x Channel"""
-
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
- bsz, tgt_len, embed_dim = hidden_states.size()
-
- # get query proj
- query_states = self.q_proj(hidden_states) * self.scaling
- # get key, value proj
- if is_cross_attention and past_key_value is not None:
- # reuse k,v, cross_attentions
- key_states = past_key_value[0]
- value_states = past_key_value[1]
- elif is_cross_attention:
- # cross_attentions
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
- elif past_key_value is not None:
- # reuse k, v, self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
- else:
- # self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
-
- if self.is_decoder:
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
- # Further calls to cross_attention layer can then reuse all cross-attention
- # key/value_states (first "if" case)
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
- # if encoder bi-directional self-attention `past_key_value` is always `None`
- past_key_value = (key_states, value_states)
-
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
- key_states = key_states.view(*proj_shape)
- value_states = value_states.view(*proj_shape)
-
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
-
- assert attn_weights.size() == (
- bsz * self.num_heads,
- tgt_len,
- src_len,
- ), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
-
- if attention_mask is not None:
- assert attention_mask.size() == (
- bsz,
- 1,
- tgt_len,
- src_len,
- ), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- attn_weights = F.softmax(attn_weights, dim=-1)
-
- if layer_head_mask is not None:
- assert layer_head_mask.size() == (
- self.num_heads,
- ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- if output_attentions:
- # this operation is a bit akward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
-
- # with mpu.get_cuda_rng_tracker().fork():
- attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
-
- attn_output = torch.bmm(attn_probs, value_states)
-
- assert attn_output.size() == (
- bsz * self.num_heads,
- tgt_len,
- self.head_dim,
- ), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
-
- attn_output = (
- attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- .transpose(1, 2)
- .reshape(bsz, tgt_len, embed_dim)
- )
-
- attn_output = self.out_proj(attn_output)
-
- return attn_output, attn_weights_reshaped, past_key_value
-
- class CPTDecoderLayer(Module):
- def __init__(self, config: CPTConfig):
- super().__init__()
- self.embed_dim = config.d_model
-
- self.self_attn = CPTAttention(
- embed_dim=self.embed_dim,
- num_heads=config.decoder_attention_heads,
- dropout=config.attention_dropout,
- is_decoder=True,
- )
- self.dropout = config.dropout
- self.activation_fn = ACT2FN[config.activation_function]
- self.activation_dropout = config.activation_dropout
-
- self.self_attn_layer_norm = LayerNorm(self.embed_dim)
- self.encoder_attn = CPTAttention(
- self.embed_dim,
- config.decoder_attention_heads,
- dropout=config.attention_dropout,
- is_decoder=True,
- )
- self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
- self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
- self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
- self.final_layer_norm = LayerNorm(self.embed_dim)
-
- def forward(
- self,
- hidden_states: "torch.Tensor",
- attention_mask: Optional["torch.Tensor"] = None,
- encoder_hidden_states: Optional["torch.Tensor"] = None,
- encoder_attention_mask: Optional["torch.Tensor"] = None,
- layer_head_mask: Optional["torch.Tensor"] = None,
- encoder_layer_head_mask: Optional["torch.Tensor"] = None,
- past_key_value: Optional[Tuple["torch.Tensor"]] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = True,
- ):
- """
- Args:
- hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
- attention_mask (:obj:`torch.FloatTensor`): attention mask of size
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
- encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
- encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
- layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
- `(config.encoder_attention_heads,)`.
- encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of
- size `(config.encoder_attention_heads,)`.
- past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
- output_attentions (:obj:`bool`, `optional`):
- Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
- returned tensors for more detail.
- """
- residual = hidden_states
-
- # Self Attention
- # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
- self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
- # add present self-attn cache to positions 1,2 of present_key_value tuple
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states=hidden_states,
- past_key_value=self_attn_past_key_value,
- attention_mask=attention_mask,
- layer_head_mask=layer_head_mask,
- output_attentions=output_attentions,
- )
- hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
- hidden_states = self.self_attn_layer_norm(hidden_states)
-
- # Cross-Attention Block
- cross_attn_present_key_value = None
- cross_attn_weights = None
- if encoder_hidden_states is not None:
- residual = hidden_states
-
- # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
- cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
- hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
- hidden_states=hidden_states,
- key_value_states=encoder_hidden_states,
- attention_mask=encoder_attention_mask,
- layer_head_mask=encoder_layer_head_mask,
- past_key_value=cross_attn_past_key_value,
- output_attentions=output_attentions,
- )
- hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
- hidden_states = self.encoder_attn_layer_norm(hidden_states)
-
- # add cross-attn to positions 3,4 of present_key_value tuple
- present_key_value = present_key_value + cross_attn_present_key_value
-
- # Fully Connected
- residual = hidden_states
- hidden_states = self.activation_fn(self.fc1(hidden_states))
- hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training)
- hidden_states = self.fc2(hidden_states)
- hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
- hidden_states = self.final_layer_norm(hidden_states)
-
- outputs = (hidden_states,)
-
- if output_attentions:
- outputs += (self_attn_weights, cross_attn_weights)
-
- if use_cache:
- outputs += (present_key_value,)
-
- return outputs
-
-
- class CPTClassificationHead(Module):
- """Head for sentence-level classification tasks."""
-
- def __init__(
- self,
- input_dim: int,
- inner_dim: int,
- num_classes: int,
- pooler_dropout: float,
- ):
- super().__init__()
- self.dense = nn.Linear(input_dim, inner_dim)
- self.dropout = nn.Dropout(p=pooler_dropout)
- self.out_proj = nn.Linear(inner_dim, num_classes)
-
- def forward(self, hidden_states: "torch.Tensor"):
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.dense(hidden_states)
- hidden_states = torch.tanh(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.out_proj(hidden_states)
- return hidden_states
-
-
- class CPTPretrainedModel(PreTrainedModel):
- config_class = CPTConfig
- base_model_prefix = "model"
-
- def _init_weights(self, module):
- std = self.config.init_std
- if isinstance(module, nn.Linear):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
-
- @property
- def dummy_inputs(self):
- pad_token = self.config.pad_token_id
- input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
- dummy_inputs = {
- "attention_mask": input_ids.ne(pad_token),
- "input_ids": input_ids,
- }
- return dummy_inputs
-
- CPT_START_DOCSTRING = r"""
- This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
- methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
- pruning heads etc.)
- This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
- subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
- general usage and behavior.
- Parameters:
- config (:class:`~transformers.CPTConfig`):
- Model configuration class with all the parameters of the model. Initializing with a config file does not
- load the weights associated with the model, only the configuration. Check out the
- :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
- """
-
- CPT_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
- it.
- Indices can be obtained using :class:`~transformers.CPTTokenizer`. See
- :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
- details.
- `What are input IDs? <../glossary.html#input-ids>`__
- attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
- Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- `What are attention masks? <../glossary.html#attention-mask>`__
- decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
- Indices of decoder input sequence tokens in the vocabulary.
- Indices can be obtained using :class:`~transformers.CPTTokenizer`. See
- :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
- details.
- `What are input IDs? <../glossary.html#input-ids>`__
- CPT uses the :obj:`eos_token_id` as the starting token for :obj:`decoder_input_ids` generation. If
- :obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see
- :obj:`past_key_values`).
- For translation and summarization training, :obj:`decoder_input_ids` should be provided. If no
- :obj:`decoder_input_ids` is provided, the model will create this tensor by shifting the :obj:`input_ids` to
- the right for denoising pre-training following the paper.
- decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
- Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
- also be used by default.
- If you want to change padding behavior, you should read :func:`modeling_cpt._prepare_decoder_inputs` and
- modify to your needs. See diagram 1 in `the paper <https://arxiv.org/abs/1910.13461>`__ for more
- information on the default strategy.
- head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
- Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- - 1 indicates the head is **not masked**,
- - 0 indicates the heas is **masked**.
- decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
- Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
- Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
- :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
- `optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
- cross-attention of the decoder.
- past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
- Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.
- If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
- (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
- instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
- inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
- Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
- This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
- vectors than the model's internal embedding lookup matrix.
- decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`):
- Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded
- representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds`
- have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert
- :obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
- If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds`
- takes the value of :obj:`inputs_embeds`.
- use_cache (:obj:`bool`, `optional`):
- If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
- decoding (see :obj:`past_key_values`).
- output_attentions (:obj:`bool`, `optional`):
- Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
- tensors for more detail.
- output_hidden_states (:obj:`bool`, `optional`):
- Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
- more detail.
- return_dict (:obj:`bool`, `optional`):
- Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
- """
-
- class CPTDecoder(CPTPretrainedModel):
- """
- Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`CPTDecoderLayer`
- Args:
- config: CPTConfig
- embed_tokens (torch.nn.Embedding): output embedding
- """
-
- def __init__(self, config: CPTConfig, embed_tokens: Optional["nn.Embedding"] = None):
- super().__init__(config)
- self.dropout = config.dropout
- self.layerdrop = config.decoder_layerdrop
- self.padding_idx = config.pad_token_id
- self.max_target_positions = config.max_position_embeddings
- self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
-
- if embed_tokens is not None:
- self.embed_tokens = embed_tokens
- else:
- self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
-
- self.embed_positions = CPTLearnedPositionalEmbedding(
- config.max_position_embeddings,
- config.d_model,
- )
- self.layers = nn.ModuleList([CPTDecoderLayer(config) for _ in range(config.decoder_layers)])
- self.layernorm_embedding = LayerNorm(config.d_model)
-
- self.init_weights()
-
- def get_input_embeddings(self):
- return self.embed_tokens
-
- def set_input_embeddings(self, value):
- self.embed_tokens = value
-
- def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
- # create causal mask
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- combined_attention_mask = None
- if input_shape[-1] > 1:
- combined_attention_mask = _make_causal_mask(
- input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
- ).to(self.device)
-
- if attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
- combined_attention_mask = (
- expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
- )
-
- return combined_attention_mask
-
- def forward(
- self,
- input_ids=None,
- attention_mask=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- head_mask=None,
- encoder_head_mask=None,
- past_key_values=None,
- inputs_embeds=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
- r"""
- Args:
- input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
- provide it.
- Indices can be obtained using :class:`~transformers.CPTTokenizer`. See
- :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
- for details.
- `What are input IDs? <../glossary.html#input-ids>`__
- attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
- Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- `What are attention masks? <../glossary.html#attention-mask>`__
- encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`):
- Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
- of the decoder.
- encoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, encoder_sequence_length)`, `optional`):
- Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
- selected in ``[0, 1]``:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- `What are attention masks? <../glossary.html#attention-mask>`__
- head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
- Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- - 1 indicates the head is **not masked**,
- - 0 indicates the heas is **masked**.
- encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
- Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
- on hidden heads. Mask values selected in ``[0, 1]``:
- - 1 indicates the head is **not masked**,
- - 0 indicates the heas is **masked**.
- past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
- Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
- decoding.
- If :obj:`past_key_values` are used, the user can optionally input only the last
- :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of
- shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size,
- sequence_length)`.
- inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
- Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
- representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
- into associated vectors than the model's internal embedding lookup matrix.
- output_attentions (:obj:`bool`, `optional`):
- Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
- returned tensors for more detail.
- output_hidden_states (:obj:`bool`, `optional`):
- Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
- for more detail.
- return_dict (:obj:`bool`, `optional`):
- Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
- """
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- # retrieve input_ids and inputs_embeds
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
- elif input_ids is not None:
- input_shape = input_ids.size()
- input_ids = input_ids.view(-1, input_shape[-1])
- elif inputs_embeds is not None:
- input_shape = inputs_embeds.size()[:-1]
- else:
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
-
- # past_key_values_length
- past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
-
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
-
- attention_mask = self._prepare_decoder_attention_mask(
- attention_mask, input_shape, inputs_embeds, past_key_values_length
- )
-
- # expand encoder attention mask
- if encoder_hidden_states is not None and encoder_attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
-
- # embed positions
- positions = self.embed_positions(input_shape, past_key_values_length)
-
- hidden_states = inputs_embeds + positions
- hidden_states = self.layernorm_embedding(hidden_states)
-
- hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
-
- # decoder layers
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
- all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
- next_decoder_cache = () if use_cache else None
-
- # check if head_mask has a correct number of layers specified if desired
- if head_mask is not None:
- assert head_mask.size()[0] == (
- len(self.layers)
- ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
- for idx, decoder_layer in enumerate(self.layers):
- # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
- dropout_probability = random.uniform(0, 1)
- if self.training and (dropout_probability < self.layerdrop):
- continue
-
- past_key_value = past_key_values[idx] if past_key_values is not None else None
-
- if getattr(self.config, "gradient_checkpointing", False) and self.training:
-
- if use_cache:
- logger.warn(
- "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
- "`use_cache=False`..."
- )
- use_cache = False
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- # None for past_key_value
- return module(*inputs, output_attentions, use_cache)
-
- return custom_forward
-
- # layer_outputs = mpu.checkpoint(
- layer_outputs = torch.utils.checkpoint(
- create_custom_forward(decoder_layer),
- hidden_states,
- attention_mask,
- encoder_hidden_states,
- encoder_attention_mask,
- head_mask[idx] if head_mask is not None else None,
- encoder_head_mask[idx] if encoder_head_mask is not None else None,
- None,
- )
- else:
-
- layer_outputs = decoder_layer(
- hidden_states,
- attention_mask=attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- layer_head_mask=(head_mask[idx] if head_mask is not None else None),
- encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- )
- hidden_states = layer_outputs[0]
-
- if use_cache:
- next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
-
- if output_attentions:
- all_self_attns += (layer_outputs[1],)
-
- if encoder_hidden_states is not None:
- all_cross_attentions += (layer_outputs[2],)
-
- # add hidden states from the last decoder layer
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- next_cache = next_decoder_cache if use_cache else None
- if not return_dict:
- return tuple(
- v
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
- if v is not None
- )
- return BaseModelOutputWithPastAndCrossAttentions(
- last_hidden_state=hidden_states,
- past_key_values=next_cache,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
- cross_attentions=all_cross_attentions,
- )
-
-
- @add_start_docstrings(
- "The bare CPT Model outputting raw hidden-states without any specific head on top.",
- CPT_START_DOCSTRING,
- )
- class CPTModel(CPTPretrainedModel):
- def __init__(self, config: CPTConfig):
- super().__init__(config)
- encoder_config = BertConfig(
- vocab_size=config.vocab_size,
- hidden_size=config.d_model,
- num_hidden_layers=config.encoder_layers,
- num_attention_heads=config.encoder_attention_heads,
- intermediate_size=config.encoder_ffn_dim,
- hidden_dropout_prob=config.activation_dropout,
- attention_probs_dropout_prob=config.attention_dropout,
- )
- config.vocab_size = encoder_config.vocab_size
- self.encoder = BertModel(encoder_config, add_pooling_layer=False)
- self.shared = self.encoder.get_input_embeddings()
- self.decoder = CPTDecoder(config, self.shared)
- self.num_decoder_layers = config.decoder_layers
- self.init_weights()
-
- def get_input_embeddings(self):
- return self.shared
-
- def set_input_embeddings(self, value):
- self.shared = value
- self.encoder.set_input_embeddings(self.shared)
- self.decoder.embed_tokens = self.shared
-
- def get_encoder(self):
- class _Encoder(torch.nn.Module):
- def __init__(self, encoder):
- super().__init__()
- self.encoder = encoder
-
- def forward(self, *args, **kwargs):
- kwargs['output_hidden_states'] = True
- return self.encoder(*args, **kwargs)
- return _Encoder(self.encoder)
-
- def get_decoder(self):
- return self.decoder
-
- @add_start_docstrings_to_model_forward(CPT_INPUTS_DOCSTRING)
- @add_code_sample_docstrings(
- tokenizer_class=_TOKENIZER_FOR_DOC,
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=Seq2SeqModelOutput,
- config_class=_CONFIG_FOR_DOC,
- )
- def forward(
- self,
- input_ids=None,
- attention_mask=None,
- decoder_input_ids=None,
- decoder_attention_mask=None,
- head_mask=None,
- decoder_head_mask=None,
- encoder_outputs=None,
- past_key_values=None,
- inputs_embeds=None,
- decoder_inputs_embeds=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
-
- # different to other models, CPT automatically creates decoder_input_ids from
- # input_ids if no decoder_input_ids are provided
- if decoder_input_ids is None and decoder_inputs_embeds is None:
- decoder_input_ids = shift_tokens_right(
- input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
- )
-
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- if getattr(self.config, "gradient_checkpointing", False) and self.training:
- # mpu.reset_checkpointed_activations_memory_buffer()
- use_cache = False
-
- if encoder_outputs is None:
- encoder_outputs = self.encoder(
- input_ids=input_ids,
- attention_mask=attention_mask,
- token_type_ids=torch.ones_like(input_ids),
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=True,
- return_dict=return_dict,
- )
- # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
- elif return_dict and isinstance(encoder_outputs, (tuple, list)):
- encoder_outputs = BaseModelOutput(
- last_hidden_state=encoder_outputs[0],
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
- )
-
- if isinstance(encoder_outputs, (torch.Tensor)):
- encoder_hidden_states = encoder_outputs
- else:
- encoder_hidden_states = encoder_outputs[1][-self.num_decoder_layers - 1]
-
- # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
- decoder_outputs = self.decoder(
- input_ids=decoder_input_ids,
- attention_mask=decoder_attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=attention_mask,
- head_mask=decoder_head_mask,
- encoder_head_mask=head_mask,
- past_key_values=past_key_values,
- inputs_embeds=decoder_inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
-
- if not return_dict:
- return decoder_outputs + encoder_outputs
-
- return Seq2SeqModelOutput(
- last_hidden_state=decoder_outputs.last_hidden_state,
- past_key_values=decoder_outputs.past_key_values,
- decoder_hidden_states=decoder_outputs.hidden_states,
- decoder_attentions=decoder_outputs.attentions,
- cross_attentions=decoder_outputs.cross_attentions,
- encoder_last_hidden_state=encoder_outputs.last_hidden_state if isinstance(encoder_outputs, dict) else None,
- encoder_hidden_states=encoder_outputs.hidden_states if isinstance(encoder_outputs, dict) else None,
- encoder_attentions=encoder_outputs.attentions if isinstance(encoder_outputs, dict) else None,
- )
-
-
- @add_start_docstrings(
- "The CPT Model with a language modeling head. Can be used for summarization.", CPT_START_DOCSTRING
- )
- class CPTForConditionalGeneration(CPTPretrainedModel):
- base_model_prefix = "model"
- _keys_to_ignore_on_load_missing = [
- r"final_logits_bias",
- r"encoder\.version",
- r"decoder\.version",
- r"lm_head\.weight",
- ]
-
- def __init__(self, config):
- super().__init__(config)
- self.model = CPTModel(config)
- self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
- self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
-
- self.init_weights()
-
- def get_encoder(self):
- return self.model.get_encoder()
-
- def get_decoder(self):
- return self.model.get_decoder()
-
- def resize_token_embeddings(self, new_num_tokens: int) -> "nn.Embedding":
- new_embeddings = super().resize_token_embeddings(new_num_tokens)
- self._resize_final_logits_bias(new_num_tokens)
- return new_embeddings
-
- def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
- old_num_tokens = self.final_logits_bias.shape[-1]
- if new_num_tokens <= old_num_tokens:
- new_bias = self.final_logits_bias[:, :new_num_tokens]
- else:
- extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
- new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
- self.register_buffer("final_logits_bias", new_bias)
-
- def get_output_embeddings(self):
- return self.lm_head
-
- def set_output_embeddings(self, new_embeddings):
- self.lm_head = new_embeddings
-
- @add_start_docstrings_to_model_forward(CPT_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
- def forward(
- self,
- input_ids=None,
- attention_mask=None,
- decoder_input_ids=None,
- decoder_attention_mask=None,
- head_mask=None,
- decoder_head_mask=None,
- encoder_outputs=None,
- past_key_values=None,
- inputs_embeds=None,
- decoder_inputs_embeds=None,
- labels=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
- r"""
- labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
- Labels for computing the masked language modeling loss. Indices should either be in ``[0, ...,
- config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are ignored
- (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``.
- Returns:
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- if labels is not None:
- if decoder_input_ids is None:
- decoder_input_ids = shift_tokens_right(
- labels, self.config.pad_token_id, self.config.decoder_start_token_id
- )
-
- outputs = self.model(
- input_ids,
- attention_mask=attention_mask,
- decoder_input_ids=decoder_input_ids,
- encoder_outputs=encoder_outputs,
- decoder_attention_mask=decoder_attention_mask,
- head_mask=head_mask,
- decoder_head_mask=decoder_head_mask,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- decoder_inputs_embeds=decoder_inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
-
- masked_lm_loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
-
- if not return_dict:
- output = (lm_logits,) + outputs[1:]
- return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
-
- return Seq2SeqLMOutput(
- loss=masked_lm_loss,
- logits=lm_logits,
- past_key_values=outputs.past_key_values,
- decoder_hidden_states=outputs.decoder_hidden_states,
- decoder_attentions=outputs.decoder_attentions,
- cross_attentions=outputs.cross_attentions,
- encoder_last_hidden_state=outputs.encoder_last_hidden_state,
- encoder_hidden_states=outputs.encoder_hidden_states,
- encoder_attentions=outputs.encoder_attentions,
- )
-
- def prepare_inputs_for_generation(
- self,
- decoder_input_ids,
- past=None,
- attention_mask=None,
- head_mask=None,
- use_cache=None,
- encoder_outputs=None,
- **kwargs
- ):
- # cut decoder_input_ids if past is used
- if past is not None:
- decoder_input_ids = decoder_input_ids[:, -1:]
-
- return {
- "input_ids": None, # encoder_outputs is defined. input_ids not needed
- "encoder_outputs": encoder_outputs,
- "past_key_values": past,
- "decoder_input_ids": decoder_input_ids,
- "attention_mask": attention_mask,
- "head_mask": head_mask,
- "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
- }
-
- @staticmethod
- def _expand_inputs_for_generation(
- input_ids: "torch.LongTensor",
- expand_size: int = 1,
- is_encoder_decoder: bool = False,
- attention_mask: "torch.LongTensor" = None,
- encoder_outputs = None,
- **model_kwargs,
- ):
- expanded_return_idx = (
- torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
- )
- input_ids = input_ids.index_select(0, expanded_return_idx)
-
- if "token_type_ids" in model_kwargs:
- token_type_ids = model_kwargs["token_type_ids"]
- model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)
-
- if attention_mask is not None:
- model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)
-
- if is_encoder_decoder:
- assert encoder_outputs is not None
- device = encoder_outputs.last_hidden_state.device
- encoder_outputs["hidden_states"] = tuple(h.index_select(0, expanded_return_idx.to(device)) \
- for h in encoder_outputs["hidden_states"])
- model_kwargs["encoder_outputs"] = encoder_outputs
- return input_ids, model_kwargs
-
- def prepare_decoder_input_ids_from_labels(self, labels: "torch.Tensor"):
- return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
-
- @staticmethod
- def _reorder_cache(past, beam_idx):
- reordered_past = ()
- for layer_past in past:
- # cached cross_attention states don't have to be reordered -> they are always the same
- reordered_past += (
- tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
- )
- return reordered_past
-
-
- @add_start_docstrings(
- """
- CPT model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
- tasks.
- """,
- CPT_START_DOCSTRING,
- )
- class CPTForSequenceClassification(CPTPretrainedModel):
- def __init__(self, config: CPTConfig, cls_mode=1, **kwargs):
- super().__init__(config, **kwargs)
- self.model = CPTModel(config)
- cls_mode = getattr(config, 'cls_mode', cls_mode)
- if cls_mode == 1:
- logger.info('Encoder for classification.')
- cls_dim = config.d_model
- elif cls_mode == 2:
- logger.info('Decoder for classification.')
- cls_dim = config.d_model
- elif cls_mode == 3:
- logger.info('Both encoder & decoder for classification.')
- cls_dim = config.d_model * 2
- else:
- raise NotImplementedError
-
- self.cls_head = CPTClassificationHead(
- cls_dim,
- cls_dim,
- config.num_labels,
- config.classifier_dropout,
- )
- self.model._init_weights(self.cls_head.dense)
- self.model._init_weights(self.cls_head.out_proj)
- self.cls_mode = cls_mode
- config.cls_mode = cls_mode
-
- @add_start_docstrings_to_model_forward(CPT_INPUTS_DOCSTRING)
- @add_code_sample_docstrings(
- tokenizer_class=_TOKENIZER_FOR_DOC,
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=Seq2SeqSequenceClassifierOutput,
- config_class=_CONFIG_FOR_DOC,
- )
- def forward(
- self,
- input_ids=None,
- attention_mask=None,
- decoder_input_ids=None,
- decoder_attention_mask=None,
- head_mask=None,
- decoder_head_mask=None,
- encoder_outputs=None,
- inputs_embeds=None,
- decoder_inputs_embeds=None,
- labels=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
- r"""
- labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
- Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
- config.num_labels - 1]`. If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if labels is not None:
- use_cache = False
-
- if input_ids is None and inputs_embeds is not None:
- raise NotImplementedError(
- f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
- )
-
- outputs = self.model(
- input_ids,
- attention_mask=attention_mask,
- decoder_input_ids=decoder_input_ids,
- decoder_attention_mask=decoder_attention_mask,
- head_mask=head_mask,
- decoder_head_mask=decoder_head_mask,
- encoder_outputs=encoder_outputs,
- inputs_embeds=inputs_embeds,
- decoder_inputs_embeds=decoder_inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=True,
- )
-
- hidden_states = outputs.last_hidden_state
- enc_hidden_states = outputs.encoder_last_hidden_state
- enc_rep = enc_hidden_states[:, 0]
-
- eos_mask = input_ids.eq(self.config.eos_token_id)
-
- if len(torch.unique(eos_mask.sum(1))) > 1:
- raise ValueError("All examples must have the same number of <eos> tokens.")
- dec_rep = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
- :, -1, :
- ]
-
- if self.cls_mode == 1:
- logits = self.cls_head(enc_rep)
- elif self.cls_mode == 2:
- logits = self.cls_head(dec_rep)
- elif self.cls_mode == 3:
- rep = torch.cat([enc_rep, dec_rep], dim=-1)
- logits = self.cls_head(rep)
- else:
- raise NotImplementedError
-
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- return ((loss,) + output) if loss is not None else output
-
- return Seq2SeqSequenceClassifierOutput(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- decoder_hidden_states=outputs.decoder_hidden_states,
- decoder_attentions=outputs.decoder_attentions,
- cross_attentions=outputs.cross_attentions,
- encoder_last_hidden_state=outputs.encoder_last_hidden_state,
- encoder_hidden_states=outputs.encoder_hidden_states,
- encoder_attentions=outputs.encoder_attentions,
- )
-
-
- @add_start_docstrings(
- """
- CPT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
- layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
- """,
- CPT_START_DOCSTRING,
- )
- class CPTForQuestionAnswering(CPTPretrainedModel):
- def __init__(self, config: CPTConfig, cls_mode=1, **kwargs):
- super().__init__(config, **kwargs)
- config.num_labels = 2
- self.num_labels = config.num_labels
-
- self.model = CPTModel(config)
-
- cls_mode = getattr(config, 'cls_mode', cls_mode)
- if cls_mode == 1:
- logger.info('Encoder for classification.')
- cls_dim = config.d_model
- elif cls_mode == 2:
- logger.info('Decoder for classification.')
- cls_dim = config.d_model
- elif cls_mode == 3:
- logger.info('Both encoder & decoder for classification.')
- cls_dim = config.d_model * 2
- else:
- raise NotImplementedError
-
- self.qa_outputs = nn.Linear(cls_dim, config.num_labels)
- self.model._init_weights(self.qa_outputs)
-
- self.cls_mode = cls_mode
- config.cls_mode = cls_mode
-
- @add_start_docstrings_to_model_forward(CPT_INPUTS_DOCSTRING)
- @add_code_sample_docstrings(
- tokenizer_class=_TOKENIZER_FOR_DOC,
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=Seq2SeqSequenceClassifierOutput,
- config_class=_CONFIG_FOR_DOC,
- )
- def forward(
- self,
- input_ids=None,
- attention_mask=None,
- decoder_input_ids=None,
- decoder_attention_mask=None,
- head_mask=None,
- decoder_head_mask=None,
- encoder_outputs=None,
- start_positions=None,
- end_positions=None,
- inputs_embeds=None,
- decoder_inputs_embeds=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
- r"""
- start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
- Labels for position (index) of the start of the labelled span for computing the token classification loss.
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
- are not taken into account for computing the loss.
- end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
- Labels for position (index) of the end of the labelled span for computing the token classification loss.
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
- are not taken into account for computing the loss.
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- if input_ids is None and inputs_embeds is not None:
- raise NotImplementedError(
- f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
- )
-
- outputs = self.model(
- input_ids,
- attention_mask=attention_mask,
- decoder_input_ids=decoder_input_ids,
- decoder_attention_mask=decoder_attention_mask,
- head_mask=head_mask,
- decoder_head_mask=decoder_head_mask,
- encoder_outputs=encoder_outputs,
- inputs_embeds=inputs_embeds,
- decoder_inputs_embeds=decoder_inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=True,
- )
-
- hidden_states = outputs.last_hidden_state
- enc_hidden_states = outputs.encoder_last_hidden_state
-
- if self.cls_mode == 1:
- logits = self.qa_outputs(enc_hidden_states)
- elif self.cls_mode == 2:
- logits = self.qa_outputs(hidden_states)
- elif self.cls_mode == 3:
- rep = torch.cat([enc_hidden_states, hidden_states], dim=-1)
- logits = self.qa_outputs(rep)
- else:
- raise NotImplementedError
-
- start_logits, end_logits = logits.split(1, dim=-1)
- start_logits = start_logits.squeeze(-1)
- end_logits = end_logits.squeeze(-1)
-
- total_loss = None
- if start_positions is not None and end_positions is not None:
- # If we are on multi-GPU, split add a dimension
- if len(start_positions.size()) > 1:
- start_positions = start_positions.squeeze(-1)
- if len(end_positions.size()) > 1:
- end_positions = end_positions.squeeze(-1)
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
- ignored_index = start_logits.size(1)
- start_positions.clamp_(0, ignored_index)
- end_positions.clamp_(0, ignored_index)
-
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
- start_loss = loss_fct(start_logits, start_positions)
- end_loss = loss_fct(end_logits, end_positions)
- total_loss = (start_loss + end_loss) / 2
-
- if not return_dict:
- output = (
- start_logits,
- end_logits,
- ) + outputs[1:]
- return ((total_loss,) + output) if total_loss is not None else output
-
- return Seq2SeqQuestionAnsweringModelOutput(
- loss=total_loss,
- start_logits=start_logits,
- end_logits=end_logits,
- past_key_values=outputs.past_key_values,
- decoder_hidden_states=outputs.decoder_hidden_states,
- decoder_attentions=outputs.decoder_attentions,
- cross_attentions=outputs.cross_attentions,
- encoder_last_hidden_state=outputs.encoder_last_hidden_state,
- encoder_hidden_states=outputs.encoder_hidden_states,
- encoder_attentions=outputs.encoder_attentions,
- )
-
-
- class CPTForMaskedLM(CPTPretrainedModel):
- _keys_to_ignore_on_load_missing = [
- r"final_logits_bias",
- r"encoder\.version",
- r"decoder\.version",
- r"lm_head\.weight",
- ]
- def __init__(self, config, **kwargs):
- super().__init__(config)
- self.model = CPTModel(config)
- self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
- self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
-
- self.init_weights()
-
- def get_encoder(self):
- return self.model.get_encoder()
-
- def get_decoder(self):
- return self.model.get_decoder()
-
- def get_output_embeddings(self):
- return self.lm_head
-
- def forward(
- self,
- input_ids=None,
- attention_mask=None,
- decoder_input_ids=None,
- decoder_attention_mask=None,
- head_mask=None,
- decoder_head_mask=None,
- encoder_outputs=None,
- inputs_embeds=None,
- decoder_inputs_embeds=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- if input_ids is None and inputs_embeds is not None:
- raise NotImplementedError(
- f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
- )
-
- outputs = self.model(
- input_ids,
- attention_mask=attention_mask,
- decoder_input_ids=decoder_input_ids,
- decoder_attention_mask=decoder_attention_mask,
- head_mask=head_mask,
- decoder_head_mask=decoder_head_mask,
- encoder_outputs=encoder_outputs,
- inputs_embeds=inputs_embeds,
- decoder_inputs_embeds=decoder_inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=True,
- )
-
- hidden_states = outputs.last_hidden_state
- enc_hidden_states = outputs.encoder_last_hidden_state
-
- dec_logits = self.lm_head(hidden_states) + self.final_logits_bias
- enc_logits = self.lm_head(enc_hidden_states) + self.final_logits_bias
-
- if not return_dict:
- logits = (enc_logits, dec_logits)
- output = (logits,) + outputs[1:]
- return output
-
- return Seq2SeqLMOutput(
- loss=None,
- logits=(enc_logits, dec_logits),
- past_key_values=outputs.past_key_values,
- decoder_hidden_states=outputs.decoder_hidden_states,
- decoder_attentions=outputs.decoder_attentions,
- cross_attentions=outputs.cross_attentions,
- encoder_last_hidden_state=outputs.encoder_last_hidden_state,
- encoder_hidden_states=outputs.encoder_hidden_states,
- encoder_attentions=outputs.encoder_attentions,
- )
|