|
|
@@ -54,6 +54,24 @@ class TrainerTestGround(unittest.TestCase): |
|
|
|
""" |
|
|
|
# 应该正确运行 |
|
|
|
""" |
|
|
|
|
|
|
|
def test_save_path(self): |
|
|
|
data_set = prepare_fake_dataset() |
|
|
|
data_set.set_input("x", flag=True) |
|
|
|
data_set.set_target("y", flag=True) |
|
|
|
|
|
|
|
train_set, dev_set = data_set.split(0.3) |
|
|
|
|
|
|
|
model = NaiveClassifier(2, 1) |
|
|
|
|
|
|
|
save_path = 'test_save_models' |
|
|
|
|
|
|
|
trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), |
|
|
|
batch_size=32, n_epochs=10, print_every=50, dev_data=dev_set, |
|
|
|
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=save_path, |
|
|
|
use_tqdm=True, check_code_level=2) |
|
|
|
trainer.train() |
|
|
|
|
|
|
|
|
|
|
|
def test_trainer_suggestion1(self): |
|
|
|
# 检查报错提示能否正确提醒用户。 |
|
|
|