Browse Source

更新了 ControlC 的 API

tags/v0.5.5
ChenXin 4 years ago
parent
commit
4557202023
5 changed files with 64 additions and 46 deletions
  1. +2
    -1
      README.md
  2. +3
    -0
      docs/source/index.rst
  3. +15
    -0
      docs/source/user/api_update.rst
  4. +14
    -8
      fastNLP/core/callback.py
  5. +30
    -37
      test/core/test_callbacks.py

+ 2
- 1
README.md View File

@@ -61,7 +61,8 @@ python -m spacy download en
### 扩展教程

- [Extend-1. BertEmbedding的各种用法](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_1_bert_embedding.html)
- [Extend-2. 使用fitlog 辅助 fastNLP 进行科研](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_2_fitlog.html)
- [Extend-2. 分布式训练简介](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_2_dist.html)
- [Extend-3. 使用fitlog 辅助 fastNLP 进行科研](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_3_fitlog.html)


## 内置组件


+ 3
- 0
docs/source/index.rst View File

@@ -35,6 +35,9 @@ API 文档
fastNLP


:doc:`API变动列表 </user/api_update>`

fitlog文档
----------



+ 15
- 0
docs/source/user/api_update.rst View File

@@ -0,0 +1,15 @@
===========================
API变动列表
===========================

2020.4.14
========================

修改了 :class:`fastNLP.core.callback.ControlC` 的 API。

原来的参数 ``quit_all`` 修改为 ``quit_and_do`` ,仍然接收一个 bool 值。新增可选参数 ``action`` ,接收一个待执行的函数,
在 ``quit_and_do`` 的值为 ``True`` 时,退出训练过程后执行该函数。 ``action`` 的默认值是退出整个程序,与原有功能一致。

.. note::
原有用法 `ControlC(True)` 和 `ControlC(False)` 均可以继续正确执行,但 `ControlC(quit_all=True/False)` 需要修改为
`ControlC(quit_and_do=True/False)` 。

+ 14
- 8
fastNLP/core/callback.py View File

@@ -695,20 +695,26 @@ class ControlC(Callback):
检测到 control+C 时的反馈
"""
def __init__(self, quit_all):
@staticmethod
def quit_all():
import sys
sys.exit(0) # 直接退出程序
def __init__(self, quit_and_do, action=quit_all):
r"""
:param bool quit_all: 若为True,则检测到control+C 直接退出程序;否则只退出Trainer
:param bool quit_and_do: 若为True,则检测到control+C 进行后续操作(默认值为:直接退出程序);否则只退出Trainer。
"""
super(ControlC, self).__init__()
if type(quit_all) != bool:
raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.")
self.quit_all = quit_all
if type(quit_and_do) != bool:
raise ValueError("In KeyBoardInterrupt, quit_and_do arguemnt must be a bool.")
self.quit_and_do = quit_and_do
self.action = action
def on_exception(self, exception):
if isinstance(exception, KeyboardInterrupt):
if self.quit_all is True:
import sys
sys.exit(0) # 直接退出程序
if self.quit_and_do is True:
self.action()
else:
pass
else:


+ 30
- 37
test/core/test_callbacks.py View File

@@ -12,7 +12,7 @@ from fastNLP import Instance
from fastNLP import SGD
from fastNLP import Trainer
from fastNLP.core.callback import EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \
LRFinder, TensorboardCallback
LRFinder, TensorboardCallback, Callback
from fastNLP.core.callback import EvaluateCallback, FitlogCallback, SaveModelCallback
from fastNLP.core.callback import WarmupCallback
from fastNLP.models.base_model import NaiveClassifier
@@ -225,39 +225,32 @@ class TestCallback(unittest.TestCase):
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True,
callbacks=EarlyStopCallback(1), check_code_level=2)
trainer.train()

@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
def test_control_C():
# 用于测试 ControlC , 再两次训练时用 Control+C 进行退出,如果最后不显示 "Test failed!" 则通过测试
from fastNLP import ControlC, Callback
import time

line1 = "\n\n\n\n\n*************************"
line2 = "*************************\n\n\n\n\n"

class Wait(Callback):
def on_epoch_end(self):
time.sleep(5)

data_set, model = prepare_env()

print(line1 + "Test starts!" + line2)
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=20, dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True,
callbacks=[Wait(), ControlC(False)], check_code_level=2)
trainer.train()

print(line1 + "Program goes on ..." + line2)

trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=20, dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True,
callbacks=[Wait(), ControlC(True)], check_code_level=2)
trainer.train()

print(line1 + "Test failed!" + line2)


if __name__ == "__main__":
test_control_C()
def test_control_C_callback(self):
class Raise(Callback):
def on_epoch_end(self):
raise KeyboardInterrupt
flags = [False]
def set_flag():
flags[0] = not flags[0]
data_set, model = prepare_env()
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=20, dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True,
callbacks=[Raise(), ControlC(False, set_flag)], check_code_level=2)
trainer.train()
self.assertEqual(flags[0], False)
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=20, dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True,
callbacks=[Raise(), ControlC(True, set_flag)], check_code_level=2)
trainer.train()
self.assertEqual(flags[0], True)

Loading…
Cancel
Save