Browse Source

* change trainer iterating into tqdm

tags/v0.2.0^2
yh 6 years ago
parent
commit
beb55f5288
4 changed files with 96 additions and 69 deletions
  1. +16
    -5
      fastNLP/core/dataset.py
  2. +78
    -62
      fastNLP/core/trainer.py
  3. +1
    -1
      fastNLP/core/utils.py
  4. +1
    -1
      test/core/test_trainer.py

+ 16
- 5
fastNLP/core/dataset.py View File

@@ -216,25 +216,36 @@ class DataSet(object):

return wrapper

def apply(self, func, new_field_name=None, is_input=False, is_target=False):
def apply(self, func, new_field_name=None, **kwargs):
"""Apply a function to every instance of the DataSet.

:param func: a function that takes an instance as input.
:param str new_field_name: If not None, results of the function will be stored as a new field.
:param **kwargs: Accept parameters will be
(1) is_input: boolean, will be ignored if new_field is None. If True, the new field will be as input.
(2) is_target: boolean, will be ignored if new_field is None. If True, the new field will be as target.
:return results: if new_field_name is not passed, returned values of the function over all instances.
"""
results = [func(ins) for ins in self]
extra_param = {}
if 'is_input' in kwargs:
extra_param['is_input'] = kwargs['is_input']
if 'is_target' in kwargs:
extra_param['is_target'] = kwargs['is_target']
if new_field_name is not None:
if new_field_name in self.field_arrays:
# overwrite the field, keep same attributes
old_field = self.field_arrays[new_field_name]
if 'is_input' not in extra_param:
extra_param['is_input'] = old_field.is_input
if 'is_target' not in extra_param:
extra_param['is_target'] = old_field.is_target
self.add_field(name=new_field_name,
fields=results,
padding_val=old_field.padding_val,
is_input=old_field.is_input,
is_target=old_field.is_target)
**extra_param)
else:
self.add_field(name=new_field_name, fields=results, is_input=is_input, is_target=is_target)
self.add_field(name=new_field_name, fields=results, **extra_param)
else:
return results

@@ -295,7 +306,7 @@ class DataSet(object):
for col in headers:
_dict[col] = []
for line_idx, line in enumerate(f, start_idx):
contents = line.split(sep)
contents = line.rstrip('\r\n').split(sep)
if len(contents) != len(headers):
if dropna:
continue


+ 78
- 62
fastNLP/core/trainer.py View File

@@ -1,7 +1,7 @@
import os
import time
from datetime import datetime
from datetime import timedelta
from tqdm import tqdm

import torch
from tensorboardX import SummaryWriter
@@ -12,6 +12,7 @@ from fastNLP.core.dataset import DataSet
from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.optimizer import Adam
from fastNLP.core.sampler import BaseSampler
from fastNLP.core.sampler import RandomSampler
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.tester import Tester
@@ -28,12 +29,10 @@ class Trainer(object):

"""

def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=50,
validate_every=-1,
dev_data=None, use_cuda=False, save_path=None,
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, update_every=50,
validate_every=-1, dev_data=None, use_cuda=False, save_path=None,
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0,
metric_key=None,
**kwargs):
metric_key=None, sampler=RandomSampler()):
"""

:param DataSet train_data: the training data
@@ -42,7 +41,7 @@ class Trainer(object):
:param MetricBase or List[MetricBase] metrics: a metric object or a list of metrics
:param int n_epochs: the number of training epochs
:param int batch_size: batch size for training and validation
:param int print_every: step interval to print next training information. Default: -1(no print).
:param int update_every: step interval to print next training information. Default: -1(no print).
:param int validate_every: step interval to do next validation. Default: -1(validate every epoch).
:param DataSet dev_data: the validation data
:param use_cuda:
@@ -54,8 +53,7 @@ class Trainer(object):
smaller, add a `-` character in front of the string. For example
::
metric_key="-PPL" # language model gets better as perplexity gets smaller

:param kwargs:
:param sampler: method used to generate batch data.

