Browse Source

Merge branch 'master' of github.com:fastnlp/fastNLP

tags/v0.5.5
yh_cc 5 years ago
parent
commit
91d9b239c2
9 changed files with 315 additions and 55 deletions
  1. +2
    -1
      README.md
  2. +3
    -0
      docs/source/index.rst
  3. +223
    -0
      docs/source/tutorials/extend_2_dist.rst
  4. +0
    -0
      docs/source/tutorials/extend_3_fitlog.rst
  5. +15
    -0
      docs/source/user/api_update.rst
  6. +2
    -1
      docs/source/user/tutorials.rst
  7. +15
    -9
      fastNLP/core/callback.py
  8. +25
    -7
      fastNLP/core/dist_trainer.py
  9. +30
    -37
      test/core/test_callbacks.py

+ 2
- 1
README.md View File

@@ -61,7 +61,8 @@ python -m spacy download en
### 扩展教程 ### 扩展教程


- [Extend-1. BertEmbedding的各种用法](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_1_bert_embedding.html) - [Extend-1. BertEmbedding的各种用法](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_1_bert_embedding.html)
- [Extend-2. 使用fitlog 辅助 fastNLP 进行科研](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_2_fitlog.html)
- [Extend-2. 分布式训练简介](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_2_dist.html)
- [Extend-3. 使用fitlog 辅助 fastNLP 进行科研](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_3_fitlog.html)




## 内置组件 ## 内置组件


+ 3
- 0
docs/source/index.rst View File

@@ -35,6 +35,9 @@ API 文档
fastNLP fastNLP



:doc:`API变动列表 </user/api_update>`

fitlog文档 fitlog文档
---------- ----------




+ 223
- 0
docs/source/tutorials/extend_2_dist.rst View File

@@ -0,0 +1,223 @@
Distributed Parallel Training
=============================

原理
----

随着深度学习模型越来越复杂,单个GPU可能已经无法满足正常的训练。比如BERT等预训练模型,更是在多个GPU上训练得到的。为了使用多GPU训练,Pytorch框架已经提供了
`nn.DataParallel <https://pytorch.org/docs/stable/nn.html#dataparallel>`_ 以及
`nn.DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#distributeddataparallel>`_ 两种方式的支持。
`nn.DataParallel <https://pytorch.org/docs/stable/nn.html#dataparallel>`_
很容易使用,但是却有着GPU负载不均衡,单进程速度慢等缺点,无法发挥出多GPU的全部性能。因此,分布式的多GPU训练方式
`nn.DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#distributeddataparallel>`_
是更好的选择。然而,因为分布式训练的特点,
`nn.DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#distributeddataparallel>`_
常常难以理解和使用,也很难debug。所以,在使用分布式训练之前,需要理解它的原理。

在使用
`nn.DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#distributeddataparallel>`_
时,模型会被复制到所有使用的GPU,通常每个GPU上存有一个模型,并被一个单独的进程控制。这样有N块GPU,就会产生N个进程。当训练一个batch时,这一batch会被分为N份,每个进程会使用batch的一部分进行训练,然后在必要时进行同步,并通过网络传输需要同步的数据。这时,只有模型的梯度会被同步,而模型的参数不会,所以能缓解大部分的网络传输压力,网络传输不再是训练速度的瓶颈之一。你可能会好奇,不同步模型的参数,怎么保证不同进程所训练的模型相同?只要每个进程初始的模型是同一个,具有相同的参数,而之后每次更新,都使用相同的梯度,就能保证梯度更新后的模型也具有相同的参数了。

为了让每个进程的模型初始化完全相同,通常这N个进程都是由单个进程复制而来的,这时需要对分布式的进程进行初始化,建立相互通信的机制。在
Pytorch 中,我们用
`distributed.init_process_group <https://pytorch.org/docs/stable/distributed.html#initialization>`_
函数来完成,需要在程序开头就加入这一步骤。初始化完成后,每一个进程用唯一的编号
``rank`` 进行区分,从 0 到 N-1递增,一般地,我们将 ``rank`` 为 0
的进程当作主进程,而其他 ``rank`` 的进程为子进程。每个进程还要知道
``world_size`` ,即分布式训练的总进程数
N。训练时,每个进程使用batch的一部分,互相不能重复,这里通过
`nn.utils.data.DistributedSampler <https://pytorch.org/docs/stable/_modules/torch/utils/data/distributed.html>`_
来实现。

