|
- import torch
- from torch import nn
- from torch.nn import init
-
- from fastNLP.modules.encoder.bert import BertModel
-
-
- class Classifier(nn.Module):
- def __init__(self, hidden_size):
- super(Classifier, self).__init__()
- self.linear = nn.Linear(hidden_size, 1)
- self.sigmoid = nn.Sigmoid()
-
- def forward(self, inputs, mask_cls):
- h = self.linear(inputs).squeeze(-1) # [batch_size, seq_len]
- sent_scores = self.sigmoid(h) * mask_cls.float()
- return sent_scores
-
-
- class BertSum(nn.Module):
-
- def __init__(self, hidden_size=768):
- super(BertSum, self).__init__()
-
- self.hidden_size = hidden_size
-
- self.encoder = BertModel.from_pretrained('/path/to/uncased_L-12_H-768_A-12')
- self.decoder = Classifier(self.hidden_size)
-
- def forward(self, article, segment_id, cls_id):
-
- # print(article.device)
- # print(segment_id.device)
- # print(cls_id.device)
-
- input_mask = 1 - (article == 0).long()
- mask_cls = 1 - (cls_id == -1).long()
- assert input_mask.size() == article.size()
- assert mask_cls.size() == cls_id.size()
-
- bert_out = self.encoder(article, token_type_ids=segment_id, attention_mask=input_mask)
- bert_out = bert_out[0][-1] # last layer
-
- sent_emb = bert_out[torch.arange(bert_out.size(0)).unsqueeze(1), cls_id]
- sent_emb = sent_emb * mask_cls.unsqueeze(-1).float()
- assert sent_emb.size() == (article.size(0), cls_id.size(1), self.hidden_size) # [batch_size, seq_len, hidden_size]
-
- sent_scores = self.decoder(sent_emb, mask_cls) # [batch_size, seq_len]
- assert sent_scores.size() == (article.size(0), cls_id.size(1))
-
- return {'pred': sent_scores, 'mask': mask_cls}
|