@@ -28,7 +28,7 @@ class Trainer(object): | |||||
"""Main Training Loop | """Main Training Loop | ||||
""" | """ | ||||
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | |||||
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | |||||
validate_every=-1, dev_data=None, use_cuda=False, save_path=None, | validate_every=-1, dev_data=None, use_cuda=False, save_path=None, | ||||
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, | optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, | ||||
metric_key=None, sampler=RandomSampler(), use_tqdm=True): | metric_key=None, sampler=RandomSampler(), use_tqdm=True): | ||||
@@ -36,7 +36,7 @@ class Trainer(object): | |||||
:param DataSet train_data: the training data | :param DataSet train_data: the training data | ||||
:param torch.nn.modules.module model: a PyTorch model | :param torch.nn.modules.module model: a PyTorch model | ||||
:param LossBase losser: a loss object | |||||
:param LossBase loss: a loss object | |||||
:param MetricBase or List[MetricBase] metrics: a metric object or a list of metrics | :param MetricBase or List[MetricBase] metrics: a metric object or a list of metrics | ||||
:param int n_epochs: the number of training epochs | :param int n_epochs: the number of training epochs | ||||
:param int batch_size: batch size for training and validation | :param int batch_size: batch size for training and validation | ||||
@@ -88,7 +88,7 @@ class Trainer(object): | |||||
self.metric_key = None | self.metric_key = None | ||||
# prepare loss | # prepare loss | ||||
losser = _prepare_losser(losser) | |||||
losser = _prepare_losser(loss) | |||||
# sampler check | # sampler check | ||||
if not isinstance(sampler, BaseSampler): | if not isinstance(sampler, BaseSampler): | ||||
@@ -7,7 +7,6 @@ from collections import namedtuple | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from tqdm import tqdm | |||||
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | ||||
'varargs'], verbose=False) | 'varargs'], verbose=False) | ||||
@@ -42,7 +42,6 @@ def prepare_fake_dataset2(*args, size=100): | |||||
class TestTester(unittest.TestCase): | class TestTester(unittest.TestCase): | ||||
def test_case_1(self): | def test_case_1(self): | ||||
# 检查报错提示能否正确提醒用户 | # 检查报错提示能否正确提醒用户 | ||||
# 这里传入多余参数,让其duplicate | |||||
dataset = prepare_fake_dataset2('x1', 'x_unused') | dataset = prepare_fake_dataset2('x1', 'x_unused') | ||||
dataset.rename_field('x_unused', 'x2') | dataset.rename_field('x_unused', 'x2') | ||||
dataset.set_input('x1', 'x2') | dataset.set_input('x1', 'x2') | ||||
@@ -60,8 +59,9 @@ class TestTester(unittest.TestCase): | |||||
return {'preds': x} | return {'preds': x} | ||||
model = Model() | model = Model() | ||||
tester = Tester( | |||||
data=dataset, | |||||
model=model, | |||||
metrics=AccuracyMetric()) | |||||
tester.test() | |||||
with self.assertRaises(NameError): | |||||
tester = Tester( | |||||
data=dataset, | |||||
model=model, | |||||
metrics=AccuracyMetric()) | |||||
tester.test() |
@@ -48,7 +48,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
model = NaiveClassifier(2, 1) | model = NaiveClassifier(2, 1) | ||||
trainer = Trainer(train_set, model, | trainer = Trainer(train_set, model, | ||||
losser=BCELoss(pred="predict", target="y"), | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | metrics=AccuracyMetric(pred="predict", target="y"), | ||||
n_epochs=10, | n_epochs=10, | ||||
batch_size=32, | batch_size=32, | ||||
@@ -227,14 +227,15 @@ class TrainerTestGround(unittest.TestCase): | |||||
return {'preds': x} | return {'preds': x} | ||||
model = Model() | model = Model() | ||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
dev_data=dataset, | |||||
losser=CrossEntropyLoss(), | |||||
metrics=AccuracyMetric(), | |||||
use_tqdm=False, | |||||
print_every=2) | |||||
with self.assertRaises(NameError): | |||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
dev_data=dataset, | |||||
loss=CrossEntropyLoss(), | |||||
metrics=AccuracyMetric(), | |||||
use_tqdm=False, | |||||
print_every=2) | |||||
def test_case2(self): | def test_case2(self): | ||||
# check metrics Wrong | # check metrics Wrong | ||||