Browse Source

Merge branch 'dev0.5.0' of https://github.com/fastnlp/fastNLP into dev0.5.0

tags/v0.4.10
yh 6 years ago
parent
commit
a45f18cba5
4 changed files with 58 additions and 20 deletions
  1. +43
    -16
      fastNLP/core/tester.py
  2. +7
    -3
      fastNLP/core/trainer.py
  3. +7
    -0
      reproduction/text_classification/README.md
  4. +1
    -1
      reproduction/text_classification/train_char_cnn.py

+ 43
- 16
fastNLP/core/tester.py View File

@@ -32,9 +32,16 @@ Tester在验证进行之前会调用model.eval()提示当前进入了evaluation




""" """
import time

import torch import torch
import torch.nn as nn import torch.nn as nn


try:
from tqdm.auto import tqdm
except:
from .utils import _pseudo_tqdm as tqdm

from .batch import BatchIter, DataSetIter from .batch import BatchIter, DataSetIter
from .dataset import DataSet from .dataset import DataSet
from .metrics import _prepare_metrics from .metrics import _prepare_metrics
@@ -47,7 +54,7 @@ from .utils import _get_func_signature
from .utils import _get_model_device from .utils import _get_model_device
from .utils import _move_model_to_device from .utils import _move_model_to_device
from ._parallel_utils import _data_parallel_wrapper from ._parallel_utils import _data_parallel_wrapper
from fastNLP.core._parallel_utils import _model_contains_inner_module
from ._parallel_utils import _model_contains_inner_module
from functools import partial from functools import partial


__all__ = [ __all__ = [
@@ -80,9 +87,10 @@ class Tester(object):


如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。 如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 :param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。
:param bool use_tqdm: 是否使用tqdm来显示测试进度; 如果为False,则不会显示任何内容。
""" """
def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1):
def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True):
super(Tester, self).__init__() super(Tester, self).__init__()


if not isinstance(model, nn.Module): if not isinstance(model, nn.Module):
@@ -94,6 +102,7 @@ class Tester(object):
self._model = _move_model_to_device(model, device=device) self._model = _move_model_to_device(model, device=device)
self.batch_size = batch_size self.batch_size = batch_size
self.verbose = verbose self.verbose = verbose
self.use_tqdm = use_tqdm


if isinstance(data, DataSet): if isinstance(data, DataSet):
self.data_iterator = DataSetIter( self.data_iterator = DataSetIter(
@@ -141,21 +150,39 @@ class Tester(object):
eval_results = {} eval_results = {}
try: try:
with torch.no_grad(): with torch.no_grad():
for batch_x, batch_y in data_iterator:
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
pred_dict = self._data_forward(self._predict_func, batch_x)
if not isinstance(pred_dict, dict):
raise TypeError(f"The return value of {_get_func_signature(self._predict_func)} "
f"must be `dict`, got {type(pred_dict)}.")
if not self.use_tqdm:
from .utils import _pseudo_tqdm as inner_tqdm
else:
inner_tqdm = tqdm
with inner_tqdm(total=len(data_iterator), leave=False, dynamic_ncols=True) as pbar:
pbar.set_description_str(desc="Test")

start_time = time.time()

for batch_x, batch_y in data_iterator:
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
pred_dict = self._data_forward(self._predict_func, batch_x)
if not isinstance(pred_dict, dict):
raise TypeError(f"The return value of {_get_func_signature(self._predict_func)} "
f"must be `dict`, got {type(pred_dict)}.")
for metric in self.metrics:
metric(pred_dict, batch_y)

if self.use_tqdm:
pbar.update()

for metric in self.metrics: for metric in self.metrics:
metric(pred_dict, batch_y)
for metric in self.metrics:
eval_result = metric.get_metric()
if not isinstance(eval_result, dict):
raise TypeError(f"The return value of {_get_func_signature(metric.get_metric)} must be "
f"`dict`, got {type(eval_result)}")
metric_name = metric.__class__.__name__
eval_results[metric_name] = eval_result
eval_result = metric.get_metric()
if not isinstance(eval_result, dict):
raise TypeError(f"The return value of {_get_func_signature(metric.get_metric)} must be "
f"`dict`, got {type(eval_result)}")
metric_name = metric.__class__.__name__
eval_results[metric_name] = eval_result

end_time = time.time()
test_str = f'Evaluate data in {round(end_time - start_time, 2)} seconds!'
pbar.write(test_str)
pbar.close()
except _CheckError as e: except _CheckError as e:
prev_func_signature = _get_func_signature(self._predict_func) prev_func_signature = _get_func_signature(self._predict_func)
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature,


+ 7
- 3
fastNLP/core/trainer.py View File

@@ -352,7 +352,7 @@ from .utils import _move_dict_value_to_device
from .utils import _get_func_signature from .utils import _get_func_signature
from .utils import _get_model_device from .utils import _get_model_device
from .utils import _move_model_to_device from .utils import _move_model_to_device
from fastNLP.core._parallel_utils import _model_contains_inner_module
from ._parallel_utils import _model_contains_inner_module




class Trainer(object): class Trainer(object):
@@ -557,7 +557,8 @@ class Trainer(object):
metrics=self.metrics, metrics=self.metrics,
batch_size=self.batch_size, batch_size=self.batch_size,
device=None, # 由上面的部分处理device device=None, # 由上面的部分处理device
verbose=0)
verbose=0,
use_tqdm=self.use_tqdm)


