From dc2885c6bc3f57b5f4a3974921e12844f6142b09 Mon Sep 17 00:00:00 2001 From: yh Date: Sun, 1 Sep 2019 02:00:03 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dsequence=20labeling=20?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E6=8A=A5=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/models/sequence_labeling.py | 41 ++++++++++----------------- test/models/test_sequence_labeling.py | 17 ++++++++++- 2 files changed, 31 insertions(+), 27 deletions(-) diff --git a/fastNLP/models/sequence_labeling.py b/fastNLP/models/sequence_labeling.py index 0c573a90..6e839bea 100644 --- a/fastNLP/models/sequence_labeling.py +++ b/fastNLP/models/sequence_labeling.py @@ -39,14 +39,14 @@ class BiLSTMCRF(BaseModel): self.embed = get_embeddings(embed) if num_layers>1: - self.lstm = LSTM(embed.embedding_dim, num_layers=num_layers, hidden_size=hidden_size, bidirectional=True, + self.lstm = LSTM(self.embed.embedding_dim, num_layers=num_layers, hidden_size=hidden_size, bidirectional=True, batch_first=True, dropout=dropout) else: - self.lstm = LSTM(embed.embedding_dim, num_layers=num_layers, hidden_size=hidden_size, bidirectional=True, + self.lstm = LSTM(self.embed.embedding_dim, num_layers=num_layers, hidden_size=hidden_size, bidirectional=True, batch_first=True) self.dropout = nn.Dropout(dropout) - self.fc = nn.Linear(hidden_size, num_classes) + self.fc = nn.Linear(hidden_size*2, num_classes) trans = None if target_vocab is not None and encoding_type is not None: @@ -56,7 +56,7 @@ class BiLSTMCRF(BaseModel): def _forward(self, words, seq_len=None, target=None): words = self.embed(words) - feats = self.lstm(words, seq_len=seq_len) + feats, _ = self.lstm(words, seq_len=seq_len) feats = self.fc(feats) feats = self.dropout(feats) logits = F.log_softmax(feats, dim=-1) @@ -142,8 +142,6 @@ class SeqLabeling(BaseModel): """ x = x.float() y = y.long() - assert x.shape[:2] == y.shape - assert y.shape == self.mask.shape total_loss = self.crf(x, y, mask) return torch.mean(total_loss) @@ -195,36 +193,29 @@ class AdvSeqLabel(nn.Module): allowed_transitions=allowed_transitions(id2words, encoding_type=encoding_type)) - def _decode(self, x): + def _decode(self, x, mask): """ :param torch.FloatTensor x: [batch_size, max_len, tag_size] + :param torch.ByteTensor mask: [batch_size, max_len] :return torch.LongTensor, [batch_size, max_len] """ - tag_seq, _ = self.Crf.viterbi_decode(x, self.mask) + tag_seq, _ = self.Crf.viterbi_decode(x, mask) return tag_seq - def _internal_loss(self, x, y): + def _internal_loss(self, x, y, mask): """ Negative log likelihood loss. :param x: Tensor, [batch_size, max_len, tag_size] :param y: Tensor, [batch_size, max_len] + :param mask: Tensor, [batch_size, max_len] :return loss: a scalar Tensor """ x = x.float() y = y.long() - assert x.shape[:2] == y.shape - assert y.shape == self.mask.shape - total_loss = self.Crf(x, y, self.mask) + total_loss = self.Crf(x, y, mask) return torch.mean(total_loss) - def _make_mask(self, x, seq_len): - batch_size, max_len = x.size(0), x.size(1) - mask = seq_len_to_mask(seq_len) - mask = mask.view(batch_size, max_len) - mask = mask.to(x).float() - return mask - def _forward(self, words, seq_len, target=None): """ :param torch.LongTensor words: [batch_size, mex_len] @@ -236,15 +227,13 @@ class AdvSeqLabel(nn.Module): words = words.long() seq_len = seq_len.long() - self.mask = self._make_mask(words, seq_len) - - # seq_len = seq_len.long() + mask = seq_len_to_mask(seq_len, max_len=words.size(1)) + target = target.long() if target is not None else None if next(self.parameters()).is_cuda: words = words.cuda() - self.mask = self.mask.cuda() - + x = self.Embedding(words) x = self.norm1(x) # [batch_size, max_len, word_emb_dim] @@ -257,9 +246,9 @@ class AdvSeqLabel(nn.Module): x = self.drop(x) x = self.Linear2(x) if target is not None: - return {"loss": self._internal_loss(x, target)} + return {"loss": self._internal_loss(x, target, mask)} else: - return {"pred": self._decode(x)} + return {"pred": self._decode(x, mask)} def forward(self, words, seq_len, target): """ diff --git a/test/models/test_sequence_labeling.py b/test/models/test_sequence_labeling.py index 3a70e381..815d7047 100644 --- a/test/models/test_sequence_labeling.py +++ b/test/models/test_sequence_labeling.py @@ -3,9 +3,24 @@ import unittest from .model_runner import * -from fastNLP.models.sequence_labeling import SeqLabeling, AdvSeqLabel +from fastNLP.models.sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF from fastNLP.core.losses import LossInForward +class TestBiLSTM(unittest.TestCase): + def test_case1(self): + # 测试能否正常运行CNN + init_emb = (VOCAB_SIZE, 30) + model = BiLSTMCRF(init_emb, + hidden_size=30, + num_classes=NUM_CLS) + + data = RUNNER.prepare_pos_tagging_data() + data.set_input('target') + loss = LossInForward() + metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET, seq_len=C.INPUT_LEN) + RUNNER.run_model(model, data, loss, metric) + + class TesSeqLabel(unittest.TestCase): def test_case1(self): # 测试能否正常运行CNN