|
|
@@ -39,14 +39,14 @@ class BiLSTMCRF(BaseModel): |
|
|
|
self.embed = get_embeddings(embed) |
|
|
|
|
|
|
|
if num_layers>1: |
|
|
|
self.lstm = LSTM(embed.embedding_dim, num_layers=num_layers, hidden_size=hidden_size, bidirectional=True, |
|
|
|
self.lstm = LSTM(self.embed.embedding_dim, num_layers=num_layers, hidden_size=hidden_size, bidirectional=True, |
|
|
|
batch_first=True, dropout=dropout) |
|
|
|
else: |
|
|
|
self.lstm = LSTM(embed.embedding_dim, num_layers=num_layers, hidden_size=hidden_size, bidirectional=True, |
|
|
|
self.lstm = LSTM(self.embed.embedding_dim, num_layers=num_layers, hidden_size=hidden_size, bidirectional=True, |
|
|
|
batch_first=True) |
|
|
|
|
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
self.fc = nn.Linear(hidden_size, num_classes) |
|
|
|
self.fc = nn.Linear(hidden_size*2, num_classes) |
|
|
|
|
|
|
|
trans = None |
|
|
|
if target_vocab is not None and encoding_type is not None: |
|
|
@@ -56,7 +56,7 @@ class BiLSTMCRF(BaseModel): |
|
|
|
|
|
|
|
def _forward(self, words, seq_len=None, target=None): |
|
|
|
words = self.embed(words) |
|
|
|
feats = self.lstm(words, seq_len=seq_len) |
|
|
|
feats, _ = self.lstm(words, seq_len=seq_len) |
|
|
|
feats = self.fc(feats) |
|
|
|
feats = self.dropout(feats) |
|
|
|
logits = F.log_softmax(feats, dim=-1) |
|
|
@@ -142,8 +142,6 @@ class SeqLabeling(BaseModel): |
|
|
|
""" |
|
|
|
x = x.float() |
|
|
|
y = y.long() |
|
|
|
assert x.shape[:2] == y.shape |
|
|
|
assert y.shape == self.mask.shape |
|
|
|
total_loss = self.crf(x, y, mask) |
|
|
|
return torch.mean(total_loss) |
|
|
|
|
|
|
@@ -195,36 +193,29 @@ class AdvSeqLabel(nn.Module): |
|
|
|
allowed_transitions=allowed_transitions(id2words, |
|
|
|
encoding_type=encoding_type)) |
|
|
|
|
|
|
|
def _decode(self, x): |
|
|
|
def _decode(self, x, mask): |
|
|
|
""" |
|
|
|
:param torch.FloatTensor x: [batch_size, max_len, tag_size] |
|
|
|
:param torch.ByteTensor mask: [batch_size, max_len] |
|
|
|
:return torch.LongTensor, [batch_size, max_len] |
|
|
|
""" |
|
|
|
tag_seq, _ = self.Crf.viterbi_decode(x, self.mask) |
|
|
|
tag_seq, _ = self.Crf.viterbi_decode(x, mask) |
|
|
|
return tag_seq |
|
|
|
|
|
|
|
def _internal_loss(self, x, y): |
|
|
|
def _internal_loss(self, x, y, mask): |
|
|
|
""" |
|
|
|
Negative log likelihood loss. |
|
|
|
:param x: Tensor, [batch_size, max_len, tag_size] |
|
|
|
:param y: Tensor, [batch_size, max_len] |
|
|
|
:param mask: Tensor, [batch_size, max_len] |
|
|
|
:return loss: a scalar Tensor |
|
|
|
|
|
|
|
""" |
|
|
|
x = x.float() |
|
|
|
y = y.long() |
|
|
|
assert x.shape[:2] == y.shape |
|
|
|
assert y.shape == self.mask.shape |
|
|
|
total_loss = self.Crf(x, y, self.mask) |
|
|
|
total_loss = self.Crf(x, y, mask) |
|
|
|
return torch.mean(total_loss) |
|
|
|
|
|
|
|
def _make_mask(self, x, seq_len): |
|
|
|
batch_size, max_len = x.size(0), x.size(1) |
|
|
|
mask = seq_len_to_mask(seq_len) |
|
|
|
mask = mask.view(batch_size, max_len) |
|
|
|
mask = mask.to(x).float() |
|
|
|
return mask |
|
|
|
|
|
|
|
def _forward(self, words, seq_len, target=None): |
|
|
|
""" |
|
|
|
:param torch.LongTensor words: [batch_size, mex_len] |
|
|
@@ -236,15 +227,13 @@ class AdvSeqLabel(nn.Module): |
|
|
|
|
|
|
|
words = words.long() |
|
|
|
seq_len = seq_len.long() |
|
|
|
self.mask = self._make_mask(words, seq_len) |
|
|
|
|
|
|
|
# seq_len = seq_len.long() |
|
|
|
mask = seq_len_to_mask(seq_len, max_len=words.size(1)) |
|
|
|
|
|
|
|
target = target.long() if target is not None else None |
|
|
|
|
|
|
|
if next(self.parameters()).is_cuda: |
|
|
|
words = words.cuda() |
|
|
|
self.mask = self.mask.cuda() |
|
|
|
|
|
|
|
|
|
|
|
x = self.Embedding(words) |
|
|
|
x = self.norm1(x) |
|
|
|
# [batch_size, max_len, word_emb_dim] |
|
|
@@ -257,9 +246,9 @@ class AdvSeqLabel(nn.Module): |
|
|
|
x = self.drop(x) |
|
|
|
x = self.Linear2(x) |
|
|
|
if target is not None: |
|
|
|
return {"loss": self._internal_loss(x, target)} |
|
|
|
return {"loss": self._internal_loss(x, target, mask)} |
|
|
|
else: |
|
|
|
return {"pred": self._decode(x)} |
|
|
|
return {"pred": self._decode(x, mask)} |
|
|
|
|
|
|
|
def forward(self, words, seq_len, target): |
|
|
|
""" |
|
|
|