|
@@ -66,8 +66,7 @@ class TestCallback(unittest.TestCase): |
|
|
dev_data=data_set, |
|
|
dev_data=data_set, |
|
|
metrics=AccuracyMetric(pred="predict", target="y"), |
|
|
metrics=AccuracyMetric(pred="predict", target="y"), |
|
|
callbacks=[EarlyStopCallback(5)]) |
|
|
callbacks=[EarlyStopCallback(5)]) |
|
|
with self.assertRaises(EarlyStopError): |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
|
|
|
trainer.train() |
|
|
|
|
|
|
|
|
def test_lr_scheduler(self): |
|
|
def test_lr_scheduler(self): |
|
|
data_set, model = prepare_env() |
|
|
data_set, model = prepare_env() |
|
|