Browse Source

rename param names in model/bert.py to adjust fastNLP.Const

tags/v0.4.10
xuyige 5 years ago
parent
commit
02c8fc0de7
1 changed files with 37 additions and 34 deletions
  1. +37
    -34
      fastNLP/models/bert.py

+ 37
- 34
fastNLP/models/bert.py View File

@@ -2,13 +2,14 @@
bert.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed under the Apache License 2.0.

"""
import os
import torch
from torch import nn

from .base_model import BaseModel
from ..core.const import Const
from ..modules.encoder import BertModel
from ..modules.encoder.bert import BertConfig
from ..modules.encoder.bert import BertConfig, CONFIG_FILE


class BertForSequenceClassification(BaseModel):
@@ -54,6 +55,7 @@ class BertForSequenceClassification(BaseModel):
self.num_labels = num_labels
if bert_dir is not None:
self.bert = BertModel.from_pretrained(bert_dir)
config = BertConfig(os.path.join(bert_dir, CONFIG_FILE))
else:
if config is None:
config = BertConfig(30522)
@@ -67,20 +69,20 @@ class BertForSequenceClassification(BaseModel):
model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir)
return model

def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
def forward(self, words, seq_len=None, target=None):
_, pooled_output = self.bert(words, attention_mask=seq_len, output_all_encoded_layers=False)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)

if labels is not None:
if target is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
loss = loss_fct(logits, target)
return {Const.OUTPUT: logits, Const.LOSS: loss}
else:
return {Const.OUTPUT: logits}

def predict(self, input_ids, token_type_ids=None, attention_mask=None):
logits = self.forward(input_ids, token_type_ids, attention_mask)
def predict(self, words, seq_len=None):
logits = self.forward(words, seq_len=seq_len)[Const.OUTPUT]
return {Const.OUTPUT: torch.argmax(logits, dim=-1)}


@@ -140,7 +142,8 @@ class BertForMultipleChoice(BaseModel):
model = cls(num_choices=num_choices, config=config, bert_dir=pretrained_model_dir)
return model

def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
def forward(self, words, seq_len1=None, seq_len2=None, target=None):
input_ids, token_type_ids, attention_mask = words, seq_len1, seq_len2
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1))
@@ -149,15 +152,15 @@ class BertForMultipleChoice(BaseModel):
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, self.num_choices)

if labels is not None:
if target is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
loss = loss_fct(reshaped_logits, target)
return {Const.OUTPUT: reshaped_logits, Const.LOSS: loss}
else:
return {Const.OUTPUT: reshaped_logits}

def predict(self, input_ids, token_type_ids=None, attention_mask=None):
logits = self.forward(input_ids, token_type_ids, attention_mask)[Const.OUTPUT]
def predict(self, words, seq_len1=None, seq_len2=None,):
logits = self.forward(words, seq_len1=seq_len1, seq_len2=seq_len2)[Const.OUTPUT]
return {Const.OUTPUT: torch.argmax(logits, dim=-1)}


@@ -219,27 +222,27 @@ class BertForTokenClassification(BaseModel):
model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir)
return model

def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
def forward(self, words, seq_len1=None, seq_len2=None, target=None):
sequence_output, _ = self.bert(words, seq_len1, seq_len2, output_all_encoded_layers=False)
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)

if labels is not None:
if target is not None:
loss_fct = nn.CrossEntropyLoss()
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
if seq_len2 is not None:
active_loss = seq_len2.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)[active_loss]
active_labels = labels.view(-1)[active_loss]
active_labels = target.view(-1)[active_loss]
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
loss = loss_fct(logits.view(-1, self.num_labels), target.view(-1))
return {Const.OUTPUT: logits, Const.LOSS: loss}
else:
return {Const.OUTPUT: logits}

def predict(self, input_ids, token_type_ids=None, attention_mask=None):
logits = self.forward(input_ids, token_type_ids, attention_mask)[Const.OUTPUT]
def predict(self, words, seq_len1=None, seq_len2=None):
logits = self.forward(words, seq_len1, seq_len2)[Const.OUTPUT]
return {Const.OUTPUT: torch.argmax(logits, dim=-1)}


@@ -304,34 +307,34 @@ class BertForQuestionAnswering(BaseModel):
model = cls(config=config, bert_dir=pretrained_model_dir)
return model

def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
def forward(self, words, seq_len1=None, seq_len2=None, target1=None, target2=None):
sequence_output, _ = self.bert(words, seq_len1, seq_len2, output_all_encoded_layers=False)
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)

if start_positions is not None and end_positions is not None:
if target1 is not None and target2 is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
if len(target1.size()) > 1:
target1 = target1.squeeze(-1)
if len(target2.size()) > 1:
target2 = target2.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
target1.clamp_(0, ignored_index)
target2.clamp_(0, ignored_index)

loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
start_loss = loss_fct(start_logits, target1)
end_loss = loss_fct(end_logits, target2)
total_loss = (start_loss + end_loss) / 2
return {Const.OUTPUTS(0): start_logits, Const.OUTPUTS(1): end_logits, Const.LOSS: total_loss}
else:
return {Const.OUTPUTS(0): start_logits, Const.OUTPUTS(1): end_logits}

def predict(self, input_ids, token_type_ids=None, attention_mask=None, **kwargs):
logits = self.forward(input_ids, token_type_ids, attention_mask)
def predict(self, words, seq_len1=None, seq_len2=None):
logits = self.forward(words, seq_len1, seq_len2)
start_logits = logits[Const.OUTPUTS(0)]
end_logits = logits[Const.OUTPUTS(1)]
return {Const.OUTPUTS(0): torch.argmax(start_logits, dim=-1),


Loading…
Cancel
Save