|
|
@@ -12,7 +12,7 @@ from fastNLP import Instance |
|
|
|
from fastNLP import SGD |
|
|
|
from fastNLP import Trainer |
|
|
|
from fastNLP.core.callback import EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \ |
|
|
|
LRFinder, TensorboardCallback |
|
|
|
LRFinder, TensorboardCallback, Callback |
|
|
|
from fastNLP.core.callback import EvaluateCallback, FitlogCallback, SaveModelCallback |
|
|
|
from fastNLP.core.callback import WarmupCallback |
|
|
|
from fastNLP.models.base_model import NaiveClassifier |
|
|
@@ -225,39 +225,32 @@ class TestCallback(unittest.TestCase): |
|
|
|
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, |
|
|
|
callbacks=EarlyStopCallback(1), check_code_level=2) |
|
|
|
trainer.train() |
|
|
|
|
|
|
|
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") |
|
|
|
def test_control_C(): |
|
|
|
# 用于测试 ControlC , 再两次训练时用 Control+C 进行退出,如果最后不显示 "Test failed!" 则通过测试 |
|
|
|
from fastNLP import ControlC, Callback |
|
|
|
import time |
|
|
|
|
|
|
|
line1 = "\n\n\n\n\n*************************" |
|
|
|
line2 = "*************************\n\n\n\n\n" |
|
|
|
|
|
|
|
class Wait(Callback): |
|
|
|
def on_epoch_end(self): |
|
|
|
time.sleep(5) |
|
|
|
|
|
|
|
data_set, model = prepare_env() |
|
|
|
|
|
|
|
print(line1 + "Test starts!" + line2) |
|
|
|
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), |
|
|
|
batch_size=32, n_epochs=20, dev_data=data_set, |
|
|
|
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, |
|
|
|
callbacks=[Wait(), ControlC(False)], check_code_level=2) |
|
|
|
trainer.train() |
|
|
|
|
|
|
|
print(line1 + "Program goes on ..." + line2) |
|
|
|
|
|
|
|
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), |
|
|
|
batch_size=32, n_epochs=20, dev_data=data_set, |
|
|
|
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, |
|
|
|
callbacks=[Wait(), ControlC(True)], check_code_level=2) |
|
|
|
trainer.train() |
|
|
|
|
|
|
|
print(line1 + "Test failed!" + line2) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
test_control_C() |
|
|
|
|
|
|
|
def test_control_C_callback(self): |
|
|
|
|
|
|
|
class Raise(Callback): |
|
|
|
def on_epoch_end(self): |
|
|
|
raise KeyboardInterrupt |
|
|
|
|
|
|
|
flags = [False] |
|
|
|
|
|
|
|
def set_flag(): |
|
|
|
flags[0] = not flags[0] |
|
|
|
|
|
|
|
data_set, model = prepare_env() |
|
|
|
|
|
|
|
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), |
|
|
|
batch_size=32, n_epochs=20, dev_data=data_set, |
|
|
|
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, |
|
|
|
callbacks=[Raise(), ControlC(False, set_flag)], check_code_level=2) |
|
|
|
trainer.train() |
|
|
|
|
|
|
|
self.assertEqual(flags[0], False) |
|
|
|
|
|
|
|
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), |
|
|
|
batch_size=32, n_epochs=20, dev_data=data_set, |
|
|
|
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, |
|
|
|
callbacks=[Raise(), ControlC(True, set_flag)], check_code_level=2) |
|
|
|
trainer.train() |
|
|
|
|
|
|
|
self.assertEqual(flags[0], True) |