@@ -27,11 +27,13 @@ Example:: | |||
tester = Tester(dataset, model, metrics=AccuracyMetric()) | |||
eval_results = tester.test() | |||
这里Metric的映射规律是和 :class:`fastNLP.Trainer` 中一致的,具体使用请参考 :doc:`trainer 模块<fastNLP.core.trainer>` 的1.3部分 | |||
这里Metric的映射规律是和 :class:`fastNLP.Trainer` 中一致的,具体使用请参考 :doc:`trainer 模块<fastNLP.core.trainer>` 的1.3部分。 | |||
Tester在验证进行之前会调用model.eval()提示当前进入了evaluation阶段,即会关闭nn.Dropout()等,在验证结束之后会调用model.train()恢复到训练状态。 | |||
""" | |||
import warnings | |||
import torch | |||
from torch import nn | |||
@@ -72,6 +74,7 @@ class Tester(object): | |||
5. None. 为None则不对模型进行任何处理,如果传入的model为torch.nn.DataParallel该值必须为None。 | |||
如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。 | |||
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 | |||
""" | |||
@@ -90,6 +93,13 @@ class Tester(object): | |||
self.batch_size = batch_size | |||
self.verbose = verbose | |||
# 如果是DataParallel将没有办法使用predict方法 | |||
if isinstance(self._model, nn.DataParallel): | |||
if hasattr(self._model.module, 'predict') and not hasattr(self._model, 'predict'): | |||
warnings.warn("Cannot use DataParallel to test your model, because your model offer predict() function," | |||
" while DataParallel has no predict() function.") | |||
self._model = self._model.module | |||
# check predict | |||
if hasattr(self._model, 'predict'): | |||
self._predict_func = self._model.predict | |||
@@ -462,7 +462,6 @@ class Trainer(object): | |||
self.best_dev_perf = None | |||
self.sampler = sampler if sampler is not None else RandomSampler() | |||
self.prefetch = prefetch | |||
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | |||
self.n_steps = (len(self.train_data) // self.batch_size + int( | |||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | |||
@@ -492,7 +491,6 @@ class Trainer(object): | |||
self.step = 0 | |||
self.start_time = None # start timestamp | |||
print("callback_manager") | |||
self.callback_manager = CallbackManager(env={"trainer": self}, | |||
callbacks=callbacks) | |||
@@ -3,9 +3,9 @@ | |||
import torch | |||
import torch.nn as nn | |||
import numpy as np | |||
from ..core.const import Const as C | |||
import fastNLP.modules.encoder as encoder | |||
from ..modules import encoder | |||
class CNNText(torch.nn.Module): | |||
@@ -18,7 +18,7 @@ class CNNText(torch.nn.Module): | |||
:param int num_classes: 一共有多少类 | |||
:param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致 | |||
:param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 | |||
:param int padding: | |||
:param int padding: 对句子前后的pad的大小, 用0填充。 | |||
:param float dropout: Dropout的大小 | |||
""" | |||
@@ -38,17 +38,7 @@ class CNNText(torch.nn.Module): | |||
kernel_sizes=kernel_sizes, | |||
padding=padding) | |||
self.dropout = nn.Dropout(dropout) | |||
self.fc = encoder.Linear(sum(kernel_nums), num_classes) | |||
def init_embed(self, embed): | |||
""" | |||
加载预训练的模型 | |||
:param numpy.ndarray embed: vocab_size x embed_dim的embedding | |||
:return: | |||
""" | |||
assert isinstance(embed, np.ndarray) | |||
assert embed.shape == self.embed.embed.weight.shape | |||
self.embed.embed.weight.data = torch.from_numpy(embed) | |||
self.fc = nn.Linear(sum(kernel_nums), num_classes) | |||
def forward(self, words, seq_len=None): | |||
""" | |||
@@ -61,7 +51,7 @@ class CNNText(torch.nn.Module): | |||
x = self.conv_pool(x) # [N,L,C] -> [N,C] | |||
x = self.dropout(x) | |||
x = self.fc(x) # [N,C] -> [N, N_class] | |||
return {'pred': x} | |||
return {C.OUTPUT: x} | |||
def predict(self, words, seq_len=None): | |||
""" | |||
@@ -71,5 +61,5 @@ class CNNText(torch.nn.Module): | |||
:return predict: dict of torch.LongTensor, [batch_size, ] | |||
""" | |||
output = self(words, seq_len) | |||
_, predict = output['pred'].max(dim=1) | |||
return {'pred': predict} | |||
_, predict = output[C.OUTPUT].max(dim=1) | |||
return {C.OUTPUT: predict} |
@@ -4,7 +4,8 @@ from .base_model import BaseModel | |||
from ..modules import decoder, encoder | |||
from ..modules.decoder.CRF import allowed_transitions | |||
from ..modules.utils import seq_mask | |||
from ..core.const import Const as C | |||
from torch import nn | |||
class SeqLabeling(BaseModel): | |||
""" | |||
@@ -24,7 +25,7 @@ class SeqLabeling(BaseModel): | |||
self.Embedding = encoder.embedding.Embedding(init_embed) | |||
self.Rnn = encoder.lstm.LSTM(self.Embedding.embedding_dim, hidden_size) | |||
self.Linear = encoder.linear.Linear(hidden_size, num_classes) | |||
self.Linear = nn.Linear(hidden_size, num_classes) | |||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | |||
self.mask = None | |||
@@ -46,7 +47,7 @@ class SeqLabeling(BaseModel): | |||
# [batch_size, max_len, hidden_size * direction] | |||
x = self.Linear(x) | |||
# [batch_size, max_len, num_classes] | |||
return {"loss": self._internal_loss(x, target)} | |||
return {C.LOSS: self._internal_loss(x, target)} | |||
def predict(self, words, seq_len): | |||
""" | |||
@@ -65,7 +66,7 @@ class SeqLabeling(BaseModel): | |||
x = self.Linear(x) | |||
# [batch_size, max_len, num_classes] | |||
pred = self._decode(x) | |||
return {'pred': pred} | |||
return {C.OUTPUT: pred} | |||
def _internal_loss(self, x, y): | |||
""" | |||
@@ -0,0 +1,22 @@ | |||
import unittest | |||
from test.models.model_runner import * | |||
from fastNLP.models.cnn_text_classification import CNNText | |||
class TestCNNText(unittest.TestCase): | |||
def test_case1(self): | |||
# 测试能否正常运行CNN | |||
init_emb = (VOCAB_SIZE, 30) | |||
model = CNNText(init_emb, | |||
NUM_CLS, | |||
kernel_nums=(1, 3, 5), | |||
kernel_sizes=(2, 2, 2), | |||
padding=0, | |||
dropout=0.5) | |||
RUNNER.run_model_with_task(TEXT_CLS, model) | |||
if __name__ == '__main__': | |||
TestCNNText().test_case1() |