|
@@ -167,6 +167,17 @@ class TestCallback(unittest.TestCase): |
|
|
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, |
|
|
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, |
|
|
callbacks=warmup_callback, check_code_level=2) |
|
|
callbacks=warmup_callback, check_code_level=2) |
|
|
trainer.train() |
|
|
trainer.train() |
|
|
|
|
|
|
|
|
|
|
|
def test_early_stop_callback(self): |
|
|
|
|
|
""" |
|
|
|
|
|
需要观察是否真的 EarlyStop |
|
|
|
|
|
""" |
|
|
|
|
|
data_set, model = prepare_env() |
|
|
|
|
|
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), |
|
|
|
|
|
batch_size=2, n_epochs=10, print_every=5, dev_data=data_set, |
|
|
|
|
|
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, |
|
|
|
|
|
callbacks=EarlyStopCallback(1), check_code_level=2) |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_control_C(): |
|
|
def test_control_C(): |
|
@@ -177,12 +188,10 @@ def test_control_C(): |
|
|
line1 = "\n\n\n\n\n*************************" |
|
|
line1 = "\n\n\n\n\n*************************" |
|
|
line2 = "*************************\n\n\n\n\n" |
|
|
line2 = "*************************\n\n\n\n\n" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Wait(Callback): |
|
|
class Wait(Callback): |
|
|
def on_epoch_end(self): |
|
|
def on_epoch_end(self): |
|
|
time.sleep(5) |
|
|
time.sleep(5) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_set, model = prepare_env() |
|
|
data_set, model = prepare_env() |
|
|
|
|
|
|
|
|
print(line1 + "Test starts!" + line2) |
|
|
print(line1 + "Test starts!" + line2) |
|
@@ -204,4 +213,4 @@ def test_control_C(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
if __name__ == "__main__": |
|
|
test_control_C() |
|
|
|
|
|
|
|
|
test_control_C() |