|
|
@@ -32,8 +32,8 @@ from fastNLP.core.utils import get_func_signature |
|
|
|
|
|
|
|
class Trainer(object): |
|
|
|
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, |
|
|
|
validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), |
|
|
|
check_code_level=0, metric_key=None, sampler=RandomSampler(), prefetch=False, use_tqdm=True, |
|
|
|
validate_every=-1, dev_data=None, save_path=None, optimizer=None, |
|
|
|
check_code_level=0, metric_key=None, sampler=None, prefetch=False, use_tqdm=True, |
|
|
|
use_cuda=False, callbacks=None): |
|
|
|
""" |
|
|
|
:param DataSet train_data: the training data |
|
|
@@ -96,7 +96,7 @@ class Trainer(object): |
|
|
|
losser = _prepare_losser(loss) |
|
|
|
|
|
|
|
# sampler check |
|
|
|
if not isinstance(sampler, BaseSampler): |
|
|
|
if sampler is not None and not isinstance(sampler, BaseSampler): |
|
|
|
raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) |
|
|
|
|
|
|
|
if check_code_level > -1: |
|
|
@@ -119,13 +119,15 @@ class Trainer(object): |
|
|
|
self.best_dev_epoch = None |
|
|
|
self.best_dev_step = None |
|
|
|
self.best_dev_perf = None |
|
|
|
self.sampler = sampler |
|
|
|
self.sampler = sampler if sampler is not None else RandomSampler() |
|
|
|
self.prefetch = prefetch |
|
|
|
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) |
|
|
|
|
|
|
|
if isinstance(optimizer, torch.optim.Optimizer): |
|
|
|
self.optimizer = optimizer |
|
|
|
else: |
|
|
|
if optimizer is None: |
|
|
|
optimizer = Adam(lr=0.01, weight_decay=0) |
|
|
|
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) |
|
|
|
|
|
|
|
self.use_tqdm = use_tqdm |
|
|
|