diff --git a/test/core/test_callbacks.py b/test/core/test_callbacks.py index fc555afb..db95a32d 100644 --- a/test/core/test_callbacks.py +++ b/test/core/test_callbacks.py @@ -167,6 +167,17 @@ class TestCallback(unittest.TestCase): metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, callbacks=warmup_callback, check_code_level=2) trainer.train() + + def test_early_stop_callback(self): + """ + 需要观察是否真的 EarlyStop + """ + data_set, model = prepare_env() + trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), + batch_size=2, n_epochs=10, print_every=5, dev_data=data_set, + metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, + callbacks=EarlyStopCallback(1), check_code_level=2) + trainer.train() def test_control_C(): @@ -177,12 +188,10 @@ def test_control_C(): line1 = "\n\n\n\n\n*************************" line2 = "*************************\n\n\n\n\n" - class Wait(Callback): def on_epoch_end(self): time.sleep(5) - data_set, model = prepare_env() print(line1 + "Test starts!" + line2) @@ -204,4 +213,4 @@ def test_control_C(): if __name__ == "__main__": - test_control_C() \ No newline at end of file + test_control_C()