Browse Source

1. optimzer中调整默认参数大小

2. 移动Tester中device获取位置
3. trainer中
    (1) train()返回一个dict,并重新加载最佳模型
    (2) check_code时返回batch_x中每个数据的形状等信息
tags/v0.3.0
yh 5 years ago
parent
commit
30a0ff4d90
4 changed files with 111 additions and 27 deletions
  1. +4
    -3
      fastNLP/core/optimizer.py
  2. +1
    -1
      fastNLP/core/tester.py
  3. +105
    -22
      fastNLP/core/trainer.py
  4. +1
    -1
      fastNLP/core/utils.py

+ 4
- 3
fastNLP/core/optimizer.py View File

@@ -10,7 +10,7 @@ class Optimizer(object):


class SGD(Optimizer):
def __init__(self, lr=0.01, momentum=0, model_params=None):
def __init__(self, lr=0.001, momentum=0, model_params=None):
"""

:param float lr: learning rate. Default: 0.01
@@ -30,7 +30,7 @@ class SGD(Optimizer):


class Adam(Optimizer):
def __init__(self, lr=0.01, weight_decay=0, model_params=None):
def __init__(self, lr=0.001, weight_decay=0, betas=(0.9, 0.999), eps=1e-8, amsgrad=False, model_params=None):
"""

:param float lr: learning rate
@@ -39,7 +39,8 @@ class Adam(Optimizer):
"""
if not isinstance(lr, float):
raise TypeError("learning rate has to be float.")
super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay)
super(Adam, self).__init__(model_params, lr=lr, betas=betas, eps=eps, amsgrad=amsgrad,
weight_decay=weight_decay)

def construct_from_pytorch(self, model_params):
if self.model_params is None:


+ 1
- 1
fastNLP/core/tester.py View File

@@ -31,12 +31,12 @@ class Tester(object):
self.use_cuda = use_cuda
self.batch_size = batch_size
self.verbose = verbose
self._model_device = model.parameters().__next__().device

if torch.cuda.is_available() and self.use_cuda:
self._model = model.cuda()
else:
self._model = model
self._model_device = model.parameters().__next__().device

# check predict
if hasattr(self._model, 'predict'):


+ 105
- 22
fastNLP/core/trainer.py View File

@@ -7,6 +7,7 @@ import torch
from tensorboardX import SummaryWriter
from torch import nn
from tqdm.autonotebook import tqdm
import numpy as np

from fastNLP.core.batch import Batch
from fastNLP.core.dataset import DataSet
@@ -97,7 +98,8 @@ class Trainer(object):

if check_code_level > -1:
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data,
metric_key=metric_key, check_level=check_code_level)
metric_key=metric_key, check_level=check_code_level,
batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE))

self.train_data = train_data
self.dev_data = dev_data # If None, No validation.
@@ -113,8 +115,6 @@ class Trainer(object):
self.best_metric_indicator = None
self.sampler = sampler

self._model_device = model.parameters().__next__().device

if isinstance(optimizer, torch.optim.Optimizer):
self.optimizer = optimizer
else:
@@ -123,6 +123,7 @@ class Trainer(object):
self.use_tqdm = use_tqdm
if self.use_tqdm:
tester_verbose = 0
self.print_every = abs(self.print_every)
else:
tester_verbose = 1

@@ -137,17 +138,44 @@ class Trainer(object):
self.step = 0
self.start_time = None # start timestamp

def train(self):
"""Start Training.
def train(self, load_best_model=True):
"""

开始训练过程。主要有以下几个步骤
for epoch in range(num_epochs):
(1) 使用Batch从DataSet中按批取出数据,并自动对DataSet中dtype为float, int的fields进行padding。并转换为Tensor。非
float,int类型的参数将不会被转换为Tensor,且不进行padding
for batch_x, batch_y in Batch(DataSet):
# batch_x中为设置为input的field
# batch_y中为设置为target的field
(2) 将batch_x的数据送入到model.forward函数中,并获取结果
(3) 将batch_y与model.forward的结果一并送入loss中计算loss
(4) 获取到loss之后,进行反向求导并更新梯度
if dev_data is not None:
根据metrics进行evaluation,并根据是否提供了save_path判断是否存储模型

:param load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现最好的
模型参数。

将会返回一个字典类型的数据, 内含以下内容:
seconds: float, 表示训练时长
以下三个内容只有在提供了dev_data的情况下会有。
best_eval: Dict of Dict, 表示evaluation的结果
best_epoch: int,在第几个epoch取得的最佳值
best_step: int, 在第几个step(batch)更新取得的最佳值

return dict:
"""
results = {}
try:
if torch.cuda.is_available() and self.use_cuda:
self.model = self.model.cuda()
self._model_device = self.model.parameters().__next__().device

self._mode(self.model, is_test=False)

self.start_time = str(datetime.now().strftime('%Y-%m-%d %H-%M-%S'))
start_time = time.time()
print("training epochs started " + self.start_time, flush=True)
if self.save_path is None:
class psudoSW:
@@ -165,26 +193,37 @@ class Trainer(object):
self._tqdm_train()
else:
self._print_train()

if self.dev_data is not None:
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) +
self.tester._format_eval_results(self.best_dev_perf),)
results['best_eval'] = self.best_dev_perf
results['best_epoch'] = self.best_dev_epoch
results['best_step'] = self.best_dev_step
if load_best_model:
model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])
self._load_model(self.model, model_name)
print("Reloaded the best model.")
finally:
self._summary_writer.close()
del self._summary_writer
results['seconds'] = round(time.time() - start_time, 2)

