|
|
@@ -28,7 +28,7 @@ class Trainer(object): |
|
|
|
"""Main Training Loop |
|
|
|
|
|
|
|
""" |
|
|
|
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, |
|
|
|
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, |
|
|
|
validate_every=-1, dev_data=None, use_cuda=False, save_path=None, |
|
|
|
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, |
|
|
|
metric_key=None, sampler=RandomSampler(), use_tqdm=True): |
|
|
@@ -36,7 +36,7 @@ class Trainer(object): |
|
|
|
|
|
|
|
:param DataSet train_data: the training data |
|
|
|
:param torch.nn.modules.module model: a PyTorch model |
|
|
|
:param LossBase losser: a loss object |
|
|
|
:param LossBase loss: a loss object |
|
|
|
:param MetricBase or List[MetricBase] metrics: a metric object or a list of metrics |
|
|
|
:param int n_epochs: the number of training epochs |
|
|
|
:param int batch_size: batch size for training and validation |
|
|
@@ -88,7 +88,7 @@ class Trainer(object): |
|
|
|
self.metric_key = None |
|
|
|
|
|
|
|
# prepare loss |
|
|
|
losser = _prepare_losser(losser) |
|
|
|
losser = _prepare_losser(loss) |
|
|
|
|
|
|
|
# sampler check |
|
|
|
if not isinstance(sampler, BaseSampler): |
|
|
|