使用方式
--------

Pytorch的分布式训练使用起来非常麻烦,难以理解,可以从给出的\ `官方教程 <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`_ \ 中看到。而\ ``fastNLP``
提供了
``DistTrainer``\ ,将大部分的分布式训练的细节进行了封装,只需简单的改动训练代码,就能直接用上分布式训练。那么,具体怎么将普通的训练代码改成支持分布式训练的代码呢。下面我们来讲一讲分布式训练的完整流程。通常,分布式程序的多个进程是单个进程的复制。假设我们用N个GPU进行分布式训练,我们需要启动N个进程,这时,在命令行使用:

.. code:: shell

python -m torch.distributed.launch --nproc_per_node=N train_script.py --args

其中\ ``N``\ 是需要启动的进程数,\ ``train_script.py``\ 为训练代码,\ ``--args``\ 是自定义的命令行参数。在启动了N个进程之后,如果我们在\ ``train_script.py``\ 的训练代码中正常配置,分布式训练就能正常进行。

此外,还可以使用环境变量\ ``CUDA_VISIBLE_DEVICES``\ 设置指定的GPU,比如在8卡机器上使用编号为4,5,6,7的4块GPU:

.. code:: shell

CUDA_VISIBLE_DEVICES=4,5,6,7 python -m torch.distributed.launch --nproc_per_node=N train_script.py --args

在 ``train_script.py``
训练代码中,有一些必须的配置。为了清晰的叙述,这里放一个简单的分布式训练代码,省去多余细节:

.. code:: python

import torch.distributed as dist
from fastNLP import DistTrainer, get_local_rank
import fastNLP as fnlp

def main(options):
# options为训练所需的参数,batch_size等
set_seed(options.seed)
# 初始化分布式进程
dist.init_process_group('nccl')

######## 读取数据
if get_local_rank() != 0:
dist.barrier() # 先让主进程(rank==0)先执行,进行数据处理,预训模型参数下载等操作,然后保存cache
data = get_processed_data()
model = get_model(data.get_vocab("words"), data.get_vocab("target"))
if get_local_rank() == 0:
dist.barrier() # 主进程执行完后,其余进程开始读取cache
########

# 初始化Trainer,训练等,与普通训练差别不大
def get_trainer(model, data):
# 注意设置的callback有两种,一种只在主进程执行,一种在所有进程都执行
callbacks_master = [fnlp.FitlogCallback()]
callbacks_all = [fnlp.WarmupCallback(warmup=options.warmup)]
trainer = DistTrainer(
save_path='save',
train_data=data.get_dataset("train"),
dev_data=data.get_dataset("dev"),
model=model,
loss=fnlp.CrossEntropyLoss(),
metrics=fnlp.AccuracyMetric(),
metric_key="acc",
optimizer=fnlp.AdamW(model.parameters(), lr=options.lr),
callbacks_master=callbacks_master, # 仅在主进程执行(如模型保存,日志记录)
callbacks_all=callbacks_all, # 在所有进程都执行(如梯度裁剪,学习率衰减)
batch_size_per_gpu=options.batch_size, # 指定每个GPU的batch大小
update_every=options.update,
n_epochs=options.epochs,
use_tqdm=True,
)
return trainer
trainer = get_trainer(model, data)
trainer.train()

指定进程编号
^^^^^^^^^^^^

首先,为了区分不同的进程,初始时需要对每个进程传入\ ``rank``\ 。这里一般分为\ ``node_rank``\ 和\ ``local_rank``\ ,分别表示进程处于哪一机器以及同机器上处于第几进程。如果在单一机器上,\ ``node_rank``\ 可以省略。\ ``local_rank``\ 一般通过命令行参数\ ``--local_rank``\ 传入,为\ ``int``\ 类型。也可以通过环境变量传入\ ``local_rank``\ ,只需在\ ``torch.distributed.launch``\ 时,使用\ ``--use_env``\ 参数。无论哪种方式,在训练脚本中,都要获取到\ ``local_rank``\ ,用于初始化分布式通信,以及区分进程。如果你使用\ ``fastNLP``\ ,可以通过\ ``fastNLP.get_local_rank``\ 来得到\ ``local_rank``\ 。

初始化进程
^^^^^^^^^^