return results

def _tqdm_train(self):
self.step = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler,
as_numpy=False)
total_steps = data_iterator.num_batches*self.n_epochs
epoch = 1
with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
ava_loss = 0
avg_loss = 0
for epoch in range(1, self.n_epochs+1):
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
for batch_x, batch_y in data_iterator:
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
prediction = self._data_forward(self.model, batch_x)
loss = self._compute_loss(prediction, batch_y)
ava_loss += loss.item()
avg_loss += loss.item()
self._grad_backward(loss)
self._update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step)
@@ -194,18 +233,18 @@ class Trainer(object):
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
if (self.step+1) % self.print_every == 0:
pbar.set_postfix_str("loss:{0:<6.5f}".format(ava_loss / self.print_every))
ava_loss = 0
pbar.update(1)
pbar.set_postfix_str("loss:{0:<6.5f}".format(avg_loss / self.print_every))
avg_loss = 0
pbar.update(self.print_every)
self.step += 1
if self.validate_every > 0 and self.step % self.validate_every == 0 \
and self.dev_data is not None:
eval_res = self._do_validation()
eval_res = self._do_validation(epoch=epoch, step=self.step)
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \
self.tester._format_eval_results(eval_res)
pbar.write(eval_str)
if self.validate_every < 0 and self.dev_data:
eval_res = self._do_validation()
eval_res = self._do_validation(epoch=epoch, step=self.step)
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \
self.tester._format_eval_results(eval_res)
pbar.write(eval_str)
@@ -244,25 +283,29 @@ class Trainer(object):

if (self.validate_every > 0 and self.step % self.validate_every == 0 and
self.dev_data is not None):
self._do_validation()
self._do_validation(epoch=epoch, step=self.step)

self.step += 1

# validate_every override validation at end of epochs
if self.dev_data and self.validate_every <= 0:
self._do_validation()
self._do_validation(epoch=epoch, step=self.step)
epoch += 1

def _do_validation(self):
def _do_validation(self, epoch, step):
res = self.tester.test()
for name, metric in res.items():
for metric_key, metric_val in metric.items():
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val,
global_step=self.step)
if self.save_path is not None and self._better_eval_result(res):
metric_key = self.metric_key if self.metric_key is not None else ""
self._save_model(self.model,
"best_" + "_".join([self.model.__class__.__name__, metric_key, self.start_time]))
if self._better_eval_result(res):
if self.save_path is not None:
self._save_model(self.model,
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]))

self.best_dev_perf = res
self.best_dev_epoch = epoch
self.best_dev_step = step
return res

def _mode(self, model, is_test=False):
@@ -317,6 +360,15 @@ class Trainer(object):
else:
torch.save(model, model_name)

def _load_model(self, model, model_name, only_param=False):
if self.save_path is not None:
model_name = os.path.join(self.save_path, model_name)
if only_param:
states = torch.save(model.state_dict(), model_name)
else:
states = torch.save(model, model_name).state_dict()
model.load_state_dict(states)

def _better_eval_result(self, metrics):
"""Check if the current epoch yields better validation results.

@@ -344,6 +396,21 @@ class Trainer(object):
DEFAULT_CHECK_BATCH_SIZE = 2
DEFAULT_CHECK_NUM_BATCH = 2

def _get_value_info(_dict):
# given a dict value, return information about this dict's value. Return list of str
strs = []
for key, value in _dict.items():
_str = ''
if isinstance(value, torch.Tensor):
_str += "\t{}: (1)type:torch.Tensor (2)dtype:{}, (3)shape:{} ".format(key,
value.dtype, value.size())
elif isinstance(value, np.ndarray):
_str += "\t{}: (1)type:numpy.ndarray (2)dtype:{}, (3)shape:{} ".format(key,
value.dtype, value.shape)
else:
_str += "\t{}: type:{}".format(key, type(value))
strs.append(_str)
return strs

def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE,
dev_data=None, metric_key=None,
@@ -356,8 +423,24 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
_move_dict_value_to_device(batch_x, batch_y, device=model_devcie)
# forward check
if batch_count==0:
info_str = ""
input_fields = _get_value_info(batch_x)
target_fields = _get_value_info(batch_y)
if len(input_fields)>0:
info_str += "input fields after batch(if batch size is {}):\n".format(batch_size)
info_str += "\n".join(input_fields)
info_str += '\n'
else:
raise RuntimeError("There is no input field.")
if len(target_fields)>0:
info_str += "target fields after batch(if batch size is {}):\n".format(batch_size)
info_str += "\n".join(target_fields)
info_str += '\n'
else:
info_str += 'There is no target field.'
print(info_str)
_check_forward_error(forward_func=model.forward, dataset=dataset,
batch_x=batch_x, check_level=check_level)
batch_x=batch_x, check_level=check_level)

refined_batch_x = _build_args(model.forward, **batch_x)
pred_dict = model(**refined_batch_x)


+ 1
- 1
fastNLP/core/utils.py View File

@@ -125,7 +125,7 @@ def _check_arg_dict_list(func, args):
input_args = set(input_arg_count.keys())
missing = list(require_args - input_args)
unused = list(input_args - all_args)
varargs = [] if not spect.varargs else [arg for arg in spect.varargs]
varargs = [] if not spect.varargs else [spect.varargs]
return CheckRes(missing=missing,
unused=unused,
duplicated=duplicated,


Loading…
Cancel
Save