From 99d6bb208bf9294d1ac88054e426f0df200b7f9f Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Sat, 16 Mar 2019 19:54:03 +0800 Subject: [PATCH] change two default init arguments of Trainer into None --- fastNLP/core/trainer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 5381fc5d..8880291d 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -32,8 +32,8 @@ from fastNLP.core.utils import get_func_signature class Trainer(object): def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, - validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), - check_code_level=0, metric_key=None, sampler=RandomSampler(), prefetch=False, use_tqdm=True, + validate_every=-1, dev_data=None, save_path=None, optimizer=None, + check_code_level=0, metric_key=None, sampler=None, prefetch=False, use_tqdm=True, use_cuda=False, callbacks=None): """ :param DataSet train_data: the training data @@ -96,7 +96,7 @@ class Trainer(object): losser = _prepare_losser(loss) # sampler check - if not isinstance(sampler, BaseSampler): + if sampler is not None and not isinstance(sampler, BaseSampler): raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) if check_code_level > -1: @@ -119,13 +119,15 @@ class Trainer(object): self.best_dev_epoch = None self.best_dev_step = None self.best_dev_perf = None - self.sampler = sampler + self.sampler = sampler if sampler is not None else RandomSampler() self.prefetch = prefetch self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) if isinstance(optimizer, torch.optim.Optimizer): self.optimizer = optimizer else: + if optimizer is None: + optimizer = Adam(lr=0.01, weight_decay=0) self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) self.use_tqdm = use_tqdm