From 25565fe0c931d732d2243222d07fb6d82845d9ae Mon Sep 17 00:00:00 2001 From: yh Date: Tue, 7 May 2019 20:42:00 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0cnn=E7=9A=84=E6=B5=8B?= =?UTF-8?q?=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/tester.py | 14 +++++++++++-- fastNLP/core/trainer.py | 2 -- fastNLP/models/cnn_text_classification.py | 24 +++++++---------------- fastNLP/models/sequence_modeling.py | 9 +++++---- test/models/test_cnn.py | 22 +++++++++++++++++++++ 5 files changed, 46 insertions(+), 25 deletions(-) create mode 100644 test/models/test_cnn.py diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 6eaa5add..7b6fdda5 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -27,11 +27,13 @@ Example:: tester = Tester(dataset, model, metrics=AccuracyMetric()) eval_results = tester.test() -这里Metric的映射规律是和 :class:`fastNLP.Trainer` 中一致的,具体使用请参考 :doc:`trainer 模块` 的1.3部分 - +这里Metric的映射规律是和 :class:`fastNLP.Trainer` 中一致的,具体使用请参考 :doc:`trainer 模块` 的1.3部分。 +Tester在验证进行之前会调用model.eval()提示当前进入了evaluation阶段,即会关闭nn.Dropout()等,在验证结束之后会调用model.train()恢复到训练状态。 """ +import warnings + import torch from torch import nn @@ -72,6 +74,7 @@ class Tester(object): 5. None. 为None则不对模型进行任何处理,如果传入的model为torch.nn.DataParallel该值必须为None。 + 如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。 :param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 """ @@ -90,6 +93,13 @@ class Tester(object): self.batch_size = batch_size self.verbose = verbose + # 如果是DataParallel将没有办法使用predict方法 + if isinstance(self._model, nn.DataParallel): + if hasattr(self._model.module, 'predict') and not hasattr(self._model, 'predict'): + warnings.warn("Cannot use DataParallel to test your model, because your model offer predict() function," + " while DataParallel has no predict() function.") + self._model = self._model.module + # check predict if hasattr(self._model, 'predict'): self._predict_func = self._model.predict diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index cb2ff821..ed76ee8e 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -462,7 +462,6 @@ class Trainer(object): self.best_dev_perf = None self.sampler = sampler if sampler is not None else RandomSampler() self.prefetch = prefetch - self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) self.n_steps = (len(self.train_data) // self.batch_size + int( len(self.train_data) % self.batch_size != 0)) * self.n_epochs @@ -492,7 +491,6 @@ class Trainer(object): self.step = 0 self.start_time = None # start timestamp - print("callback_manager") self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) diff --git a/fastNLP/models/cnn_text_classification.py b/fastNLP/models/cnn_text_classification.py index 7d7c3878..5df4e62a 100644 --- a/fastNLP/models/cnn_text_classification.py +++ b/fastNLP/models/cnn_text_classification.py @@ -3,9 +3,9 @@ import torch import torch.nn as nn -import numpy as np +from ..core.const import Const as C -import fastNLP.modules.encoder as encoder +from ..modules import encoder class CNNText(torch.nn.Module): @@ -18,7 +18,7 @@ class CNNText(torch.nn.Module): :param int num_classes: 一共有多少类 :param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致 :param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 - :param int padding: + :param int padding: 对句子前后的pad的大小, 用0填充。 :param float dropout: Dropout的大小 """ @@ -38,17 +38,7 @@ class CNNText(torch.nn.Module): kernel_sizes=kernel_sizes, padding=padding) self.dropout = nn.Dropout(dropout) - self.fc = encoder.Linear(sum(kernel_nums), num_classes) - - def init_embed(self, embed): - """ - 加载预训练的模型 - :param numpy.ndarray embed: vocab_size x embed_dim的embedding - :return: - """ - assert isinstance(embed, np.ndarray) - assert embed.shape == self.embed.embed.weight.shape - self.embed.embed.weight.data = torch.from_numpy(embed) + self.fc = nn.Linear(sum(kernel_nums), num_classes) def forward(self, words, seq_len=None): """ @@ -61,7 +51,7 @@ class CNNText(torch.nn.Module): x = self.conv_pool(x) # [N,L,C] -> [N,C] x = self.dropout(x) x = self.fc(x) # [N,C] -> [N, N_class] - return {'pred': x} + return {C.OUTPUT: x} def predict(self, words, seq_len=None): """ @@ -71,5 +61,5 @@ class CNNText(torch.nn.Module): :return predict: dict of torch.LongTensor, [batch_size, ] """ output = self(words, seq_len) - _, predict = output['pred'].max(dim=1) - return {'pred': predict} + _, predict = output[C.OUTPUT].max(dim=1) + return {C.OUTPUT: predict} diff --git a/fastNLP/models/sequence_modeling.py b/fastNLP/models/sequence_modeling.py index e076910f..ffa24940 100644 --- a/fastNLP/models/sequence_modeling.py +++ b/fastNLP/models/sequence_modeling.py @@ -4,7 +4,8 @@ from .base_model import BaseModel from ..modules import decoder, encoder from ..modules.decoder.CRF import allowed_transitions from ..modules.utils import seq_mask - +from ..core.const import Const as C +from torch import nn class SeqLabeling(BaseModel): """ @@ -24,7 +25,7 @@ class SeqLabeling(BaseModel): self.Embedding = encoder.embedding.Embedding(init_embed) self.Rnn = encoder.lstm.LSTM(self.Embedding.embedding_dim, hidden_size) - self.Linear = encoder.linear.Linear(hidden_size, num_classes) + self.Linear = nn.Linear(hidden_size, num_classes) self.Crf = decoder.CRF.ConditionalRandomField(num_classes) self.mask = None @@ -46,7 +47,7 @@ class SeqLabeling(BaseModel): # [batch_size, max_len, hidden_size * direction] x = self.Linear(x) # [batch_size, max_len, num_classes] - return {"loss": self._internal_loss(x, target)} + return {C.LOSS: self._internal_loss(x, target)} def predict(self, words, seq_len): """ @@ -65,7 +66,7 @@ class SeqLabeling(BaseModel): x = self.Linear(x) # [batch_size, max_len, num_classes] pred = self._decode(x) - return {'pred': pred} + return {C.OUTPUT: pred} def _internal_loss(self, x, y): """ diff --git a/test/models/test_cnn.py b/test/models/test_cnn.py new file mode 100644 index 00000000..61b75703 --- /dev/null +++ b/test/models/test_cnn.py @@ -0,0 +1,22 @@ + +import unittest + +from test.models.model_runner import * +from fastNLP.models.cnn_text_classification import CNNText + + +class TestCNNText(unittest.TestCase): + def test_case1(self): + # 测试能否正常运行CNN + init_emb = (VOCAB_SIZE, 30) + model = CNNText(init_emb, + NUM_CLS, + kernel_nums=(1, 3, 5), + kernel_sizes=(2, 2, 2), + padding=0, + dropout=0.5) + RUNNER.run_model_with_task(TEXT_CLS, model) + + +if __name__ == '__main__': + TestCNNText().test_case1() \ No newline at end of file