|
- """
- bert.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed under the Apache License 2.0.
-
- """
- import torch
- from torch import nn
-
- from .base_model import BaseModel
- from fastNLP.modules.encoder import BertModel
-
-
- class BertForSequenceClassification(BaseModel):
- """BERT model for classification.
- This module is composed of the BERT model with a linear layer on top of
- the pooled output.
- Params:
- `config`: a BertConfig class instance with the configuration to build a new model.
- `num_labels`: the number of classes for the classifier. Default = 2.
- Inputs:
- `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
- with the word token indices in the vocabulary. Items in the batch should begin with the special "CLS" token. (see the tokens preprocessing logic in the scripts
- `extract_features.py`, `run_classifier.py` and `run_squad.py`)
- `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
- types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
- a `sentence B` token (see BERT paper for more details).
- `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
- selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
- input sequence length in the current batch. It's the mask that we typically use for attention when
- a batch has varying length sentences.
- `labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
- with indices selected in [0, ..., num_labels].
- Outputs:
- if `labels` is not `None`:
- Outputs the CrossEntropy classification loss of the output with the labels.
- if `labels` is `None`:
- Outputs the classification logits of shape [batch_size, num_labels].
- Example usage:
- ```python
- # Already been converted into WordPiece token ids
- input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
- input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
- token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
- config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
- num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
- num_labels = 2
- model = BertForSequenceClassification(config, num_labels)
- logits = model(input_ids, token_type_ids, input_mask)
- ```
- """
- def __init__(self, config, num_labels, bert_dir):
- super(BertForSequenceClassification, self).__init__()
- self.num_labels = num_labels
- self.bert = BertModel.from_pretrained(bert_dir)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- self.classifier = nn.Linear(config.hidden_size, num_labels)
-
- def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
- _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
- pooled_output = self.dropout(pooled_output)
- logits = self.classifier(pooled_output)
-
- if labels is not None:
- loss_fct = nn.CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- return {"pred": logits, "loss": loss}
- else:
- return {"pred": logits}
-
- def predict(self, input_ids, token_type_ids=None, attention_mask=None):
- logits = self.forward(input_ids, token_type_ids, attention_mask)
- return {"pred": torch.argmax(logits, dim=-1)}
-
-
- class BertForMultipleChoice(BaseModel):
- """BERT model for multiple choice tasks.
- This module is composed of the BERT model with a linear layer on top of
- the pooled output.
- Params:
- `config`: a BertConfig class instance with the configuration to build a new model.
- `num_choices`: the number of classes for the classifier. Default = 2.
- Inputs:
- `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
- with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
- `extract_features.py`, `run_classifier.py` and `run_squad.py`)
- `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length]
- with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A`
- and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
- `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices
- selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
- input sequence length in the current batch. It's the mask that we typically use for attention when
- a batch has varying length sentences.
- `labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
- with indices selected in [0, ..., num_choices].
- Outputs:
- if `labels` is not `None`:
- Outputs the CrossEntropy classification loss of the output with the labels.
- if `labels` is `None`:
- Outputs the classification logits of shape [batch_size, num_labels].
- Example usage:
- ```python
- # Already been converted into WordPiece token ids
- input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]])
- input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]])
- token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]])
- config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
- num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
- num_choices = 2
- model = BertForMultipleChoice(config, num_choices, bert_dir)
- logits = model(input_ids, token_type_ids, input_mask)
- ```
- """
- def __init__(self, config, num_choices, bert_dir):
- super(BertForMultipleChoice, self).__init__()
- self.num_choices = num_choices
- self.bert = BertModel.from_pretrained(bert_dir)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- self.classifier = nn.Linear(config.hidden_size, 1)
-
- def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
- flat_input_ids = input_ids.view(-1, input_ids.size(-1))
- flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
- flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1))
- _, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False)
- pooled_output = self.dropout(pooled_output)
- logits = self.classifier(pooled_output)
- reshaped_logits = logits.view(-1, self.num_choices)
-
- if labels is not None:
- loss_fct = nn.CrossEntropyLoss()
- loss = loss_fct(reshaped_logits, labels)
- return {"pred": reshaped_logits, "loss": loss}
- else:
- return {"pred": reshaped_logits}
-
- def predict(self, input_ids, token_type_ids=None, attention_mask=None):
- logits = self.forward(input_ids, token_type_ids, attention_mask)["pred"]
- return {"pred": torch.argmax(logits, dim=-1)}
-
-
- class BertForTokenClassification(BaseModel):
- """BERT model for token-level classification.
- This module is composed of the BERT model with a linear layer on top of
- the full hidden state of the last layer.
- Params:
- `config`: a BertConfig class instance with the configuration to build a new model.
- `num_labels`: the number of classes for the classifier. Default = 2.
- `bert_dir`: a dir which contains the bert parameters within file `pytorch_model.bin`
- Inputs:
- `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
- with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
- `extract_features.py`, `run_classifier.py` and `run_squad.py`)
- `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
- types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
- a `sentence B` token (see BERT paper for more details).
- `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
- selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
- input sequence length in the current batch. It's the mask that we typically use for attention when
- a batch has varying length sentences.
- `labels`: labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length]
- with indices selected in [0, ..., num_labels].
- Outputs:
- if `labels` is not `None`:
- Outputs the CrossEntropy classification loss of the output with the labels.
- if `labels` is `None`:
- Outputs the classification logits of shape [batch_size, sequence_length, num_labels].
- Example usage:
- ```python
- # Already been converted into WordPiece token ids
- input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
- input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
- token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
- config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
- num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
- num_labels = 2
- bert_dir = 'your-bert-file-dir'
- model = BertForTokenClassification(config, num_labels, bert_dir)
- logits = model(input_ids, token_type_ids, input_mask)
- ```
- """
- def __init__(self, config, num_labels, bert_dir):
- super(BertForTokenClassification, self).__init__()
- self.num_labels = num_labels
- self.bert = BertModel.from_pretrained(bert_dir)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- self.classifier = nn.Linear(config.hidden_size, num_labels)
-
- def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
- sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
- sequence_output = self.dropout(sequence_output)
- logits = self.classifier(sequence_output)
-
- if labels is not None:
- loss_fct = nn.CrossEntropyLoss()
- # Only keep active parts of the loss
- if attention_mask is not None:
- active_loss = attention_mask.view(-1) == 1
- active_logits = logits.view(-1, self.num_labels)[active_loss]
- active_labels = labels.view(-1)[active_loss]
- loss = loss_fct(active_logits, active_labels)
- else:
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- return {"pred": logits, "loss": loss}
- else:
- return {"pred": logits}
-
- def predict(self, input_ids, token_type_ids=None, attention_mask=None):
- logits = self.forward(input_ids, token_type_ids, attention_mask)["pred"]
- return {"pred": torch.argmax(logits, dim=-1)}
-
-
- class BertForQuestionAnswering(BaseModel):
- """BERT model for Question Answering (span extraction).
- This module is composed of the BERT model with a linear layer on top of
- the sequence output that computes start_logits and end_logits
- Params:
- `config`: a BertConfig class instance with the configuration to build a new model.
- `bert_dir`: a dir which contains the bert parameters within file `pytorch_model.bin`
- Inputs:
- `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
- with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
- `extract_features.py`, `run_classifier.py` and `run_squad.py`)
- `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
- types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
- a `sentence B` token (see BERT paper for more details).
- `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
- selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
- input sequence length in the current batch. It's the mask that we typically use for attention when
- a batch has varying length sentences.
- `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
- Positions are clamped to the length of the sequence and position outside of the sequence are not taken
- into account for computing the loss.
- `end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size].
- Positions are clamped to the length of the sequence and position outside of the sequence are not taken
- into account for computing the loss.
- Outputs:
- if `start_positions` and `end_positions` are not `None`:
- Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions.
- if `start_positions` or `end_positions` is `None`:
- Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end
- position tokens of shape [batch_size, sequence_length].
- Example usage:
- ```python
- # Already been converted into WordPiece token ids
- input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
- input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
- token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
- config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
- num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
- bert_dir = 'your-bert-file-dir'
- model = BertForQuestionAnswering(config, bert_dir)
- start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
- ```
- """
- def __init__(self, config, bert_dir):
- super(BertForQuestionAnswering, self).__init__()
- self.bert = BertModel.from_pretrained(bert_dir)
- # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
- # self.dropout = nn.Dropout(config.hidden_dropout_prob)
- self.qa_outputs = nn.Linear(config.hidden_size, 2)
-
- def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None):
- sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
- logits = self.qa_outputs(sequence_output)
- start_logits, end_logits = logits.split(1, dim=-1)
- start_logits = start_logits.squeeze(-1)
- end_logits = end_logits.squeeze(-1)
-
- if start_positions is not None and end_positions is not None:
- # If we are on multi-GPU, split add a dimension
- if len(start_positions.size()) > 1:
- start_positions = start_positions.squeeze(-1)
- if len(end_positions.size()) > 1:
- end_positions = end_positions.squeeze(-1)
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
- ignored_index = start_logits.size(1)
- start_positions.clamp_(0, ignored_index)
- end_positions.clamp_(0, ignored_index)
-
- loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
- start_loss = loss_fct(start_logits, start_positions)
- end_loss = loss_fct(end_logits, end_positions)
- total_loss = (start_loss + end_loss) / 2
- return {"loss": total_loss}
- else:
- return {"pred1": start_logits, "pred2": end_logits}
-
- def predict(self, input_ids, token_type_ids=None, attention_mask=None, **kwargs):
- logits = self.forward(input_ids, token_type_ids, attention_mask)
- start_logits = logits["pred1"]
- end_logits = logits["pred2"]
- return {"pred1": torch.argmax(start_logits, dim=-1), "pred2": torch.argmax(end_logits, dim=-1)}
|