在获取了\ ``local_rank``\ 等重要参数后,在开始训练前,我们需要建立不同进程的通信和同步机制。这时我们使用\ `torch.distributed.init_process_group <https://pytorch.org/docs/stable/distributed.html#initialization>`_
来完成。通常,我们只需要 ``torch.distributed.init_process_group('nccl')``
来指定使用\ ``nccl``\ 后端来进行同步即可。其他参数程序将读取环境变量自动设置。如果想手动设置这些参数,比如,使用TCP进行通信,可以设置:

.. code:: python

init_process_group('nccl', init_method='tcp://localhost:55678',
rank=args.rank, world_size=N)

或者使用文件进行通信:

.. code:: python

init_process_group('nccl', init_method='file:///mnt/nfs/sharedfile',
world_size=N, rank=args.rank)

注意,此时必须显式指定\ ``world_size``\ 和\ ``rank``\ ,具体可以参考
`torch.distributed.init_process_group <https://pytorch.org/docs/stable/distributed.html#initialization>`_
的使用文档。

在初始化分布式通信后,再初始化\ ``DistTrainer``\ ,传入数据和模型,就完成了分布式训练的代码。代码修改完成后,使用上面给出的命令行启动脚本,就能成功运行分布式训练。但是,如果数据处理,训练中的自定义操作比较复杂,则可能需要额外的代码修改。下面列出一些需要特别注意的地方,在使用分布式训练前,请仔细检查这些事项。

注意事项
--------

在执行完
`torch.distributed.init_process_group <https://pytorch.org/docs/stable/distributed.html#initialization>`_
后,我们就可以在不同进程间完成传输数据,进行同步等操作。这些操作都可以在\ `torch.distributed <https://pytorch.org/docs/stable/distributed.html#>`_
中找到。其中,最重要的是
`barrier <https://pytorch.org/docs/stable/distributed.html#torch.distributed.barrier>`_
以及
`get_rank <https://pytorch.org/docs/stable/distributed.html#torch.distributed.get_rank>`_
操作。对于训练而言,我们关心的是读入数据,记录日志,模型初始化,模型参数更新,模型保存等操作。这些操作大多是读写操作,在多进程状态下,这些操作都必须小心进行,否则可能出现难以预料的bug。而在\ ``fastNLP``\ 中,大部分操作都封装在
``DistTrainer`` 中,只需保证数据读入和模型初始化正确即可完成训练。

写操作
^^^^^^

一般而言,读入操作需要在每一个进程都执行,因为每个进程都要使用读入的数据和模型参数进行训练。而写出操作只需在其中一个进程(通常为主进程)执行,因为每一个进程保存的模型都相同,都处于同一训练状态。所以,通常单进程的训练脚本中,只需要修改写出操作的部分,通过加入对进程\ ``rank``\ 的判断,仅让其中一个进程执行写操作:

.. code:: python

import torch.distributed as dist

# 仅在主进程才执行
if dist.get_rank() == 0:
do_wirte_op() # 一些写操作
dist.barrier() # 确保写完成后,所有进程再执行(若进程无需读入写出的数据,可以省去)

若使用\ ``fastNLP``\ 中的\ ``DistTrainer``\ ,也可以这样写:

.. code:: python

# 判断是否是主进程的trainer
if trainer.is_master:
do_wirte_op()
dist.barrier()

读操作
^^^^^^

然而有些时候,我们需要其中一个进程先执行某些操作,等这一进程执行完后,其它进程再执行这一操作。比如,在读入数据时,我们有时需要从网上下载,再处理,将处理好的数据保存,供反复使用。这时,我们不需要所有进程都去下载和处理数据,只需要主进程进行这些操作,其它进程等待。直到处理好的数据被保存后,其他进程再从保存位置直接读入数据。这里可以参考范例代码中的读取数据:

.. code:: python

if dist.get_rank() != 0:
dist.barrier() # 先让主进程(rank==0)先执行,进行数据处理,预训模型参数下载等操作,然后保存cache

# 这里会自动处理数据,或直接读取保存的cache
data = get_processed_data()
model = get_model(data.get_vocab("words"), data.get_vocab("target"))

if dist.get_rank() == 0:
dist.barrier() # 主进程执行完后,其余进程开始读取cache

也可以显式的将主进程和其它进程的操作分开:

.. code:: python

