Browse Source

增加cnn的测试

tags/v0.4.10
yh 5 years ago
parent
commit
25565fe0c9
5 changed files with 46 additions and 25 deletions
  1. +12
    -2
      fastNLP/core/tester.py
  2. +0
    -2
      fastNLP/core/trainer.py
  3. +7
    -17
      fastNLP/models/cnn_text_classification.py
  4. +5
    -4
      fastNLP/models/sequence_modeling.py
  5. +22
    -0
      test/models/test_cnn.py

+ 12
- 2
fastNLP/core/tester.py View File

@@ -27,11 +27,13 @@ Example::
tester = Tester(dataset, model, metrics=AccuracyMetric())
eval_results = tester.test()

这里Metric的映射规律是和 :class:`fastNLP.Trainer` 中一致的,具体使用请参考 :doc:`trainer 模块<fastNLP.core.trainer>` 的1.3部分
这里Metric的映射规律是和 :class:`fastNLP.Trainer` 中一致的,具体使用请参考 :doc:`trainer 模块<fastNLP.core.trainer>` 的1.3部分
Tester在验证进行之前会调用model.eval()提示当前进入了evaluation阶段,即会关闭nn.Dropout()等,在验证结束之后会调用model.train()恢复到训练状态。


"""
import warnings

import torch
from torch import nn

@@ -72,6 +74,7 @@ class Tester(object):

5. None. 为None则不对模型进行任何处理,如果传入的model为torch.nn.DataParallel该值必须为None。

如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。
"""

@@ -90,6 +93,13 @@ class Tester(object):
self.batch_size = batch_size
self.verbose = verbose

# 如果是DataParallel将没有办法使用predict方法
if isinstance(self._model, nn.DataParallel):
if hasattr(self._model.module, 'predict') and not hasattr(self._model, 'predict'):
warnings.warn("Cannot use DataParallel to test your model, because your model offer predict() function,"
" while DataParallel has no predict() function.")
self._model = self._model.module

# check predict
if hasattr(self._model, 'predict'):
self._predict_func = self._model.predict


+ 0
- 2
fastNLP/core/trainer.py View File

@@ -462,7 +462,6 @@ class Trainer(object):
self.best_dev_perf = None
self.sampler = sampler if sampler is not None else RandomSampler()
self.prefetch = prefetch
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks)
self.n_steps = (len(self.train_data) // self.batch_size + int(
len(self.train_data) % self.batch_size != 0)) * self.n_epochs
@@ -492,7 +491,6 @@ class Trainer(object):
self.step = 0
self.start_time = None # start timestamp
print("callback_manager")
self.callback_manager = CallbackManager(env={"trainer": self},
callbacks=callbacks)


+ 7
- 17
fastNLP/models/cnn_text_classification.py View File

@@ -3,9 +3,9 @@

import torch
import torch.nn as nn
import numpy as np
from ..core.const import Const as C

import fastNLP.modules.encoder as encoder
from ..modules import encoder


class CNNText(torch.nn.Module):
@@ -18,7 +18,7 @@ class CNNText(torch.nn.Module):
:param int num_classes: 一共有多少类
:param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致
:param int,tuple(int) kernel_sizes: 输出channel的kernel大小。
:param int padding:
:param int padding: 对句子前后的pad的大小, 用0填充。
:param float dropout: Dropout的大小
"""

@@ -38,17 +38,7 @@ class CNNText(torch.nn.Module):
kernel_sizes=kernel_sizes,
padding=padding)
self.dropout = nn.Dropout(dropout)
self.fc = encoder.Linear(sum(kernel_nums), num_classes)

def init_embed(self, embed):
"""
加载预训练的模型
:param numpy.ndarray embed: vocab_size x embed_dim的embedding
:return:
"""
assert isinstance(embed, np.ndarray)
assert embed.shape == self.embed.embed.weight.shape
self.embed.embed.weight.data = torch.from_numpy(embed)
self.fc = nn.Linear(sum(kernel_nums), num_classes)

def forward(self, words, seq_len=None):
"""
@@ -61,7 +51,7 @@ class CNNText(torch.nn.Module):
x = self.conv_pool(x) # [N,L,C] -> [N,C]
x = self.dropout(x)
x = self.fc(x) # [N,C] -> [N, N_class]
return {'pred': x}
return {C.OUTPUT: x}

def predict(self, words, seq_len=None):
"""
@@ -71,5 +61,5 @@ class CNNText(torch.nn.Module):
:return predict: dict of torch.LongTensor, [batch_size, ]
"""
output = self(words, seq_len)
_, predict = output['pred'].max(dim=1)
return {'pred': predict}
_, predict = output[C.OUTPUT].max(dim=1)
return {C.OUTPUT: predict}

+ 5
- 4
fastNLP/models/sequence_modeling.py View File

@@ -4,7 +4,8 @@ from .base_model import BaseModel
from ..modules import decoder, encoder
from ..modules.decoder.CRF import allowed_transitions
from ..modules.utils import seq_mask

from ..core.const import Const as C
from torch import nn

class SeqLabeling(BaseModel):
"""
@@ -24,7 +25,7 @@ class SeqLabeling(BaseModel):

self.Embedding = encoder.embedding.Embedding(init_embed)
self.Rnn = encoder.lstm.LSTM(self.Embedding.embedding_dim, hidden_size)
self.Linear = encoder.linear.Linear(hidden_size, num_classes)
self.Linear = nn.Linear(hidden_size, num_classes)
self.Crf = decoder.CRF.ConditionalRandomField(num_classes)
self.mask = None

@@ -46,7 +47,7 @@ class SeqLabeling(BaseModel):
# [batch_size, max_len, hidden_size * direction]
x = self.Linear(x)
# [batch_size, max_len, num_classes]
return {"loss": self._internal_loss(x, target)}
return {C.LOSS: self._internal_loss(x, target)}

def predict(self, words, seq_len):
"""
@@ -65,7 +66,7 @@ class SeqLabeling(BaseModel):
x = self.Linear(x)
# [batch_size, max_len, num_classes]
pred = self._decode(x)
return {'pred': pred}
return {C.OUTPUT: pred}

def _internal_loss(self, x, y):
"""


+ 22
- 0
test/models/test_cnn.py View File

@@ -0,0 +1,22 @@

import unittest

from test.models.model_runner import *
from fastNLP.models.cnn_text_classification import CNNText


class TestCNNText(unittest.TestCase):
def test_case1(self):
# 测试能否正常运行CNN
init_emb = (VOCAB_SIZE, 30)
model = CNNText(init_emb,
NUM_CLS,
kernel_nums=(1, 3, 5),
kernel_sizes=(2, 2, 2),
padding=0,
dropout=0.5)
RUNNER.run_model_with_task(TEXT_CLS, model)


if __name__ == '__main__':
TestCNNText().test_case1()

Loading…
Cancel
Save