|
|
@@ -1,7 +1,7 @@ |
|
|
|
import os |
|
|
|
import time |
|
|
|
from datetime import datetime |
|
|
|
from datetime import timedelta |
|
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
import torch |
|
|
|
from tensorboardX import SummaryWriter |
|
|
@@ -12,6 +12,7 @@ from fastNLP.core.dataset import DataSet |
|
|
|
from fastNLP.core.losses import _prepare_losser |
|
|
|
from fastNLP.core.metrics import _prepare_metrics |
|
|
|
from fastNLP.core.optimizer import Adam |
|
|
|
from fastNLP.core.sampler import BaseSampler |
|
|
|
from fastNLP.core.sampler import RandomSampler |
|
|
|
from fastNLP.core.sampler import SequentialSampler |
|
|
|
from fastNLP.core.tester import Tester |
|
|
@@ -28,12 +29,10 @@ class Trainer(object): |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, |
|
|
|
validate_every=-1, |
|
|
|
dev_data=None, use_cuda=False, save_path=None, |
|
|
|
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, update_every=50, |
|
|
|
validate_every=-1, dev_data=None, use_cuda=False, save_path=None, |
|
|
|
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, |
|
|
|
metric_key=None, |
|
|
|
**kwargs): |
|
|
|
metric_key=None, sampler=RandomSampler()): |
|
|
|
""" |
|
|
|
|
|
|
|
:param DataSet train_data: the training data |
|
|
@@ -42,7 +41,7 @@ class Trainer(object): |
|
|
|
:param MetricBase or List[MetricBase] metrics: a metric object or a list of metrics |
|
|
|
:param int n_epochs: the number of training epochs |
|
|
|
:param int batch_size: batch size for training and validation |
|
|
|
:param int print_every: step interval to print next training information. Default: -1(no print). |
|
|
|
:param int update_every: step interval to print next training information. Default: -1(no print). |
|
|
|
:param int validate_every: step interval to do next validation. Default: -1(validate every epoch). |
|
|
|
:param DataSet dev_data: the validation data |
|
|
|
:param use_cuda: |
|
|
@@ -54,8 +53,7 @@ class Trainer(object): |
|
|
|
smaller, add a `-` character in front of the string. For example |
|
|
|
:: |
|
|
|
metric_key="-PPL" # language model gets better as perplexity gets smaller |
|
|
|
|
|
|
|
:param kwargs: |
|
|
|
:param sampler: method used to generate batch data. |
|
|
|
|
|
|
|
""" |
|
|
|
super(Trainer, self).__init__() |
|
|
@@ -90,6 +88,10 @@ class Trainer(object): |
|
|
|
# prepare loss |
|
|
|
losser = _prepare_losser(losser) |
|
|
|
|
|
|
|
# sampler check |
|
|
|
if not isinstance(sampler, BaseSampler): |
|
|
|
raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) |
|
|
|
|
|
|
|
if check_code_level > -1: |
|
|
|
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, |
|
|
|
metric_key=metric_key, check_level=check_code_level) |
|
|
@@ -103,9 +105,10 @@ class Trainer(object): |
|
|
|
self.batch_size = int(batch_size) |
|
|
|
self.use_cuda = bool(use_cuda) |
|
|
|
self.save_path = save_path |
|
|
|
self.print_every = int(print_every) |
|
|
|
self.print_every = int(update_every) |
|
|
|
self.validate_every = int(validate_every) |
|
|
|
self.best_metric_indicator = None |
|
|
|
self.sampler = sampler |
|
|
|
|
|
|
|
self._model_device = model.parameters().__next__().device |
|
|
|
|
|
|
@@ -119,10 +122,8 @@ class Trainer(object): |
|
|
|
data=self.dev_data, |
|
|
|
metrics=self.metrics, |
|
|
|
batch_size=self.batch_size, |
|
|
|
use_cuda=self.use_cuda) |
|
|
|
|
|
|
|
for k, v in kwargs.items(): |
|
|
|
setattr(self, k, v) |
|
|
|
use_cuda=self.use_cuda, |
|
|
|
verbose=0) |
|
|
|
|
|
|
|
self.step = 0 |
|
|
|
self.start_time = None # start timestamp |
|
|
@@ -140,8 +141,7 @@ class Trainer(object): |
|
|
|
|
|
|
|
self._mode(self.model, is_test=False) |
|
|
|
|
|
|
|
start = time.time() |
|
|
|
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) |
|
|
|
self.start_time = str(datetime.now().strftime('%Y-%m-%d %H:%M:%S')) |
|
|
|
print("training epochs started " + self.start_time) |
|
|
|
if self.save_path is None: |
|
|
|
class psudoSW: |
|
|
@@ -156,65 +156,81 @@ class Trainer(object): |
|
|
|
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) |
|
|
|
self._summary_writer = SummaryWriter(path) |
|
|
|
|
|
|
|
epoch = 1 |
|
|
|
while epoch <= self.n_epochs: |
|
|
|
|
|
|
|
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), |
|
|
|
as_numpy=False) |
|
|
|
|
|
|
|
self._train_epoch(data_iterator, self.model, epoch, start) |
|
|
|
self._tqdm_train() |
|
|
|
|
|
|
|
# validate_every override validation at end of epochs |
|
|
|
if self.dev_data and self.validate_every <= 0: |
|
|
|
self._do_validation() |
|
|
|
epoch += 1 |
|
|
|
finally: |
|
|
|
self._summary_writer.close() |
|
|
|
del self._summary_writer |
|
|
|
|
|
|
|
def _train_epoch(self, data_iterator, model, epoch, start): |
|
|
|
""" |
|
|
|
|
|
|
|
:param data_iterator: |
|
|
|
:param model: |
|
|
|
:param epoch: |
|
|
|
:param start: |
|
|
|
:return: |
|
|
|
""" |
|
|
|
for batch_x, batch_y in data_iterator: |
|
|
|
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 |
|
|
|
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) |
|
|
|
prediction = self._data_forward(model, batch_x) |
|
|
|
loss = self._compute_loss(prediction, batch_y) |
|
|
|
self._grad_backward(loss) |
|
|
|
self._update() |
|
|
|
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) |
|
|
|
for name, param in self.model.named_parameters(): |
|
|
|
if param.requires_grad: |
|
|
|
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step) |
|
|
|
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) |
|
|
|
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) |
|
|
|
if self.print_every > 0 and self.step % self.print_every == 0: |
|
|
|
end = time.time() |
|
|
|
diff = timedelta(seconds=round(end - start)) |
|
|
|
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( |
|
|
|
epoch, self.step, loss.data, diff) |
|
|
|
print(print_output) |
|
|
|
|
|
|
|
if self.validate_every > 0 and self.step % self.validate_every == 0: |
|
|
|
self._do_validation() |
|
|
|
|
|
|
|
self.step += 1 |
|
|
|
def _tqdm_train(self): |
|
|
|
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, |
|
|
|
as_numpy=False) |
|
|
|
total_steps = data_iterator.num_batches*self.n_epochs |
|
|
|
epoch = 1 |
|
|
|
with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', desc="Epoch {}/{}" |
|
|
|
.format(epoch, self.n_epochs), leave=False, dynamic_ncols=True) as pbar: |
|
|
|
ava_loss = 0 |
|
|
|
for epoch in range(1, self.n_epochs+1): |
|
|
|
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) |
|
|
|
for batch_x, batch_y in data_iterator: |
|
|
|
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) |
|
|
|
prediction = self._data_forward(self.model, batch_x) |
|
|
|
loss = self._compute_loss(prediction, batch_y) |
|
|
|
ava_loss += loss.item() |
|
|
|
self._grad_backward(loss) |
|
|
|
self._update() |
|
|
|
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) |
|
|
|
for name, param in self.model.named_parameters(): |
|
|
|
if param.requires_grad: |
|
|
|
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step) |
|
|
|
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) |
|
|
|
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) |
|
|
|
if (self.step+1) % self.print_every == 0: |
|
|
|
pbar.update(self.print_every) |
|
|
|
pbar.set_postfix_str("loss:{0:<6.5f}".format(ava_loss/self.print_every)) |
|
|
|
ava_loss = 0 |
|
|
|
|
|
|
|
self.step += 1 |
|
|
|
if self.validate_every > 0 and self.step % self.validate_every == 0 \ |
|
|
|
and self.dev_data is not None: |
|
|
|
eval_res = self._do_validation() |
|
|
|
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ |
|
|
|
self.tester._format_eval_results(eval_res) |
|
|
|
pbar = self._relocate_pbar(pbar, print_str=eval_str, total=total_steps, initial=self.step) |
|
|
|
time.sleep(0.1) |
|
|
|
if self.validate_every < 0 and self.dev_data: |
|
|
|
eval_res = self._do_validation() |
|
|
|
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ |
|
|
|
self.tester._format_eval_results(eval_res) |
|
|
|
pbar = self._relocate_pbar(pbar, print_str=eval_str, total=total_steps, initial=self.step) |
|
|
|
if epoch!=self.n_epochs: |
|
|
|
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, |
|
|
|
as_numpy=False) |
|
|
|
pbar.close() |
|
|
|
|
|
|
|
def _relocate_pbar(self, pbar, total, initial, print_str=None): |
|
|
|
postfix = pbar.postfix |
|
|
|
desc = pbar.desc |
|
|
|
pbar.close() |
|
|
|
avg_time = pbar.avg_time |
|
|
|
start_t = pbar.start_t |
|
|
|
if print_str: |
|
|
|
print(print_str) |
|
|
|
pbar = tqdm(total=total, postfix=postfix, desc=desc, leave=False, initial=initial, dynamic_ncols=True) |
|
|
|
pbar.start_t = start_t |
|
|
|
pbar.avg_time = avg_time |
|
|
|
pbar.sp(pbar.__repr__()) |
|
|
|
return pbar |
|
|
|
|
|
|
|
def _do_validation(self): |
|
|
|
res = self.tester.test() |
|
|
|
for name, num in res.items(): |
|
|
|
pass |
|
|
|
# self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step) |
|
|
|
self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step) |
|
|
|
if self.save_path is not None and self._better_eval_result(res): |
|
|
|
metric_key = self.metric_key if self.metric_key is not None else "None" |
|
|
|
self._save_model(self.model, |
|
|
|
"best_" + "_".join([self.model.__class__.__name__, metric_key, self.start_time])) |
|
|
|
return res |
|
|
|
|
|
|
|
def _mode(self, model, is_test=False): |
|
|
|
"""Train mode or Test mode. This is for PyTorch currently. |
|
|
|