if dist.get_rank() == 0:
data = do_data_processing() # 数据处理
dist.barrier()
else:
dist.barrier()
data = load_processed_data() # 读取cache

日志操作
^^^^^^^^

通常,我们需要知道训练的状态,如当前在第几个epoch,模型当前的loss等等。单进程训练时,我们可以直接使用\ ``print``\ 将这些信息输出到命令行或日志文件。然而,在多进程时,\ ``print``\ 会导致同样的信息在每一进程都输出,造成问题。这一问题和写操作类似,也可以通过判断进程的编号之后再输出。问题是,日志通常在训练的很多地方都有输出,逐一加上判断代码是非常繁琐的。这里,建议统一修改为:

.. code:: python

from fastNLP import logger
logger.info('....') # 替换print

在\ ``DistTrainer``\ 中,主进程的\ ``logger``\ 级别为\ ``INFO``\ ,而其它进程为\ ``WARNING``\ 。这样级别为\ ``INFO``\ 的信息只会在主进程输出,不会造成日志重复问题。若需要其它进程中的信息,可以使用\ ``logger.warning``\ 。

注意,\ ``logger``\ 的级别设置只有初始化了\ ``DistTrainer``\ 后才能生效。如果想要在初始化进程后就生效,需要在分布式通信初始化后,执行\ ``init_logger_dist``\ 。

Callback
^^^^^^^^

``fastNLP``\ 的一个特色是可以使用\ ``Callback``\ 在训练时完成各种自定义操作。而这一特色在\ ``DistTrainer``\ 中得以保留。但是,这时需要特别注意\ ``Callback``\ 是否只需要在主进程执行。一些\ ``Callback``\ ,比如调整学习率,梯度裁剪等,会改变模型的状态,因此需要在所有进程上都执行,将它们通过\ ``callback_all``\ 参数传入\ ``DistTrainer``\ 。而另一些\ ``Callback``\ ,比如\ ``fitlog``\ ,保存模型,不会改变模型的状态,而是进行数据写操作,因此仅在主进程上执行,将它们通过\ ``callback_master``\ 传入。

在自定义\ ``Callback``\ 时,请遵循一个原则,改变训练或模型状态的操作在所有进程中执行,而数据写到硬盘请在主进程单独进行。这样就能避免进程间失去同步,或者磁盘写操作的冲突。

Debug
^^^^^

多进程的程序很难进行debug,如果出现问题,可以先参考报错信息进行处理。也可以在程序中多输出日志,定位问题。具体情况,具体分析。在debug时,要多考虑进程同步和异步的操作,判断问题是程序本身导致的,还是由进程间没有同步而产生。

其中,有一个常见问题是程序卡住不动。具体表现为训练暂停,程序没有输出,但是GPU利用率保持100%。这一问题是由进程失去同步导致的。这时只能手动\ ``kill``\ GPU上残留的进程,再检查代码。需要检查进程同步的位置,比如模型\ ``backward()``\ 时,\ ``barrier()``\ 时等。同时,也要检查主进程与其它进程操作不同的位置,比如存储模型,evaluate模型时等。注意,失去同步的位置可能并不是程序卡住的位置,所以需要细致的检查。

docs/source/tutorials/extend_2_fitlog.rst → docs/source/tutorials/extend_3_fitlog.rst View File


+ 15
- 0
docs/source/user/api_update.rst View File

@@ -0,0 +1,15 @@
===========================
API变动列表
===========================

2020.4.14
========================

修改了 :class:`fastNLP.core.callback.ControlC` 的 API。

原来的参数 ``quit_all`` 修改为 ``quit_and_do`` ,仍然接收一个 bool 值。新增可选参数 ``action`` ,接收一个待执行的函数,
在 ``quit_and_do`` 的值为 ``True`` 时,退出训练过程后执行该函数。 ``action`` 的默认值是退出整个程序,与原有功能一致。

.. note::
原有用法 `ControlC(True)` 和 `ControlC(False)` 均可以继续正确执行,但 `ControlC(quit_all=True/False)` 需要修改为
`ControlC(quit_and_do=True/False)` 。

+ 2
- 1
docs/source/user/tutorials.rst View File

@@ -21,4 +21,5 @@ fastNLP 详细使用教程
:maxdepth: 1 :maxdepth: 1


