Browse Source

change two default init arguments of Trainer into None

tags/v0.4.10
FengZiYjun 6 years ago
parent
commit
99d6bb208b
1 changed files with 6 additions and 4 deletions
  1. +6
    -4
      fastNLP/core/trainer.py

+ 6
- 4
fastNLP/core/trainer.py View File

@@ -32,8 +32,8 @@ from fastNLP.core.utils import get_func_signature

class Trainer(object):
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50,
validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0),
check_code_level=0, metric_key=None, sampler=RandomSampler(), prefetch=False, use_tqdm=True,
validate_every=-1, dev_data=None, save_path=None, optimizer=None,
check_code_level=0, metric_key=None, sampler=None, prefetch=False, use_tqdm=True,
use_cuda=False, callbacks=None):
"""
:param DataSet train_data: the training data
@@ -96,7 +96,7 @@ class Trainer(object):
losser = _prepare_losser(loss)

# sampler check
if not isinstance(sampler, BaseSampler):
if sampler is not None and not isinstance(sampler, BaseSampler):
raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler)))

if check_code_level > -1:
@@ -119,13 +119,15 @@ class Trainer(object):
self.best_dev_epoch = None
self.best_dev_step = None
self.best_dev_perf = None
self.sampler = sampler
self.sampler = sampler if sampler is not None else RandomSampler()
self.prefetch = prefetch
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks)

if isinstance(optimizer, torch.optim.Optimizer):
self.optimizer = optimizer
else:
if optimizer is None:
optimizer = Adam(lr=0.01, weight_decay=0)
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters())

self.use_tqdm = use_tqdm


Loading…
Cancel
Save