Browse Source

修复sequence labeling 测试报错

tags/v0.4.10
yh 5 years ago
parent
commit
dc2885c6bc
2 changed files with 31 additions and 27 deletions
  1. +15
    -26
      fastNLP/models/sequence_labeling.py
  2. +16
    -1
      test/models/test_sequence_labeling.py

+ 15
- 26
fastNLP/models/sequence_labeling.py View File

@@ -39,14 +39,14 @@ class BiLSTMCRF(BaseModel):
self.embed = get_embeddings(embed)

if num_layers>1:
self.lstm = LSTM(embed.embedding_dim, num_layers=num_layers, hidden_size=hidden_size, bidirectional=True,
self.lstm = LSTM(self.embed.embedding_dim, num_layers=num_layers, hidden_size=hidden_size, bidirectional=True,
batch_first=True, dropout=dropout)
else:
self.lstm = LSTM(embed.embedding_dim, num_layers=num_layers, hidden_size=hidden_size, bidirectional=True,
self.lstm = LSTM(self.embed.embedding_dim, num_layers=num_layers, hidden_size=hidden_size, bidirectional=True,
batch_first=True)

self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(hidden_size, num_classes)
self.fc = nn.Linear(hidden_size*2, num_classes)

trans = None
if target_vocab is not None and encoding_type is not None:
@@ -56,7 +56,7 @@ class BiLSTMCRF(BaseModel):

def _forward(self, words, seq_len=None, target=None):
words = self.embed(words)
feats = self.lstm(words, seq_len=seq_len)
feats, _ = self.lstm(words, seq_len=seq_len)
feats = self.fc(feats)
feats = self.dropout(feats)
logits = F.log_softmax(feats, dim=-1)
@@ -142,8 +142,6 @@ class SeqLabeling(BaseModel):
"""
x = x.float()
y = y.long()
assert x.shape[:2] == y.shape
assert y.shape == self.mask.shape
total_loss = self.crf(x, y, mask)
return torch.mean(total_loss)
@@ -195,36 +193,29 @@ class AdvSeqLabel(nn.Module):
allowed_transitions=allowed_transitions(id2words,
encoding_type=encoding_type))
def _decode(self, x):
def _decode(self, x, mask):
"""
:param torch.FloatTensor x: [batch_size, max_len, tag_size]
:param torch.ByteTensor mask: [batch_size, max_len]
:return torch.LongTensor, [batch_size, max_len]
"""
tag_seq, _ = self.Crf.viterbi_decode(x, self.mask)
tag_seq, _ = self.Crf.viterbi_decode(x, mask)
return tag_seq
def _internal_loss(self, x, y):
def _internal_loss(self, x, y, mask):
"""
Negative log likelihood loss.
:param x: Tensor, [batch_size, max_len, tag_size]
:param y: Tensor, [batch_size, max_len]
:param mask: Tensor, [batch_size, max_len]
:return loss: a scalar Tensor

"""
x = x.float()
y = y.long()
assert x.shape[:2] == y.shape
assert y.shape == self.mask.shape
total_loss = self.Crf(x, y, self.mask)
total_loss = self.Crf(x, y, mask)
return torch.mean(total_loss)
def _make_mask(self, x, seq_len):
batch_size, max_len = x.size(0), x.size(1)
mask = seq_len_to_mask(seq_len)
mask = mask.view(batch_size, max_len)
mask = mask.to(x).float()
return mask
def _forward(self, words, seq_len, target=None):
"""
:param torch.LongTensor words: [batch_size, mex_len]
@@ -236,15 +227,13 @@ class AdvSeqLabel(nn.Module):
words = words.long()
seq_len = seq_len.long()
self.mask = self._make_mask(words, seq_len)
# seq_len = seq_len.long()
mask = seq_len_to_mask(seq_len, max_len=words.size(1))

target = target.long() if target is not None else None
if next(self.parameters()).is_cuda:
words = words.cuda()
self.mask = self.mask.cuda()

x = self.Embedding(words)
x = self.norm1(x)
# [batch_size, max_len, word_emb_dim]
@@ -257,9 +246,9 @@ class AdvSeqLabel(nn.Module):
x = self.drop(x)
x = self.Linear2(x)
if target is not None:
return {"loss": self._internal_loss(x, target)}
return {"loss": self._internal_loss(x, target, mask)}
else:
return {"pred": self._decode(x)}
return {"pred": self._decode(x, mask)}
def forward(self, words, seq_len, target):
"""


+ 16
- 1
test/models/test_sequence_labeling.py View File

@@ -3,9 +3,24 @@
import unittest

from .model_runner import *
from fastNLP.models.sequence_labeling import SeqLabeling, AdvSeqLabel
from fastNLP.models.sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF
from fastNLP.core.losses import LossInForward

class TestBiLSTM(unittest.TestCase):
def test_case1(self):
# 测试能否正常运行CNN
init_emb = (VOCAB_SIZE, 30)
model = BiLSTMCRF(init_emb,
hidden_size=30,
num_classes=NUM_CLS)

data = RUNNER.prepare_pos_tagging_data()
data.set_input('target')
loss = LossInForward()
metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET, seq_len=C.INPUT_LEN)
RUNNER.run_model(model, data, loss, metric)


class TesSeqLabel(unittest.TestCase):
def test_case1(self):
# 测试能否正常运行CNN


Loading…
Cancel
Save