From 6d36dbe7fb0358e70b87f343b5533a964245563d Mon Sep 17 00:00:00 2001 From: yh_cc Date: Thu, 9 May 2019 00:15:25 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E5=96=84=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/models/__init__.py | 2 +- ...uence_modeling.py => sequence_labeling.py} | 44 ++++++++----------- reproduction/POS_tagging/train_pos_tag.py | 2 +- test/data_for_tests/word2vec_test.txt | 7 +++ test/io/test_embed_loader.py | 16 ++++++- ...cnn.py => test_cnn_text_classification.py} | 6 +-- test/models/test_sequence_labeling.py | 36 +++++++++++++++ 7 files changed, 78 insertions(+), 35 deletions(-) rename fastNLP/models/{sequence_modeling.py => sequence_labeling.py} (85%) create mode 100644 test/data_for_tests/word2vec_test.txt rename test/models/{test_cnn.py => test_cnn_text_classification.py} (83%) create mode 100644 test/models/test_sequence_labeling.py diff --git a/fastNLP/models/__init__.py b/fastNLP/models/__init__.py index 59200773..f0d84b1c 100644 --- a/fastNLP/models/__init__.py +++ b/fastNLP/models/__init__.py @@ -8,6 +8,6 @@ from .bert import BertForMultipleChoice, BertForQuestionAnswering, BertForSequen BertForTokenClassification from .biaffine_parser import BiaffineParser, GraphParser from .cnn_text_classification import CNNText -from .sequence_modeling import SeqLabeling, AdvSeqLabel +from .sequence_labeling import SeqLabeling, AdvSeqLabel from .snli import ESIM from .star_transformer import STSeqCls, STNLICls, STSeqLabel diff --git a/fastNLP/models/sequence_modeling.py b/fastNLP/models/sequence_labeling.py similarity index 85% rename from fastNLP/models/sequence_modeling.py rename to fastNLP/models/sequence_labeling.py index ffa24940..880bd8a8 100644 --- a/fastNLP/models/sequence_modeling.py +++ b/fastNLP/models/sequence_labeling.py @@ -43,7 +43,7 @@ class SeqLabeling(BaseModel): x = self.Embedding(words) # [batch_size, max_len, word_emb_dim] - x = self.Rnn(x) + x,_ = self.Rnn(x, seq_len) # [batch_size, max_len, hidden_size * direction] x = self.Linear(x) # [batch_size, max_len, num_classes] @@ -55,13 +55,13 @@ class SeqLabeling(BaseModel): :param torch.LongTensor words: [batch_size, max_len] :param torch.LongTensor seq_len: [batch_size,] - :return: + :return: {'pred': xx}, [batch_size, max_len] """ self.mask = self._make_mask(words, seq_len) x = self.Embedding(words) # [batch_size, max_len, word_emb_dim] - x = self.Rnn(x) + x, _ = self.Rnn(x, seq_len) # [batch_size, max_len, hidden_size * direction] x = self.Linear(x) # [batch_size, max_len, num_classes] @@ -93,13 +93,13 @@ class SeqLabeling(BaseModel): def _decode(self, x): """ :param torch.FloatTensor x: [batch_size, max_len, tag_size] - :return prediction: list of [decode path(list)] + :return prediction: [batch_size, max_len] """ - tag_seq, _ = self.Crf.viterbi_decode(x, self.mask, unpad=True) + tag_seq, _ = self.Crf.viterbi_decode(x, self.mask) return tag_seq -class AdvSeqLabel: +class AdvSeqLabel(nn.Module): """ 更复杂的Sequence Labelling模型。结构为Embedding, LayerNorm, 双向LSTM(两层),FC,LayerNorm,DropOut,FC,CRF。 """ @@ -115,17 +115,19 @@ class AdvSeqLabel: :param dict id2words: tag id转为其tag word的表。用于在CRF解码时防止解出非法的顺序,比如'BMES'这个标签规范中,'S' 不能出现在'B'之后。这里也支持类似与'B-NN',即'-'前为标签类型的指示,后面为具体的tag的情况。这里不但会保证 'B-NN'后面不为'S-NN'还会保证'B-NN'后面不会出现'M-xx'(任何非'M-NN'和'E-NN'的情况。) - :param str encoding_type: 支持"BIO", "BMES", "BEMSO"。 + :param str encoding_type: 支持"BIO", "BMES", "BEMSO", 只有在id2words不为None的情况游泳。 """ + super().__init__() + self.Embedding = encoder.embedding.Embedding(init_embed) self.norm1 = torch.nn.LayerNorm(self.Embedding.embedding_dim) - self.Rnn = torch.nn.LSTM(input_size=self.Embedding.embedding_dim, hidden_size=hidden_size, num_layers=2, dropout=dropout, + self.Rnn = encoder.LSTM(input_size=self.Embedding.embedding_dim, hidden_size=hidden_size, num_layers=2, dropout=dropout, bidirectional=True, batch_first=True) - self.Linear1 = encoder.Linear(hidden_size * 2, hidden_size * 2 // 3) + self.Linear1 = nn.Linear(hidden_size * 2, hidden_size * 2 // 3) self.norm2 = torch.nn.LayerNorm(hidden_size * 2 // 3) self.relu = torch.nn.LeakyReLU() self.drop = torch.nn.Dropout(dropout) - self.Linear2 = encoder.Linear(hidden_size * 2 // 3, num_classes) + self.Linear2 = nn.Linear(hidden_size * 2 // 3, num_classes) if id2words is None: self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False) @@ -137,9 +139,9 @@ class AdvSeqLabel: def _decode(self, x): """ :param torch.FloatTensor x: [batch_size, max_len, tag_size] - :return prediction: list of [decode path(list)] + :return torch.LongTensor, [batch_size, max_len] """ - tag_seq, _ = self.Crf.viterbi_decode(x, self.mask, unpad=True) + tag_seq, _ = self.Crf.viterbi_decode(x, self.mask) return tag_seq def _internal_loss(self, x, y): @@ -176,31 +178,20 @@ class AdvSeqLabel: words = words.long() seq_len = seq_len.long() self.mask = self._make_mask(words, seq_len) - sent_len, idx_sort = torch.sort(seq_len, descending=True) - _, idx_unsort = torch.sort(idx_sort, descending=False) # seq_len = seq_len.long() target = target.long() if target is not None else None if next(self.parameters()).is_cuda: words = words.cuda() - idx_sort = idx_sort.cuda() - idx_unsort = idx_unsort.cuda() self.mask = self.mask.cuda() x = self.Embedding(words) x = self.norm1(x) # [batch_size, max_len, word_emb_dim] - sent_variable = x[idx_sort] - sent_packed = torch.nn.utils.rnn.pack_padded_sequence(sent_variable, sent_len, batch_first=True) - - x, _ = self.Rnn(sent_packed) + x, _ = self.Rnn(x, seq_len=seq_len) - sent_output = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True)[0] - x = sent_output[idx_unsort] - - x = x.contiguous() x = self.Linear1(x) x = self.norm2(x) x = self.relu(x) @@ -225,6 +216,7 @@ class AdvSeqLabel: :param torch.LongTensor words: [batch_size, mex_len] :param torch.LongTensor seq_len:[batch_size, ] - :return: [list1, list2, ...], 内部每个list为一个路径,已经unpad了。 + :return {'pred':}, value是torch.LongTensor, [batch_size, max_len] + """ - return self._forward(words, seq_len, ) + return self._forward(words, seq_len) diff --git a/reproduction/POS_tagging/train_pos_tag.py b/reproduction/POS_tagging/train_pos_tag.py index 06547701..ccf7aa1e 100644 --- a/reproduction/POS_tagging/train_pos_tag.py +++ b/reproduction/POS_tagging/train_pos_tag.py @@ -13,7 +13,7 @@ from fastNLP.api.processor import SeqLenProcessor, VocabIndexerProcessor, SetInp from fastNLP.core.metrics import SpanFPreRecMetric from fastNLP.core.trainer import Trainer from fastNLP.io.config_io import ConfigLoader, ConfigSection -from fastNLP.models.sequence_modeling import AdvSeqLabel +from fastNLP.models.sequence_labeling import AdvSeqLabel from fastNLP.io.dataset_loader import ConllxDataLoader from fastNLP.api.processor import ModelProcessor, Index2WordProcessor diff --git a/test/data_for_tests/word2vec_test.txt b/test/data_for_tests/word2vec_test.txt new file mode 100644 index 00000000..c16170f2 --- /dev/null +++ b/test/data_for_tests/word2vec_test.txt @@ -0,0 +1,7 @@ +5 50 +the 0.418 0.24968 -0.41242 0.1217 0.34527 -0.044457 -0.49688 -0.17862 -0.00066023 -0.6566 0.27843 -0.14767 -0.55677 0.14658 -0.0095095 0.011658 0.10204 -0.12792 -0.8443 -0.12181 -0.016801 -0.33279 -0.1552 -0.23131 -0.19181 -1.8823 -0.76746 0.099051 -0.42125 -0.19526 4.0071 -0.18594 -0.52287 -0.31681 0.00059213 0.0074449 0.17778 -0.15897 0.012041 -0.054223 -0.29871 -0.15749 -0.34758 -0.045637 -0.44251 0.18785 0.0027849 -0.18411 -0.11514 -0.78581 +of 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 0.18157 -0.52393 0.10381 -0.17566 0.078852 -0.36216 -0.11829 -0.83336 0.11917 -0.16605 0.061555 -0.012719 -0.56623 0.013616 0.22851 -0.14396 -0.067549 -0.38157 -0.23698 -1.7037 -0.86692 -0.26704 -0.2589 0.1767 3.8676 -0.1613 -0.13273 -0.68881 0.18444 0.0052464 -0.33874 -0.078956 0.24185 0.36576 -0.34727 0.28483 0.075693 -0.062178 -0.38988 0.22902 -0.21617 -0.22562 -0.093918 -0.80375 +to 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 -0.41376 0.13228 -0.29847 -0.085253 0.17118 0.22419 -0.10046 -0.43653 0.33418 0.67846 0.057204 -0.34448 -0.42785 -0.43275 0.55963 0.10032 0.18677 -0.26854 0.037334 -2.0932 0.22171 -0.39868 0.20912 -0.55725 3.8826 0.47466 -0.95658 -0.37788 0.20869 -0.32752 0.12751 0.088359 0.16351 -0.21634 -0.094375 0.018324 0.21048 -0.03088 -0.19722 0.082279 -0.09434 -0.073297 -0.064699 -0.26044 +and 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 -0.51332 -0.47368 -0.33075 -0.13834 0.2702 0.30938 -0.45012 -0.4127 -0.09932 0.038085 0.029749 0.10076 -0.25058 -0.51818 0.34558 0.44922 0.48791 -0.080866 -0.10121 -1.3777 -0.10866 -0.23201 0.012839 -0.46508 3.8463 0.31362 0.13643 -0.52244 0.3302 0.33707 -0.35601 0.32431 0.12041 0.3512 -0.069043 0.36885 0.25168 -0.24517 0.25381 0.1367 -0.31178 -0.6321 -0.25028 -0.38097 +in 0.33042 0.24995 -0.60874 0.10923 0.036372 0.151 -0.55083 -0.074239 -0.092307 -0.32821 0.09598 -0.82269 -0.36717 -0.67009 0.42909 0.016496 -0.23573 0.12864 -1.0953 0.43334 0.57067 -0.1036 0.20422 0.078308 -0.42795 -1.7984 -0.27865 0.11954 -0.12689 0.031744 3.8631 -0.17786 -0.082434 -0.62698 0.26497 -0.057185 -0.073521 0.46103 0.30862 0.12498 -0.48609 -0.0080272 0.031184 -0.36576 -0.42699 0.42164 -0.11666 -0.50703 -0.027273 -0.53285 +a 0.21705 0.46515 -0.46757 0.10082 1.0135 0.74845 -0.53104 -0.26256 0.16812 0.13182 -0.24909 -0.44185 -0.21739 0.51004 0.13448 -0.43141 -0.03123 0.20674 -0.78138 -0.20148 -0.097401 0.16088 -0.61836 -0.18504 -0.12461 -2.2526 -0.22321 0.5043 0.32257 0.15313 3.9636 -0.71365 -0.67012 0.28388 0.21738 0.14433 0.25926 0.23434 0.4274 -0.44451 0.13813 0.36973 -0.64289 0.024142 -0.039315 -0.26037 0.12017 -0.043782 0.41013 0.1796 \ No newline at end of file diff --git a/test/io/test_embed_loader.py b/test/io/test_embed_loader.py index 05a127a9..d43a00fe 100644 --- a/test/io/test_embed_loader.py +++ b/test/io/test_embed_loader.py @@ -3,7 +3,9 @@ import numpy as np from fastNLP import Vocabulary from fastNLP.io import EmbedLoader - +import os +from fastNLP.io.dataset_loader import SSTLoader +from fastNLP.core.const import Const as C class TestEmbedLoader(unittest.TestCase): def test_load_with_vocab(self): @@ -36,4 +38,14 @@ class TestEmbedLoader(unittest.TestCase): self.assertEqual(w_m.shape, (7, 50)) self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 7) for word in words: - self.assertIn(word, vocab) \ No newline at end of file + self.assertIn(word, vocab) + + def test_read_all_glove(self): + pass + # 这是可以运行的,但是总数少于行数,应该是由于glove有重复的word + # path = '/where/to/read/full/glove' + # init_embed, vocab = EmbedLoader.load_without_vocab(path, error='strict') + # print(init_embed.shape) + # print(init_embed.mean()) + # print(np.isnan(init_embed).sum()) + # print(len(vocab)) diff --git a/test/models/test_cnn.py b/test/models/test_cnn_text_classification.py similarity index 83% rename from test/models/test_cnn.py rename to test/models/test_cnn_text_classification.py index 61b75703..b83b7bad 100644 --- a/test/models/test_cnn.py +++ b/test/models/test_cnn_text_classification.py @@ -1,7 +1,7 @@ import unittest -from test.models.model_runner import * +from .model_runner import * from fastNLP.models.cnn_text_classification import CNNText @@ -16,7 +16,3 @@ class TestCNNText(unittest.TestCase): padding=0, dropout=0.5) RUNNER.run_model_with_task(TEXT_CLS, model) - - -if __name__ == '__main__': - TestCNNText().test_case1() \ No newline at end of file diff --git a/test/models/test_sequence_labeling.py b/test/models/test_sequence_labeling.py new file mode 100644 index 00000000..3a70e381 --- /dev/null +++ b/test/models/test_sequence_labeling.py @@ -0,0 +1,36 @@ + + +import unittest + +from .model_runner import * +from fastNLP.models.sequence_labeling import SeqLabeling, AdvSeqLabel +from fastNLP.core.losses import LossInForward + +class TesSeqLabel(unittest.TestCase): + def test_case1(self): + # 测试能否正常运行CNN + init_emb = (VOCAB_SIZE, 30) + model = SeqLabeling(init_emb, + hidden_size=30, + num_classes=NUM_CLS) + + data = RUNNER.prepare_pos_tagging_data() + data.set_input('target') + loss = LossInForward() + metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET, seq_len=C.INPUT_LEN) + RUNNER.run_model(model, data, loss, metric) + + +class TesAdvSeqLabel(unittest.TestCase): + def test_case1(self): + # 测试能否正常运行CNN + init_emb = (VOCAB_SIZE, 30) + model = AdvSeqLabel(init_emb, + hidden_size=30, + num_classes=NUM_CLS) + + data = RUNNER.prepare_pos_tagging_data() + data.set_input('target') + loss = LossInForward() + metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET, seq_len=C.INPUT_LEN) + RUNNER.run_model(model, data, loss, metric) \ No newline at end of file