self.step = 0 self.step = 0
self.start_time = None # start timestamp self.start_time = None # start timestamp
@@ -633,7 +634,7 @@ class Trainer(object):


def _train(self): def _train(self):
if not self.use_tqdm: if not self.use_tqdm:
from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm
from .utils import _pseudo_tqdm as inner_tqdm
else: else:
inner_tqdm = tqdm inner_tqdm = tqdm
self.step = 0 self.step = 0
@@ -859,8 +860,11 @@ def _get_value_info(_dict):
strs.append(_str) strs.append(_str)
return strs return strs



from numbers import Number from numbers import Number
from .batch import _to_tensor from .batch import _to_tensor


def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAULT_CHECK_BATCH_SIZE, def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAULT_CHECK_BATCH_SIZE,
dev_data=None, metric_key=None, check_level=0): dev_data=None, metric_key=None, check_level=0):
# check get_loss 方法 # check get_loss 方法


+ 7
- 0
reproduction/text_classification/README.md View File

@@ -11,6 +11,13 @@ LSTM+self_attention:论文链接[A Structured Self-attentive Sentence Embedding]


AWD-LSTM:论文链接[Regularizing and Optimizing LSTM Language Models](https://arxiv.org/pdf/1708.02182.pdf) AWD-LSTM:论文链接[Regularizing and Optimizing LSTM Language Models](https://arxiv.org/pdf/1708.02182.pdf)


#数据集来源
IMDB:http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
SST-2:https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8
SST:https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip
yelp_full:https://drive.google.com/drive/folders/0Bz8a_Dbh9Qhbfll6bVpmNUtUcFdjYmF2SEpmZUZUcVNiMUw1TWN6RDV3a0JHT3kxLVhVR2M
yelp_polarity:https://drive.google.com/drive/folders/0Bz8a_Dbh9Qhbfll6bVpmNUtUcFdjYmF2SEpmZUZUcVNiMUw1TWN6RDV3a0JHT3kxLVhVR2M

# 数据集及复现结果汇总 # 数据集及复现结果汇总


使用fastNLP复现的结果vs论文汇报结果(/前为fastNLP实现,后面为论文报道,-表示论文没有在该数据集上列出结果) 使用fastNLP复现的结果vs论文汇报结果(/前为fastNLP实现,后面为论文报道,-表示论文没有在该数据集上列出结果)


+ 1
- 1
reproduction/text_classification/train_char_cnn.py View File

@@ -203,7 +203,7 @@ callbacks.append(
def train(model,datainfo,loss,metrics,optimizer,num_epochs=100): def train(model,datainfo,loss,metrics,optimizer,num_epochs=100):
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss(target='target'),batch_size=ops.batch_size, trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss(target='target'),batch_size=ops.batch_size,
metrics=[metrics(target='target')], dev_data=datainfo.datasets['test'], device=[0,1,2], check_code_level=-1, metrics=[metrics(target='target')], dev_data=datainfo.datasets['test'], device=[0,1,2], check_code_level=-1,
n_epochs=num_epochs)
n_epochs=num_epochs,callbacks=callbacks)
print(trainer.train()) print(trainer.train())






Loading…
Cancel
Save