Browse Source

add the test for EarlyStopCallback, BUG found!

tags/v0.4.10
ChenXin 5 years ago
parent
commit
4f0ec4a081
1 changed files with 12 additions and 3 deletions
  1. +12
    -3
      test/core/test_callbacks.py

+ 12
- 3
test/core/test_callbacks.py View File

@@ -167,6 +167,17 @@ class TestCallback(unittest.TestCase):
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True,
callbacks=warmup_callback, check_code_level=2) callbacks=warmup_callback, check_code_level=2)
trainer.train() trainer.train()
def test_early_stop_callback(self):
"""
需要观察是否真的 EarlyStop
"""
data_set, model = prepare_env()
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=2, n_epochs=10, print_every=5, dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True,
callbacks=EarlyStopCallback(1), check_code_level=2)
trainer.train()




def test_control_C(): def test_control_C():
@@ -177,12 +188,10 @@ def test_control_C():
line1 = "\n\n\n\n\n*************************" line1 = "\n\n\n\n\n*************************"
line2 = "*************************\n\n\n\n\n" line2 = "*************************\n\n\n\n\n"
class Wait(Callback): class Wait(Callback):
def on_epoch_end(self): def on_epoch_end(self):
time.sleep(5) time.sleep(5)
data_set, model = prepare_env() data_set, model = prepare_env()
print(line1 + "Test starts!" + line2) print(line1 + "Test starts!" + line2)
@@ -204,4 +213,4 @@ def test_control_C():




if __name__ == "__main__": if __name__ == "__main__":
test_control_C()
test_control_C()

Loading…
Cancel
Save