"""
super(Trainer, self).__init__()
@@ -90,6 +88,10 @@ class Trainer(object):
# prepare loss
losser = _prepare_losser(losser)

# sampler check
if not isinstance(sampler, BaseSampler):
raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler)))

if check_code_level > -1:
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data,
metric_key=metric_key, check_level=check_code_level)
@@ -103,9 +105,10 @@ class Trainer(object):
self.batch_size = int(batch_size)
self.use_cuda = bool(use_cuda)
self.save_path = save_path
self.print_every = int(print_every)
self.print_every = int(update_every)
self.validate_every = int(validate_every)
self.best_metric_indicator = None
self.sampler = sampler

self._model_device = model.parameters().__next__().device

@@ -119,10 +122,8 @@ class Trainer(object):
data=self.dev_data,
metrics=self.metrics,
batch_size=self.batch_size,
use_cuda=self.use_cuda)

for k, v in kwargs.items():
setattr(self, k, v)
use_cuda=self.use_cuda,
verbose=0)

self.step = 0
self.start_time = None # start timestamp
@@ -140,8 +141,7 @@ class Trainer(object):

self._mode(self.model, is_test=False)

start = time.time()
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
self.start_time = str(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
print("training epochs started " + self.start_time)
if self.save_path is None:
class psudoSW:
@@ -156,65 +156,81 @@ class Trainer(object):
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time))
self._summary_writer = SummaryWriter(path)

epoch = 1
while epoch <= self.n_epochs:

data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(),
as_numpy=False)

self._train_epoch(data_iterator, self.model, epoch, start)
self._tqdm_train()

# validate_every override validation at end of epochs
if self.dev_data and self.validate_every <= 0:
self._do_validation()
epoch += 1
finally:
self._summary_writer.close()
del self._summary_writer

def _train_epoch(self, data_iterator, model, epoch, start):
"""

:param data_iterator:
:param model:
:param epoch:
:param start:
:return:
"""
for batch_x, batch_y in data_iterator:
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
prediction = self._data_forward(model, batch_x)
loss = self._compute_loss(prediction, batch_y)
self._grad_backward(loss)
self._update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step)
for name, param in self.model.named_parameters():
if param.requires_grad:
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
if self.print_every > 0 and self.step % self.print_every == 0:
end = time.time()
diff = timedelta(seconds=round(end - start))
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
epoch, self.step, loss.data, diff)
print(print_output)

if self.validate_every > 0 and self.step % self.validate_every == 0:
self._do_validation()

self.step += 1
def _tqdm_train(self):
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler,
as_numpy=False)
total_steps = data_iterator.num_batches*self.n_epochs
epoch = 1
with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', desc="Epoch {}/{}"
.format(epoch, self.n_epochs), leave=False, dynamic_ncols=True) as pbar:
ava_loss = 0
for epoch in range(1, self.n_epochs+1):
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
for batch_x, batch_y in data_iterator:
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
prediction = self._data_forward(self.model, batch_x)
loss = self._compute_loss(prediction, batch_y)
ava_loss += loss.item()
self._grad_backward(loss)
self._update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step)
for name, param in self.model.named_parameters():
if param.requires_grad:
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
if (self.step+1) % self.print_every == 0:
pbar.update(self.print_every)
pbar.set_postfix_str("loss:{0:<6.5f}".format(ava_loss/self.print_every))
ava_loss = 0

self.step += 1
if self.validate_every > 0 and self.step % self.validate_every == 0 \
and self.dev_data is not None:
eval_res = self._do_validation()
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \
self.tester._format_eval_results(eval_res)
pbar = self._relocate_pbar(pbar, print_str=eval_str, total=total_steps, initial=self.step)
time.sleep(0.1)
if self.validate_every < 0 and self.dev_data:
eval_res = self._do_validation()
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \
self.tester._format_eval_results(eval_res)
pbar = self._relocate_pbar(pbar, print_str=eval_str, total=total_steps, initial=self.step)
if epoch!=self.n_epochs:
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler,
as_numpy=False)
pbar.close()

def _relocate_pbar(self, pbar, total, initial, print_str=None):
postfix = pbar.postfix
desc = pbar.desc
pbar.close()
avg_time = pbar.avg_time
start_t = pbar.start_t
if print_str:
print(print_str)
pbar = tqdm(total=total, postfix=postfix, desc=desc, leave=False, initial=initial, dynamic_ncols=True)
pbar.start_t = start_t
pbar.avg_time = avg_time
pbar.sp(pbar.__repr__())
return pbar

def _do_validation(self):
res = self.tester.test()
for name, num in res.items():
pass
# self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step)
self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step)
if self.save_path is not None and self._better_eval_result(res):
metric_key = self.metric_key if self.metric_key is not None else "None"
self._save_model(self.model,
"best_" + "_".join([self.model.__class__.__name__, metric_key, self.start_time]))
return res

def _mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently.


+ 1
- 1
fastNLP/core/utils.py View File

@@ -248,7 +248,7 @@ def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res:
if _unused_field:
unuseds.append([f"\tunused field: {_unused_field}"])
if _unused_param:
unuseds.append([f"\tunused param: {_unused_param}"])
unuseds.append([f"\tunused param: {_unused_param}"]) # output from predict or forward

if check_res.missing:
errs.append(f"\tmissing param: {check_res.missing}")


+ 1
- 1
test/core/test_trainer.py View File

@@ -36,7 +36,7 @@ class TrainerTestGround(unittest.TestCase):
metrics=AccuracyMetric(pred="predict", target="y"),
n_epochs=10,
batch_size=32,
print_every=10,
update_every=1,
validate_every=-1,
dev_data=dev_set,
optimizer=SGD(0.1),


Loading…
Cancel
Save