拓展阅读1:BertEmbedding的各种用法 </tutorials/extend_1_bert_embedding> 拓展阅读1:BertEmbedding的各种用法 </tutorials/extend_1_bert_embedding>
拓展阅读2:使用fitlog 辅助 fastNLP 进行科研 </tutorials/extend_2_fitlog>
拓展阅读2:分布式训练简介 </tutorials/extend_2_dist>
拓展阅读3:使用fitlog 辅助 fastNLP 进行科研 </tutorials/extend_3_fitlog>

+ 15
- 9
fastNLP/core/callback.py View File

@@ -695,20 +695,26 @@ class ControlC(Callback):
检测到 control+C 时的反馈 检测到 control+C 时的反馈
""" """
def __init__(self, quit_all):
@staticmethod
def quit_all():
import sys
sys.exit(0) # 直接退出程序
def __init__(self, quit_and_do, action=quit_all):
r""" r"""
:param bool quit_all: 若为True,则检测到control+C 直接退出程序;否则只退出Trainer
:param bool quit_and_do: 若为True,则检测到control+C 进行后续操作(默认值为:直接退出程序);否则只退出Trainer。
""" """
super(ControlC, self).__init__() super(ControlC, self).__init__()
if type(quit_all) != bool:
raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.")
self.quit_all = quit_all
if type(quit_and_do) != bool:
raise ValueError("In KeyBoardInterrupt, quit_and_do arguemnt must be a bool.")
self.quit_and_do = quit_and_do
self.action = action
def on_exception(self, exception): def on_exception(self, exception):
if isinstance(exception, KeyboardInterrupt): if isinstance(exception, KeyboardInterrupt):
if self.quit_all is True:
import sys
sys.exit(0) # 直接退出程序
if self.quit_and_do is True:
self.action()
else: else:
pass pass
else: else:
@@ -809,7 +815,7 @@ class TensorboardCallback(Callback):
.. warning:: .. warning::
fastNLP 已停止对此功能的维护,请等待 fastNLP 兼容 PyTorch1.1 的下一个版本。 fastNLP 已停止对此功能的维护,请等待 fastNLP 兼容 PyTorch1.1 的下一个版本。
或者使用和 fastNLP 高度配合的 fitlog(参见 :doc:`/tutorials/extend_2_fitlog` )。
或者使用和 fastNLP 高度配合的 fitlog(参见 :doc:`/tutorials/extend_3_fitlog` )。
""" """


+ 25
- 7
fastNLP/core/dist_trainer.py View File

@@ -19,6 +19,7 @@ from pkg_resources import parse_version
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm from tqdm import tqdm
import time


from ._logger import logger, init_logger_dist from ._logger import logger, init_logger_dist
from .batch import DataSetIter, BatchIter from .batch import DataSetIter, BatchIter
@@ -175,8 +176,12 @@ class DistTrainer():
self.test_manager.add_callback([cb], master=True) self.test_manager.add_callback([cb], master=True)


# Setup logging # Setup logging
dist.barrier()
self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M')
# 同步start_time
sync_time = torch.tensor(time.time(), dtype=torch.double).to(self.device)
dist.broadcast(sync_time, src=0)
self.start_time = datetime.fromtimestamp(sync_time.item()).strftime('%Y-%m-%d-%H-%M-%S-%f')
# print('sync_time: {}, start_time: {}'.format(sync_time, self.start_time))

if self.save_path: if self.save_path:
self.cp_save_path = self.save_path self.cp_save_path = self.save_path
else: else:
@@ -208,9 +213,7 @@ class DistTrainer():
return contextlib.ExitStack() # dummy contextmanager return contextlib.ExitStack() # dummy contextmanager


def _get_n_steps(self): def _get_n_steps(self):
batch_size = self.world_size * self.batch_size_per_gpu
return (len(self.train_data) // batch_size + int(
len(self.train_data) % batch_size != 0)) * int(self.drop_last == 0) * self.n_epochs
return len(self.data_iterator) * self.n_epochs


def _get_data_iter(self, dataset): def _get_data_iter(self, dataset):
if isinstance(dataset, DataSet): if isinstance(dataset, DataSet):
@@ -432,8 +435,9 @@ class DistTrainer():
model_load = model_load.state_dict() model_load = model_load.state_dict()
self.model.load_state_dict(model_load) self.model.load_state_dict(model_load)


