|
@@ -336,7 +336,7 @@ except: |
|
|
import warnings |
|
|
import warnings |
|
|
|
|
|
|
|
|
from .batch import DataSetIter, BatchIter |
|
|
from .batch import DataSetIter, BatchIter |
|
|
from .callback import CallbackManager, CallbackException |
|
|
|
|
|
|
|
|
from .callback import CallbackManager, CallbackException, Callback |
|
|
from .dataset import DataSet |
|
|
from .dataset import DataSet |
|
|
from .losses import _prepare_losser |
|
|
from .losses import _prepare_losser |
|
|
from .metrics import _prepare_metrics |
|
|
from .metrics import _prepare_metrics |
|
@@ -422,13 +422,8 @@ class Trainer(object): |
|
|
batch_size=32, sampler=None, drop_last=False, update_every=1, |
|
|
batch_size=32, sampler=None, drop_last=False, update_every=1, |
|
|
num_workers=0, n_epochs=10, print_every=5, |
|
|
num_workers=0, n_epochs=10, print_every=5, |
|
|
dev_data=None, metrics=None, metric_key=None, |
|
|
dev_data=None, metrics=None, metric_key=None, |
|
|
validate_every=-1, save_path=None, use_tqdm=True, device=None, prefetch=False, |
|
|
|
|
|
|
|
|
validate_every=-1, save_path=None, use_tqdm=True, device=None, |
|
|
callbacks=None, check_code_level=0, **kwargs): |
|
|
callbacks=None, check_code_level=0, **kwargs): |
|
|
if prefetch and num_workers==0: |
|
|
|
|
|
num_workers = 1 |
|
|
|
|
|
if prefetch: |
|
|
|
|
|
warnings.warn("prefetch is deprecated, will be removed in version 0.5.0, please use num_workers instead.") |
|
|
|
|
|
|
|
|
|
|
|
super(Trainer, self).__init__() |
|
|
super(Trainer, self).__init__() |
|
|
if not isinstance(model, nn.Module): |
|
|
if not isinstance(model, nn.Module): |
|
|
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") |
|
|
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") |
|
@@ -566,6 +561,9 @@ class Trainer(object): |
|
|
self.step = 0 |
|
|
self.step = 0 |
|
|
self.start_time = None # start timestamp |
|
|
self.start_time = None # start timestamp |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(callbacks, Callback): |
|
|
|
|
|
callbacks = [callbacks] |
|
|
|
|
|
|
|
|
self.callback_manager = CallbackManager(env={"trainer": self}, |
|
|
self.callback_manager = CallbackManager(env={"trainer": self}, |
|
|
callbacks=callbacks) |
|
|
callbacks=callbacks) |
|
|
|
|
|
|
|
@@ -617,8 +615,8 @@ class Trainer(object): |
|
|
|
|
|
|
|
|
if self.dev_data is not None and self.best_dev_perf is not None: |
|
|
if self.dev_data is not None and self.best_dev_perf is not None: |
|
|
self.logger.info( |
|
|
self.logger.info( |
|
|
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + |
|
|
|
|
|
self.tester._format_eval_results(self.best_dev_perf), ) |
|
|
|
|
|
|
|
|
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step)) |
|
|
|
|
|
self.logger.info(self.tester._format_eval_results(self.best_dev_perf)) |
|
|
results['best_eval'] = self.best_dev_perf |
|
|
results['best_eval'] = self.best_dev_perf |
|
|
results['best_epoch'] = self.best_dev_epoch |
|
|
results['best_epoch'] = self.best_dev_epoch |
|
|
results['best_step'] = self.best_dev_step |
|
|
results['best_step'] = self.best_dev_step |
|
|