|
-
-
- from torch import nn
- from fastNLP.modules import ConditionalRandomField, allowed_transitions
- import torch.nn.functional as F
-
- class BertCRF(nn.Module):
- def __init__(self, embed, tag_vocab, encoding_type='bio'):
- super().__init__()
- self.embed = embed
- self.fc = nn.Linear(self.embed.embed_size, len(tag_vocab))
- trans = allowed_transitions(tag_vocab, encoding_type=encoding_type, include_start_end=True)
- self.crf = ConditionalRandomField(len(tag_vocab), include_start_end_trans=True, allowed_transitions=trans)
-
- def _forward(self, words, target):
- mask = words.ne(0)
- words = self.embed(words)
- words = self.fc(words)
- logits = F.log_softmax(words, dim=-1)
- if target is not None:
- loss = self.crf(logits, target, mask)
- return {'loss': loss}
- else:
- paths, _ = self.crf.viterbi_decode(logits, mask)
- return {'pred': paths}
-
- def forward(self, words, target):
- return self._forward(words, target)
-
- def predict(self, words):
- return self._forward(words, None)
|