@@ -27,11 +27,13 @@ Example:: | |||||
tester = Tester(dataset, model, metrics=AccuracyMetric()) | tester = Tester(dataset, model, metrics=AccuracyMetric()) | ||||
eval_results = tester.test() | eval_results = tester.test() | ||||
这里Metric的映射规律是和 :class:`fastNLP.Trainer` 中一致的,具体使用请参考 :doc:`trainer 模块<fastNLP.core.trainer>` 的1.3部分 | |||||
这里Metric的映射规律是和 :class:`fastNLP.Trainer` 中一致的,具体使用请参考 :doc:`trainer 模块<fastNLP.core.trainer>` 的1.3部分。 | |||||
Tester在验证进行之前会调用model.eval()提示当前进入了evaluation阶段,即会关闭nn.Dropout()等,在验证结束之后会调用model.train()恢复到训练状态。 | |||||
""" | """ | ||||
import warnings | |||||
import torch | import torch | ||||
from torch import nn | from torch import nn | ||||
@@ -72,6 +74,7 @@ class Tester(object): | |||||
5. None. 为None则不对模型进行任何处理,如果传入的model为torch.nn.DataParallel该值必须为None。 | 5. None. 为None则不对模型进行任何处理,如果传入的model为torch.nn.DataParallel该值必须为None。 | ||||
如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。 | |||||
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 | :param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 | ||||
""" | """ | ||||
@@ -90,6 +93,13 @@ class Tester(object): | |||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
self.verbose = verbose | self.verbose = verbose | ||||
# 如果是DataParallel将没有办法使用predict方法 | |||||
if isinstance(self._model, nn.DataParallel): | |||||
if hasattr(self._model.module, 'predict') and not hasattr(self._model, 'predict'): | |||||
warnings.warn("Cannot use DataParallel to test your model, because your model offer predict() function," | |||||
" while DataParallel has no predict() function.") | |||||
self._model = self._model.module | |||||
# check predict | # check predict | ||||
if hasattr(self._model, 'predict'): | if hasattr(self._model, 'predict'): | ||||
self._predict_func = self._model.predict | self._predict_func = self._model.predict | ||||
@@ -462,7 +462,6 @@ class Trainer(object): | |||||
self.best_dev_perf = None | self.best_dev_perf = None | ||||
self.sampler = sampler if sampler is not None else RandomSampler() | self.sampler = sampler if sampler is not None else RandomSampler() | ||||
self.prefetch = prefetch | self.prefetch = prefetch | ||||
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | |||||
self.n_steps = (len(self.train_data) // self.batch_size + int( | self.n_steps = (len(self.train_data) // self.batch_size + int( | ||||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | len(self.train_data) % self.batch_size != 0)) * self.n_epochs | ||||
@@ -492,7 +491,6 @@ class Trainer(object): | |||||
self.step = 0 | self.step = 0 | ||||
self.start_time = None # start timestamp | self.start_time = None # start timestamp | ||||
print("callback_manager") | |||||
self.callback_manager = CallbackManager(env={"trainer": self}, | self.callback_manager = CallbackManager(env={"trainer": self}, | ||||
callbacks=callbacks) | callbacks=callbacks) | ||||
@@ -3,9 +3,9 @@ | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
import numpy as np | |||||
from ..core.const import Const as C | |||||
import fastNLP.modules.encoder as encoder | |||||
from ..modules import encoder | |||||
class CNNText(torch.nn.Module): | class CNNText(torch.nn.Module): | ||||
@@ -18,7 +18,7 @@ class CNNText(torch.nn.Module): | |||||
:param int num_classes: 一共有多少类 | :param int num_classes: 一共有多少类 | ||||
:param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致 | :param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致 | ||||
:param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 | :param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 | ||||
:param int padding: | |||||
:param int padding: 对句子前后的pad的大小, 用0填充。 | |||||
:param float dropout: Dropout的大小 | :param float dropout: Dropout的大小 | ||||
""" | """ | ||||
@@ -38,17 +38,7 @@ class CNNText(torch.nn.Module): | |||||
kernel_sizes=kernel_sizes, | kernel_sizes=kernel_sizes, | ||||
padding=padding) | padding=padding) | ||||
self.dropout = nn.Dropout(dropout) | self.dropout = nn.Dropout(dropout) | ||||
self.fc = encoder.Linear(sum(kernel_nums), num_classes) | |||||
def init_embed(self, embed): | |||||
""" | |||||
加载预训练的模型 | |||||
:param numpy.ndarray embed: vocab_size x embed_dim的embedding | |||||
:return: | |||||
""" | |||||
assert isinstance(embed, np.ndarray) | |||||
assert embed.shape == self.embed.embed.weight.shape | |||||
self.embed.embed.weight.data = torch.from_numpy(embed) | |||||
self.fc = nn.Linear(sum(kernel_nums), num_classes) | |||||
def forward(self, words, seq_len=None): | def forward(self, words, seq_len=None): | ||||
""" | """ | ||||
@@ -61,7 +51,7 @@ class CNNText(torch.nn.Module): | |||||
x = self.conv_pool(x) # [N,L,C] -> [N,C] | x = self.conv_pool(x) # [N,L,C] -> [N,C] | ||||
x = self.dropout(x) | x = self.dropout(x) | ||||
x = self.fc(x) # [N,C] -> [N, N_class] | x = self.fc(x) # [N,C] -> [N, N_class] | ||||
return {'pred': x} | |||||
return {C.OUTPUT: x} | |||||
def predict(self, words, seq_len=None): | def predict(self, words, seq_len=None): | ||||
""" | """ | ||||
@@ -71,5 +61,5 @@ class CNNText(torch.nn.Module): | |||||
:return predict: dict of torch.LongTensor, [batch_size, ] | :return predict: dict of torch.LongTensor, [batch_size, ] | ||||
""" | """ | ||||
output = self(words, seq_len) | output = self(words, seq_len) | ||||
_, predict = output['pred'].max(dim=1) | |||||
return {'pred': predict} | |||||
_, predict = output[C.OUTPUT].max(dim=1) | |||||
return {C.OUTPUT: predict} |
@@ -4,7 +4,8 @@ from .base_model import BaseModel | |||||
from ..modules import decoder, encoder | from ..modules import decoder, encoder | ||||
from ..modules.decoder.CRF import allowed_transitions | from ..modules.decoder.CRF import allowed_transitions | ||||
from ..modules.utils import seq_mask | from ..modules.utils import seq_mask | ||||
from ..core.const import Const as C | |||||
from torch import nn | |||||
class SeqLabeling(BaseModel): | class SeqLabeling(BaseModel): | ||||
""" | """ | ||||
@@ -24,7 +25,7 @@ class SeqLabeling(BaseModel): | |||||
self.Embedding = encoder.embedding.Embedding(init_embed) | self.Embedding = encoder.embedding.Embedding(init_embed) | ||||
self.Rnn = encoder.lstm.LSTM(self.Embedding.embedding_dim, hidden_size) | self.Rnn = encoder.lstm.LSTM(self.Embedding.embedding_dim, hidden_size) | ||||
self.Linear = encoder.linear.Linear(hidden_size, num_classes) | |||||
self.Linear = nn.Linear(hidden_size, num_classes) | |||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | ||||
self.mask = None | self.mask = None | ||||
@@ -46,7 +47,7 @@ class SeqLabeling(BaseModel): | |||||
# [batch_size, max_len, hidden_size * direction] | # [batch_size, max_len, hidden_size * direction] | ||||
x = self.Linear(x) | x = self.Linear(x) | ||||
# [batch_size, max_len, num_classes] | # [batch_size, max_len, num_classes] | ||||
return {"loss": self._internal_loss(x, target)} | |||||
return {C.LOSS: self._internal_loss(x, target)} | |||||
def predict(self, words, seq_len): | def predict(self, words, seq_len): | ||||
""" | """ | ||||
@@ -65,7 +66,7 @@ class SeqLabeling(BaseModel): | |||||
x = self.Linear(x) | x = self.Linear(x) | ||||
# [batch_size, max_len, num_classes] | # [batch_size, max_len, num_classes] | ||||
pred = self._decode(x) | pred = self._decode(x) | ||||
return {'pred': pred} | |||||
return {C.OUTPUT: pred} | |||||
def _internal_loss(self, x, y): | def _internal_loss(self, x, y): | ||||
""" | """ | ||||
@@ -0,0 +1,22 @@ | |||||
import unittest | |||||
from test.models.model_runner import * | |||||
from fastNLP.models.cnn_text_classification import CNNText | |||||
class TestCNNText(unittest.TestCase): | |||||
def test_case1(self): | |||||
# 测试能否正常运行CNN | |||||
init_emb = (VOCAB_SIZE, 30) | |||||
model = CNNText(init_emb, | |||||
NUM_CLS, | |||||
kernel_nums=(1, 3, 5), | |||||
kernel_sizes=(2, 2, 2), | |||||
padding=0, | |||||
dropout=0.5) | |||||
RUNNER.run_model_with_task(TEXT_CLS, model) | |||||
if __name__ == '__main__': | |||||
TestCNNText().test_case1() |