|
@@ -7,6 +7,7 @@ import torch |
|
|
from tensorboardX import SummaryWriter |
|
|
from tensorboardX import SummaryWriter |
|
|
from torch import nn |
|
|
from torch import nn |
|
|
from tqdm.autonotebook import tqdm |
|
|
from tqdm.autonotebook import tqdm |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
|
from fastNLP.core.batch import Batch |
|
|
from fastNLP.core.batch import Batch |
|
|
from fastNLP.core.dataset import DataSet |
|
|
from fastNLP.core.dataset import DataSet |
|
@@ -97,7 +98,8 @@ class Trainer(object): |
|
|
|
|
|
|
|
|
if check_code_level > -1: |
|
|
if check_code_level > -1: |
|
|
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, |
|
|
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, |
|
|
metric_key=metric_key, check_level=check_code_level) |
|
|
|
|
|
|
|
|
metric_key=metric_key, check_level=check_code_level, |
|
|
|
|
|
batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) |
|
|
|
|
|
|
|
|
self.train_data = train_data |
|
|
self.train_data = train_data |
|
|
self.dev_data = dev_data # If None, No validation. |
|
|
self.dev_data = dev_data # If None, No validation. |
|
@@ -113,8 +115,6 @@ class Trainer(object): |
|
|
self.best_metric_indicator = None |
|
|
self.best_metric_indicator = None |
|
|
self.sampler = sampler |
|
|
self.sampler = sampler |
|
|
|
|
|
|
|
|
self._model_device = model.parameters().__next__().device |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(optimizer, torch.optim.Optimizer): |
|
|
if isinstance(optimizer, torch.optim.Optimizer): |
|
|
self.optimizer = optimizer |
|
|
self.optimizer = optimizer |
|
|
else: |
|
|
else: |
|
@@ -123,6 +123,7 @@ class Trainer(object): |
|
|
self.use_tqdm = use_tqdm |
|
|
self.use_tqdm = use_tqdm |
|
|
if self.use_tqdm: |
|
|
if self.use_tqdm: |
|
|
tester_verbose = 0 |
|
|
tester_verbose = 0 |
|
|
|
|
|
self.print_every = abs(self.print_every) |
|
|
else: |
|
|
else: |
|
|
tester_verbose = 1 |
|
|
tester_verbose = 1 |
|
|
|
|
|
|
|
@@ -137,17 +138,44 @@ class Trainer(object): |
|
|
self.step = 0 |
|
|
self.step = 0 |
|
|
self.start_time = None # start timestamp |
|
|
self.start_time = None # start timestamp |
|
|
|
|
|
|
|
|
def train(self): |
|
|
|
|
|
"""Start Training. |
|
|
|
|
|
|
|
|
def train(self, load_best_model=True): |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
开始训练过程。主要有以下几个步骤 |
|
|
|
|
|
for epoch in range(num_epochs): |
|
|
|
|
|
(1) 使用Batch从DataSet中按批取出数据,并自动对DataSet中dtype为float, int的fields进行padding。并转换为Tensor。非 |
|
|
|
|
|
float,int类型的参数将不会被转换为Tensor,且不进行padding |
|
|
|
|
|
for batch_x, batch_y in Batch(DataSet): |
|
|
|
|
|
# batch_x中为设置为input的field |
|
|
|
|
|
# batch_y中为设置为target的field |
|
|
|
|
|
(2) 将batch_x的数据送入到model.forward函数中,并获取结果 |
|
|
|
|
|
(3) 将batch_y与model.forward的结果一并送入loss中计算loss |
|
|
|
|
|
(4) 获取到loss之后,进行反向求导并更新梯度 |
|
|
|
|
|
if dev_data is not None: |
|
|
|
|
|
根据metrics进行evaluation,并根据是否提供了save_path判断是否存储模型 |
|
|
|
|
|
|
|
|
|
|
|
:param load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现最好的 |
|
|
|
|
|
模型参数。 |
|
|
|
|
|
|
|
|
|
|
|
将会返回一个字典类型的数据, 内含以下内容: |
|
|
|
|
|
seconds: float, 表示训练时长 |
|
|
|
|
|
以下三个内容只有在提供了dev_data的情况下会有。 |
|
|
|
|
|
best_eval: Dict of Dict, 表示evaluation的结果 |
|
|
|
|
|
best_epoch: int,在第几个epoch取得的最佳值 |
|
|
|
|
|
best_step: int, 在第几个step(batch)更新取得的最佳值 |
|
|
|
|
|
|
|
|
|
|
|
return dict: |
|
|
""" |
|
|
""" |
|
|
|
|
|
results = {} |
|
|
try: |
|
|
try: |
|
|
if torch.cuda.is_available() and self.use_cuda: |
|
|
if torch.cuda.is_available() and self.use_cuda: |
|
|
self.model = self.model.cuda() |
|
|
self.model = self.model.cuda() |
|
|
|
|
|
self._model_device = self.model.parameters().__next__().device |
|
|
|
|
|
|
|
|
self._mode(self.model, is_test=False) |
|
|
self._mode(self.model, is_test=False) |
|
|
|
|
|
|
|
|
self.start_time = str(datetime.now().strftime('%Y-%m-%d %H-%M-%S')) |
|
|
self.start_time = str(datetime.now().strftime('%Y-%m-%d %H-%M-%S')) |
|
|
|
|
|
start_time = time.time() |
|
|
print("training epochs started " + self.start_time, flush=True) |
|
|
print("training epochs started " + self.start_time, flush=True) |
|
|
if self.save_path is None: |
|
|
if self.save_path is None: |
|
|
class psudoSW: |
|
|
class psudoSW: |
|
@@ -165,26 +193,37 @@ class Trainer(object): |
|
|
self._tqdm_train() |
|
|
self._tqdm_train() |
|
|
else: |
|
|
else: |
|
|
self._print_train() |
|
|
self._print_train() |
|
|
|
|
|
|
|
|
|
|
|
if self.dev_data is not None: |
|
|
|
|
|
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + |
|
|
|
|
|
self.tester._format_eval_results(self.best_dev_perf),) |
|
|
|
|
|
results['best_eval'] = self.best_dev_perf |
|
|
|
|
|
results['best_epoch'] = self.best_dev_epoch |
|
|
|
|
|
results['best_step'] = self.best_dev_step |
|
|
|
|
|
if load_best_model: |
|
|
|
|
|
model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]) |
|
|
|
|
|
self._load_model(self.model, model_name) |
|
|
|
|
|
print("Reloaded the best model.") |
|
|
finally: |
|
|
finally: |
|
|
self._summary_writer.close() |
|
|
self._summary_writer.close() |
|
|
del self._summary_writer |
|
|
del self._summary_writer |
|
|
|
|
|
results['seconds'] = round(time.time() - start_time, 2) |
|
|
|
|
|
|
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def _tqdm_train(self): |
|
|
def _tqdm_train(self): |
|
|
self.step = 0 |
|
|
self.step = 0 |
|
|
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, |
|
|
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, |
|
|
as_numpy=False) |
|
|
as_numpy=False) |
|
|
total_steps = data_iterator.num_batches*self.n_epochs |
|
|
total_steps = data_iterator.num_batches*self.n_epochs |
|
|
epoch = 1 |
|
|
|
|
|
with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: |
|
|
with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: |
|
|
ava_loss = 0 |
|
|
|
|
|
|
|
|
avg_loss = 0 |
|
|
for epoch in range(1, self.n_epochs+1): |
|
|
for epoch in range(1, self.n_epochs+1): |
|
|
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) |
|
|
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) |
|
|
for batch_x, batch_y in data_iterator: |
|
|
for batch_x, batch_y in data_iterator: |
|
|
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) |
|
|
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) |
|
|
prediction = self._data_forward(self.model, batch_x) |
|
|
prediction = self._data_forward(self.model, batch_x) |
|
|
loss = self._compute_loss(prediction, batch_y) |
|
|
loss = self._compute_loss(prediction, batch_y) |
|
|
ava_loss += loss.item() |
|
|
|
|
|
|
|
|
avg_loss += loss.item() |
|
|
self._grad_backward(loss) |
|
|
self._grad_backward(loss) |
|
|
self._update() |
|
|
self._update() |
|
|
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) |
|
|
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) |
|
@@ -194,18 +233,18 @@ class Trainer(object): |
|
|
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) |
|
|
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) |
|
|
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) |
|
|
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) |
|
|
if (self.step+1) % self.print_every == 0: |
|
|
if (self.step+1) % self.print_every == 0: |
|
|
pbar.set_postfix_str("loss:{0:<6.5f}".format(ava_loss / self.print_every)) |
|
|
|
|
|
ava_loss = 0 |
|
|
|
|
|
pbar.update(1) |
|
|
|
|
|
|
|
|
pbar.set_postfix_str("loss:{0:<6.5f}".format(avg_loss / self.print_every)) |
|
|
|
|
|
avg_loss = 0 |
|
|
|
|
|
pbar.update(self.print_every) |
|
|
self.step += 1 |
|
|
self.step += 1 |
|
|
if self.validate_every > 0 and self.step % self.validate_every == 0 \ |
|
|
if self.validate_every > 0 and self.step % self.validate_every == 0 \ |
|
|
and self.dev_data is not None: |
|
|
and self.dev_data is not None: |
|
|
eval_res = self._do_validation() |
|
|
|
|
|
|
|
|
eval_res = self._do_validation(epoch=epoch, step=self.step) |
|
|
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ |
|
|
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ |
|
|
self.tester._format_eval_results(eval_res) |
|
|
self.tester._format_eval_results(eval_res) |
|
|
pbar.write(eval_str) |
|
|
pbar.write(eval_str) |
|
|
if self.validate_every < 0 and self.dev_data: |
|
|
if self.validate_every < 0 and self.dev_data: |
|
|
eval_res = self._do_validation() |
|
|
|
|
|
|
|
|
eval_res = self._do_validation(epoch=epoch, step=self.step) |
|
|
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ |
|
|
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ |
|
|
self.tester._format_eval_results(eval_res) |
|
|
self.tester._format_eval_results(eval_res) |
|
|
pbar.write(eval_str) |
|
|
pbar.write(eval_str) |
|
@@ -244,25 +283,29 @@ class Trainer(object): |
|
|
|
|
|
|
|
|
if (self.validate_every > 0 and self.step % self.validate_every == 0 and |
|
|
if (self.validate_every > 0 and self.step % self.validate_every == 0 and |
|
|
self.dev_data is not None): |
|
|
self.dev_data is not None): |
|
|
self._do_validation() |
|
|
|
|
|
|
|
|
self._do_validation(epoch=epoch, step=self.step) |
|
|
|
|
|
|
|
|
self.step += 1 |
|
|
self.step += 1 |
|
|
|
|
|
|
|
|
# validate_every override validation at end of epochs |
|
|
# validate_every override validation at end of epochs |
|
|
if self.dev_data and self.validate_every <= 0: |
|
|
if self.dev_data and self.validate_every <= 0: |
|
|
self._do_validation() |
|
|
|
|
|
|
|
|
self._do_validation(epoch=epoch, step=self.step) |
|
|
epoch += 1 |
|
|
epoch += 1 |
|
|
|
|
|
|
|
|
def _do_validation(self): |
|
|
|
|
|
|
|
|
def _do_validation(self, epoch, step): |
|
|
res = self.tester.test() |
|
|
res = self.tester.test() |
|
|
for name, metric in res.items(): |
|
|
for name, metric in res.items(): |
|
|
for metric_key, metric_val in metric.items(): |
|
|
for metric_key, metric_val in metric.items(): |
|
|
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, |
|
|
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, |
|
|
global_step=self.step) |
|
|
global_step=self.step) |
|
|
if self.save_path is not None and self._better_eval_result(res): |
|
|
|
|
|
metric_key = self.metric_key if self.metric_key is not None else "" |
|
|
|
|
|
self._save_model(self.model, |
|
|
|
|
|
"best_" + "_".join([self.model.__class__.__name__, metric_key, self.start_time])) |
|
|
|
|
|
|
|
|
if self._better_eval_result(res): |
|
|
|
|
|
if self.save_path is not None: |
|
|
|
|
|
self._save_model(self.model, |
|
|
|
|
|
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) |
|
|
|
|
|
|
|
|
|
|
|
self.best_dev_perf = res |
|
|
|
|
|
self.best_dev_epoch = epoch |
|
|
|
|
|
self.best_dev_step = step |
|
|
return res |
|
|
return res |
|
|
|
|
|
|
|
|
def _mode(self, model, is_test=False): |
|
|
def _mode(self, model, is_test=False): |
|
@@ -317,6 +360,15 @@ class Trainer(object): |
|
|
else: |
|
|
else: |
|
|
torch.save(model, model_name) |
|
|
torch.save(model, model_name) |
|
|
|
|
|
|
|
|
|
|
|
def _load_model(self, model, model_name, only_param=False): |
|
|
|
|
|
if self.save_path is not None: |
|
|
|
|
|
model_name = os.path.join(self.save_path, model_name) |
|
|
|
|
|
if only_param: |
|
|
|
|
|
states = torch.save(model.state_dict(), model_name) |
|
|
|
|
|
else: |
|
|
|
|
|
states = torch.save(model, model_name).state_dict() |
|
|
|
|
|
model.load_state_dict(states) |
|
|
|
|
|
|
|
|
def _better_eval_result(self, metrics): |
|
|
def _better_eval_result(self, metrics): |
|
|
"""Check if the current epoch yields better validation results. |
|
|
"""Check if the current epoch yields better validation results. |
|
|
|
|
|
|
|
@@ -344,6 +396,21 @@ class Trainer(object): |
|
|
DEFAULT_CHECK_BATCH_SIZE = 2 |
|
|
DEFAULT_CHECK_BATCH_SIZE = 2 |
|
|
DEFAULT_CHECK_NUM_BATCH = 2 |
|
|
DEFAULT_CHECK_NUM_BATCH = 2 |
|
|
|
|
|
|
|
|
|
|
|
def _get_value_info(_dict): |
|
|
|
|
|
# given a dict value, return information about this dict's value. Return list of str |
|
|
|
|
|
strs = [] |
|
|
|
|
|
for key, value in _dict.items(): |
|
|
|
|
|
_str = '' |
|
|
|
|
|
if isinstance(value, torch.Tensor): |
|
|
|
|
|
_str += "\t{}: (1)type:torch.Tensor (2)dtype:{}, (3)shape:{} ".format(key, |
|
|
|
|
|
value.dtype, value.size()) |
|
|
|
|
|
elif isinstance(value, np.ndarray): |
|
|
|
|
|
_str += "\t{}: (1)type:numpy.ndarray (2)dtype:{}, (3)shape:{} ".format(key, |
|
|
|
|
|
value.dtype, value.shape) |
|
|
|
|
|
else: |
|
|
|
|
|
_str += "\t{}: type:{}".format(key, type(value)) |
|
|
|
|
|
strs.append(_str) |
|
|
|
|
|
return strs |
|
|
|
|
|
|
|
|
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, |
|
|
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, |
|
|
dev_data=None, metric_key=None, |
|
|
dev_data=None, metric_key=None, |
|
@@ -356,8 +423,24 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ |
|
|
_move_dict_value_to_device(batch_x, batch_y, device=model_devcie) |
|
|
_move_dict_value_to_device(batch_x, batch_y, device=model_devcie) |
|
|
# forward check |
|
|
# forward check |
|
|
if batch_count==0: |
|
|
if batch_count==0: |
|
|
|
|
|
info_str = "" |
|
|
|
|
|
input_fields = _get_value_info(batch_x) |
|
|
|
|
|
target_fields = _get_value_info(batch_y) |
|
|
|
|
|
if len(input_fields)>0: |
|
|
|
|
|
info_str += "input fields after batch(if batch size is {}):\n".format(batch_size) |
|
|
|
|
|
info_str += "\n".join(input_fields) |
|
|
|
|
|
info_str += '\n' |
|
|
|
|
|
else: |
|
|
|
|
|
raise RuntimeError("There is no input field.") |
|
|
|
|
|
if len(target_fields)>0: |
|
|
|
|
|
info_str += "target fields after batch(if batch size is {}):\n".format(batch_size) |
|
|
|
|
|
info_str += "\n".join(target_fields) |
|
|
|
|
|
info_str += '\n' |
|
|
|
|
|
else: |
|
|
|
|
|
info_str += 'There is no target field.' |
|
|
|
|
|
print(info_str) |
|
|
_check_forward_error(forward_func=model.forward, dataset=dataset, |
|
|
_check_forward_error(forward_func=model.forward, dataset=dataset, |
|
|
batch_x=batch_x, check_level=check_level) |
|
|
|
|
|
|
|
|
batch_x=batch_x, check_level=check_level) |
|
|
|
|
|
|
|
|
refined_batch_x = _build_args(model.forward, **batch_x) |
|
|
refined_batch_x = _build_args(model.forward, **batch_x) |
|
|
pred_dict = model(**refined_batch_x) |
|
|
pred_dict = model(**refined_batch_x) |
|
|