|
@@ -2,13 +2,14 @@ |
|
|
bert.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed under the Apache License 2.0. |
|
|
bert.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed under the Apache License 2.0. |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
|
|
|
import os |
|
|
import torch |
|
|
import torch |
|
|
from torch import nn |
|
|
from torch import nn |
|
|
|
|
|
|
|
|
from .base_model import BaseModel |
|
|
from .base_model import BaseModel |
|
|
from ..core.const import Const |
|
|
from ..core.const import Const |
|
|
from ..modules.encoder import BertModel |
|
|
from ..modules.encoder import BertModel |
|
|
from ..modules.encoder.bert import BertConfig |
|
|
|
|
|
|
|
|
from ..modules.encoder.bert import BertConfig, CONFIG_FILE |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BertForSequenceClassification(BaseModel): |
|
|
class BertForSequenceClassification(BaseModel): |
|
@@ -54,6 +55,7 @@ class BertForSequenceClassification(BaseModel): |
|
|
self.num_labels = num_labels |
|
|
self.num_labels = num_labels |
|
|
if bert_dir is not None: |
|
|
if bert_dir is not None: |
|
|
self.bert = BertModel.from_pretrained(bert_dir) |
|
|
self.bert = BertModel.from_pretrained(bert_dir) |
|
|
|
|
|
config = BertConfig(os.path.join(bert_dir, CONFIG_FILE)) |
|
|
else: |
|
|
else: |
|
|
if config is None: |
|
|
if config is None: |
|
|
config = BertConfig(30522) |
|
|
config = BertConfig(30522) |
|
@@ -67,20 +69,20 @@ class BertForSequenceClassification(BaseModel): |
|
|
model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir) |
|
|
model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir) |
|
|
return model |
|
|
return model |
|
|
|
|
|
|
|
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): |
|
|
|
|
|
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) |
|
|
|
|
|
|
|
|
def forward(self, words, seq_len=None, target=None): |
|
|
|
|
|
_, pooled_output = self.bert(words, attention_mask=seq_len, output_all_encoded_layers=False) |
|
|
pooled_output = self.dropout(pooled_output) |
|
|
pooled_output = self.dropout(pooled_output) |
|
|
logits = self.classifier(pooled_output) |
|
|
logits = self.classifier(pooled_output) |
|
|
|
|
|
|
|
|
if labels is not None: |
|
|
|
|
|
|
|
|
if target is not None: |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
|
|
|
|
|
|
loss = loss_fct(logits, target) |
|
|
return {Const.OUTPUT: logits, Const.LOSS: loss} |
|
|
return {Const.OUTPUT: logits, Const.LOSS: loss} |
|
|
else: |
|
|
else: |
|
|
return {Const.OUTPUT: logits} |
|
|
return {Const.OUTPUT: logits} |
|
|
|
|
|
|
|
|
def predict(self, input_ids, token_type_ids=None, attention_mask=None): |
|
|
|
|
|
logits = self.forward(input_ids, token_type_ids, attention_mask) |
|
|
|
|
|
|
|
|
def predict(self, words, seq_len=None): |
|
|
|
|
|
logits = self.forward(words, seq_len=seq_len)[Const.OUTPUT] |
|
|
return {Const.OUTPUT: torch.argmax(logits, dim=-1)} |
|
|
return {Const.OUTPUT: torch.argmax(logits, dim=-1)} |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -140,7 +142,8 @@ class BertForMultipleChoice(BaseModel): |
|
|
model = cls(num_choices=num_choices, config=config, bert_dir=pretrained_model_dir) |
|
|
model = cls(num_choices=num_choices, config=config, bert_dir=pretrained_model_dir) |
|
|
return model |
|
|
return model |
|
|
|
|
|
|
|
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): |
|
|
|
|
|
|
|
|
def forward(self, words, seq_len1=None, seq_len2=None, target=None): |
|
|
|
|
|
input_ids, token_type_ids, attention_mask = words, seq_len1, seq_len2 |
|
|
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) |
|
|
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) |
|
|
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) |
|
|
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) |
|
|
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) |
|
|
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) |
|
@@ -149,15 +152,15 @@ class BertForMultipleChoice(BaseModel): |
|
|
logits = self.classifier(pooled_output) |
|
|
logits = self.classifier(pooled_output) |
|
|
reshaped_logits = logits.view(-1, self.num_choices) |
|
|
reshaped_logits = logits.view(-1, self.num_choices) |
|
|
|
|
|
|
|
|
if labels is not None: |
|
|
|
|
|
|
|
|
if target is not None: |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
loss = loss_fct(reshaped_logits, labels) |
|
|
|
|
|
|
|
|
loss = loss_fct(reshaped_logits, target) |
|
|
return {Const.OUTPUT: reshaped_logits, Const.LOSS: loss} |
|
|
return {Const.OUTPUT: reshaped_logits, Const.LOSS: loss} |
|
|
else: |
|
|
else: |
|
|
return {Const.OUTPUT: reshaped_logits} |
|
|
return {Const.OUTPUT: reshaped_logits} |
|
|
|
|
|
|
|
|
def predict(self, input_ids, token_type_ids=None, attention_mask=None): |
|
|
|
|
|
logits = self.forward(input_ids, token_type_ids, attention_mask)[Const.OUTPUT] |
|
|
|
|
|
|
|
|
def predict(self, words, seq_len1=None, seq_len2=None,): |
|
|
|
|
|
logits = self.forward(words, seq_len1=seq_len1, seq_len2=seq_len2)[Const.OUTPUT] |
|
|
return {Const.OUTPUT: torch.argmax(logits, dim=-1)} |
|
|
return {Const.OUTPUT: torch.argmax(logits, dim=-1)} |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -219,27 +222,27 @@ class BertForTokenClassification(BaseModel): |
|
|
model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir) |
|
|
model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir) |
|
|
return model |
|
|
return model |
|
|
|
|
|
|
|
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): |
|
|
|
|
|
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) |
|
|
|
|
|
|
|
|
def forward(self, words, seq_len1=None, seq_len2=None, target=None): |
|
|
|
|
|
sequence_output, _ = self.bert(words, seq_len1, seq_len2, output_all_encoded_layers=False) |
|
|
sequence_output = self.dropout(sequence_output) |
|
|
sequence_output = self.dropout(sequence_output) |
|
|
logits = self.classifier(sequence_output) |
|
|
logits = self.classifier(sequence_output) |
|
|
|
|
|
|
|
|
if labels is not None: |
|
|
|
|
|
|
|
|
if target is not None: |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
# Only keep active parts of the loss |
|
|
# Only keep active parts of the loss |
|
|
if attention_mask is not None: |
|
|
|
|
|
active_loss = attention_mask.view(-1) == 1 |
|
|
|
|
|
|
|
|
if seq_len2 is not None: |
|
|
|
|
|
active_loss = seq_len2.view(-1) == 1 |
|
|
active_logits = logits.view(-1, self.num_labels)[active_loss] |
|
|
active_logits = logits.view(-1, self.num_labels)[active_loss] |
|
|
active_labels = labels.view(-1)[active_loss] |
|
|
|
|
|
|
|
|
active_labels = target.view(-1)[active_loss] |
|
|
loss = loss_fct(active_logits, active_labels) |
|
|
loss = loss_fct(active_logits, active_labels) |
|
|
else: |
|
|
else: |
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
|
|
|
|
|
|
loss = loss_fct(logits.view(-1, self.num_labels), target.view(-1)) |
|
|
return {Const.OUTPUT: logits, Const.LOSS: loss} |
|
|
return {Const.OUTPUT: logits, Const.LOSS: loss} |
|
|
else: |
|
|
else: |
|
|
return {Const.OUTPUT: logits} |
|
|
return {Const.OUTPUT: logits} |
|
|
|
|
|
|
|
|
def predict(self, input_ids, token_type_ids=None, attention_mask=None): |
|
|
|
|
|
logits = self.forward(input_ids, token_type_ids, attention_mask)[Const.OUTPUT] |
|
|
|
|
|
|
|
|
def predict(self, words, seq_len1=None, seq_len2=None): |
|
|
|
|
|
logits = self.forward(words, seq_len1, seq_len2)[Const.OUTPUT] |
|
|
return {Const.OUTPUT: torch.argmax(logits, dim=-1)} |
|
|
return {Const.OUTPUT: torch.argmax(logits, dim=-1)} |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -304,34 +307,34 @@ class BertForQuestionAnswering(BaseModel): |
|
|
model = cls(config=config, bert_dir=pretrained_model_dir) |
|
|
model = cls(config=config, bert_dir=pretrained_model_dir) |
|
|
return model |
|
|
return model |
|
|
|
|
|
|
|
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): |
|
|
|
|
|
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) |
|
|
|
|
|
|
|
|
def forward(self, words, seq_len1=None, seq_len2=None, target1=None, target2=None): |
|
|
|
|
|
sequence_output, _ = self.bert(words, seq_len1, seq_len2, output_all_encoded_layers=False) |
|
|
logits = self.qa_outputs(sequence_output) |
|
|
logits = self.qa_outputs(sequence_output) |
|
|
start_logits, end_logits = logits.split(1, dim=-1) |
|
|
start_logits, end_logits = logits.split(1, dim=-1) |
|
|
start_logits = start_logits.squeeze(-1) |
|
|
start_logits = start_logits.squeeze(-1) |
|
|
end_logits = end_logits.squeeze(-1) |
|
|
end_logits = end_logits.squeeze(-1) |
|
|
|
|
|
|
|
|
if start_positions is not None and end_positions is not None: |
|
|
|
|
|
|
|
|
if target1 is not None and target2 is not None: |
|
|
# If we are on multi-GPU, split add a dimension |
|
|
# If we are on multi-GPU, split add a dimension |
|
|
if len(start_positions.size()) > 1: |
|
|
|
|
|
start_positions = start_positions.squeeze(-1) |
|
|
|
|
|
if len(end_positions.size()) > 1: |
|
|
|
|
|
end_positions = end_positions.squeeze(-1) |
|
|
|
|
|
|
|
|
if len(target1.size()) > 1: |
|
|
|
|
|
target1 = target1.squeeze(-1) |
|
|
|
|
|
if len(target2.size()) > 1: |
|
|
|
|
|
target2 = target2.squeeze(-1) |
|
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms |
|
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms |
|
|
ignored_index = start_logits.size(1) |
|
|
ignored_index = start_logits.size(1) |
|
|
start_positions.clamp_(0, ignored_index) |
|
|
|
|
|
end_positions.clamp_(0, ignored_index) |
|
|
|
|
|
|
|
|
target1.clamp_(0, ignored_index) |
|
|
|
|
|
target2.clamp_(0, ignored_index) |
|
|
|
|
|
|
|
|
loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) |
|
|
loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) |
|
|
start_loss = loss_fct(start_logits, start_positions) |
|
|
|
|
|
end_loss = loss_fct(end_logits, end_positions) |
|
|
|
|
|
|
|
|
start_loss = loss_fct(start_logits, target1) |
|
|
|
|
|
end_loss = loss_fct(end_logits, target2) |
|
|
total_loss = (start_loss + end_loss) / 2 |
|
|
total_loss = (start_loss + end_loss) / 2 |
|
|
return {Const.OUTPUTS(0): start_logits, Const.OUTPUTS(1): end_logits, Const.LOSS: total_loss} |
|
|
return {Const.OUTPUTS(0): start_logits, Const.OUTPUTS(1): end_logits, Const.LOSS: total_loss} |
|
|
else: |
|
|
else: |
|
|
return {Const.OUTPUTS(0): start_logits, Const.OUTPUTS(1): end_logits} |
|
|
return {Const.OUTPUTS(0): start_logits, Const.OUTPUTS(1): end_logits} |
|
|
|
|
|
|
|
|
def predict(self, input_ids, token_type_ids=None, attention_mask=None, **kwargs): |
|
|
|
|
|
logits = self.forward(input_ids, token_type_ids, attention_mask) |
|
|
|
|
|
|
|
|
def predict(self, words, seq_len1=None, seq_len2=None): |
|
|
|
|
|
logits = self.forward(words, seq_len1, seq_len2) |
|
|
start_logits = logits[Const.OUTPUTS(0)] |
|
|
start_logits = logits[Const.OUTPUTS(0)] |
|
|
end_logits = logits[Const.OUTPUTS(1)] |
|
|
end_logits = logits[Const.OUTPUTS(1)] |
|
|
return {Const.OUTPUTS(0): torch.argmax(start_logits, dim=-1), |
|
|
return {Const.OUTPUTS(0): torch.argmax(start_logits, dim=-1), |
|
|