def _best_save_name(self):
return "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])
def _best_save_name(self, auto_fix=True):
best_name = "best_" + "_".join([self.model.__class__.__name__, str(self.metric_key), self.start_time])
return best_name


def _do_validation(self): def _do_validation(self):
with self.ddp_model.no_sync(): with self.ddp_model.no_sync():
@@ -447,12 +451,26 @@ class DistTrainer():
is_better = is_better[0] is_better = is_better[0]
else: else:
eval_res, is_better = None, None eval_res, is_better = None, None
if self.metric_key is None and eval_res is not None:
eval_res0 = list(eval_res.values())[0]
self.metric_key = list(eval_res0.keys())[0]
# logger.info('{}, {}'.format(eval_res, is_better)) # logger.info('{}, {}'.format(eval_res, is_better))
# save better model on master node # save better model on master node
if is_better is not None and self.cp_save_path: if is_better is not None and self.cp_save_path:
if is_better: if is_better:
self.save_check_point(self._best_save_name(), only_params=False) self.save_check_point(self._best_save_name(), only_params=False)
dist.barrier() dist.barrier()

if not self.is_master and self.metric_key is None:
# 主进程自动得到了metric_key,而其它进程没有
prefix = 'best_' + self.model.__class__.__name__
suffix = self.start_time
fn_list = os.listdir(self.cp_save_path)
fn_list = [fn for fn in fn_list if fn.startswith(prefix) and fn.endswith(suffix)]
if len(fn_list) == 1:
best_name = fn_list[0]
self.metric_key = best_name[len(prefix):-len(suffix)].strip('_')
# print('RANK {} metric_key {}'.format(self.rank, self.metric_key))
self.callback_manager.on_valid_end( self.callback_manager.on_valid_end(
eval_res, self.metric_key, self.optimizer, is_better) eval_res, self.metric_key, self.optimizer, is_better)
self.ddp_model.train() self.ddp_model.train()


+ 30
- 37
test/core/test_callbacks.py View File

@@ -12,7 +12,7 @@ from fastNLP import Instance
from fastNLP import SGD from fastNLP import SGD
from fastNLP import Trainer from fastNLP import Trainer
from fastNLP.core.callback import EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \ from fastNLP.core.callback import EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \
LRFinder, TensorboardCallback
LRFinder, TensorboardCallback, Callback
from fastNLP.core.callback import EvaluateCallback, FitlogCallback, SaveModelCallback from fastNLP.core.callback import EvaluateCallback, FitlogCallback, SaveModelCallback
from fastNLP.core.callback import WarmupCallback from fastNLP.core.callback import WarmupCallback
from fastNLP.models.base_model import NaiveClassifier from fastNLP.models.base_model import NaiveClassifier
@@ -225,39 +225,32 @@ class TestCallback(unittest.TestCase):
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True,
callbacks=EarlyStopCallback(1), check_code_level=2) callbacks=EarlyStopCallback(1), check_code_level=2)
trainer.train() trainer.train()

@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
def test_control_C():
# 用于测试 ControlC , 再两次训练时用 Control+C 进行退出,如果最后不显示 "Test failed!" 则通过测试
from fastNLP import ControlC, Callback
import time

line1 = "\n\n\n\n\n*************************"
line2 = "*************************\n\n\n\n\n"

class Wait(Callback):
def on_epoch_end(self):
time.sleep(5)

data_set, model = prepare_env()

print(line1 + "Test starts!" + line2)
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=20, dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True,
callbacks=[Wait(), ControlC(False)], check_code_level=2)
trainer.train()

print(line1 + "Program goes on ..." + line2)

trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=20, dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True,
callbacks=[Wait(), ControlC(True)], check_code_level=2)
trainer.train()

print(line1 + "Test failed!" + line2)


if __name__ == "__main__":
test_control_C()
def test_control_C_callback(self):
class Raise(Callback):
def on_epoch_end(self):
raise KeyboardInterrupt
flags = [False]
def set_flag():
flags[0] = not flags[0]
data_set, model = prepare_env()
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=20, dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True,
callbacks=[Raise(), ControlC(False, set_flag)], check_code_level=2)
trainer.train()
self.assertEqual(flags[0], False)
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=20, dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True,
callbacks=[Raise(), ControlC(True, set_flag)], check_code_level=2)
trainer.train()
self.assertEqual(flags[0], True)

Loading…
Cancel
Save