From 455720202387dc07b64d12f5b312d8e5dc138c79 Mon Sep 17 00:00:00 2001 From: ChenXin Date: Tue, 14 Apr 2020 10:51:23 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E4=BA=86=20ControlC=20?= =?UTF-8?q?=E7=9A=84=20API?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 3 +- docs/source/index.rst | 3 ++ docs/source/user/api_update.rst | 15 ++++++++ fastNLP/core/callback.py | 22 +++++++---- test/core/test_callbacks.py | 67 +++++++++++++++------------------ 5 files changed, 64 insertions(+), 46 deletions(-) create mode 100644 docs/source/user/api_update.rst diff --git a/README.md b/README.md index f2bd3501..74090646 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,8 @@ python -m spacy download en ### 扩展教程 - [Extend-1. BertEmbedding的各种用法](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_1_bert_embedding.html) -- [Extend-2. 使用fitlog 辅助 fastNLP 进行科研](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_2_fitlog.html) +- [Extend-2. 分布式训练简介](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_2_dist.html) +- [Extend-3. 使用fitlog 辅助 fastNLP 进行科研](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_3_fitlog.html) ## 内置组件 diff --git a/docs/source/index.rst b/docs/source/index.rst index e175dd94..4db6dea6 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -35,6 +35,9 @@ API 文档 fastNLP + +:doc:`API变动列表 ` + fitlog文档 ---------- diff --git a/docs/source/user/api_update.rst b/docs/source/user/api_update.rst new file mode 100644 index 00000000..08a6bdbe --- /dev/null +++ b/docs/source/user/api_update.rst @@ -0,0 +1,15 @@ +=========================== +API变动列表 +=========================== + +2020.4.14 +======================== + +修改了 :class:`fastNLP.core.callback.ControlC` 的 API。 + +原来的参数 ``quit_all`` 修改为 ``quit_and_do`` ,仍然接收一个 bool 值。新增可选参数 ``action`` ,接收一个待执行的函数, +在 ``quit_and_do`` 的值为 ``True`` 时,退出训练过程后执行该函数。 ``action`` 的默认值是退出整个程序,与原有功能一致。 + +.. note:: + 原有用法 `ControlC(True)` 和 `ControlC(False)` 均可以继续正确执行,但 `ControlC(quit_all=True/False)` 需要修改为 + `ControlC(quit_and_do=True/False)` 。 \ No newline at end of file diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 21055c15..c1ead8c5 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -695,20 +695,26 @@ class ControlC(Callback): 检测到 control+C 时的反馈 """ - def __init__(self, quit_all): + @staticmethod + def quit_all(): + import sys + sys.exit(0) # 直接退出程序 + + def __init__(self, quit_and_do, action=quit_all): r""" - :param bool quit_all: 若为True,则检测到control+C 直接退出程序;否则只退出Trainer + :param bool quit_and_do: 若为True,则检测到control+C 进行后续操作(默认值为:直接退出程序);否则只退出Trainer。 """ + super(ControlC, self).__init__() - if type(quit_all) != bool: - raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.") - self.quit_all = quit_all + if type(quit_and_do) != bool: + raise ValueError("In KeyBoardInterrupt, quit_and_do arguemnt must be a bool.") + self.quit_and_do = quit_and_do + self.action = action def on_exception(self, exception): if isinstance(exception, KeyboardInterrupt): - if self.quit_all is True: - import sys - sys.exit(0) # 直接退出程序 + if self.quit_and_do is True: + self.action() else: pass else: diff --git a/test/core/test_callbacks.py b/test/core/test_callbacks.py index e756040c..165d7004 100644 --- a/test/core/test_callbacks.py +++ b/test/core/test_callbacks.py @@ -12,7 +12,7 @@ from fastNLP import Instance from fastNLP import SGD from fastNLP import Trainer from fastNLP.core.callback import EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \ - LRFinder, TensorboardCallback + LRFinder, TensorboardCallback, Callback from fastNLP.core.callback import EvaluateCallback, FitlogCallback, SaveModelCallback from fastNLP.core.callback import WarmupCallback from fastNLP.models.base_model import NaiveClassifier @@ -225,39 +225,32 @@ class TestCallback(unittest.TestCase): metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, callbacks=EarlyStopCallback(1), check_code_level=2) trainer.train() - -@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") -def test_control_C(): - # 用于测试 ControlC , 再两次训练时用 Control+C 进行退出,如果最后不显示 "Test failed!" 则通过测试 - from fastNLP import ControlC, Callback - import time - - line1 = "\n\n\n\n\n*************************" - line2 = "*************************\n\n\n\n\n" - - class Wait(Callback): - def on_epoch_end(self): - time.sleep(5) - - data_set, model = prepare_env() - - print(line1 + "Test starts!" + line2) - trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), - batch_size=32, n_epochs=20, dev_data=data_set, - metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, - callbacks=[Wait(), ControlC(False)], check_code_level=2) - trainer.train() - - print(line1 + "Program goes on ..." + line2) - - trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), - batch_size=32, n_epochs=20, dev_data=data_set, - metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, - callbacks=[Wait(), ControlC(True)], check_code_level=2) - trainer.train() - - print(line1 + "Test failed!" + line2) - - -if __name__ == "__main__": - test_control_C() + + def test_control_C_callback(self): + + class Raise(Callback): + def on_epoch_end(self): + raise KeyboardInterrupt + + flags = [False] + + def set_flag(): + flags[0] = not flags[0] + + data_set, model = prepare_env() + + trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), + batch_size=32, n_epochs=20, dev_data=data_set, + metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, + callbacks=[Raise(), ControlC(False, set_flag)], check_code_level=2) + trainer.train() + + self.assertEqual(flags[0], False) + + trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), + batch_size=32, n_epochs=20, dev_data=data_set, + metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, + callbacks=[Raise(), ControlC(True, set_flag)], check_code_level=2) + trainer.train() + + self.assertEqual(flags[0], True)