Browse Source

1.修改ProgressBart在Trainer中的一个bug;2修复pytest的bug

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
b38fc1136e
7 changed files with 61 additions and 38 deletions
  1. +41
    -13
      fastNLP/core/callbacks/callback_manager.py
  2. +0
    -2
      fastNLP/core/callbacks/progress_callback.py
  3. +3
    -13
      fastNLP/core/controllers/trainer.py
  4. +7
    -4
      tests/core/callbacks/test_checkpoint_callback_torch.py
  5. +2
    -2
      tests/core/callbacks/test_load_best_model_callback_torch.py
  6. +2
    -1
      tests/core/callbacks/test_more_evaluate_callback.py
  7. +6
    -3
      tests/helpers/utils.py

+ 41
- 13
fastNLP/core/callbacks/callback_manager.py View File

@@ -9,6 +9,8 @@ __all__ = [
from .callback_events import Events from .callback_events import Events
from .callback import Callback from .callback import Callback
from fastNLP.core.log import logger from fastNLP.core.log import logger
from .progress_callback import ProgressCallback, choose_progress_callback
from fastNLP.envs import rank_zero_call




def _transfer(func): def _transfer(func):
@@ -26,6 +28,43 @@ def _transfer(func):
return wrapper return wrapper




def prepare_callbacks(callbacks, progress_bar):
"""

:param callbacks:
:param progress_bar:
:return:
"""
_callbacks = []
if callbacks is not None:
if isinstance(callbacks, Callback):
callbacks = [callbacks]
if not isinstance(callbacks, Sequence):
raise ValueError("Parameter `callbacks` should be type 'List' or 'Tuple'.")
callbacks = list(callbacks)
for _callback in callbacks:
if not isinstance(_callback, Callback):
raise TypeError(f"callbacks must be of Callback type, instead of `{type(_callback)}`")
_callbacks += callbacks

has_no_progress = False
for _callback in _callbacks:
if isinstance(_callback, ProgressCallback):
has_no_progress = True
if not has_no_progress:
callback = choose_progress_callback(progress_bar)
if callback is not None:
_callbacks.append(callback)
elif progress_bar is not None and progress_bar != 'auto':
logger.warning(f"Since you have passed in ProgressBar callback, progress_bar will be ignored.")

if has_no_progress and progress_bar is None:
rank_zero_call(logger.warning)("No progress bar is provided, there will have no information output "
"during training.")

return _callbacks


class CallbackManager: class CallbackManager:
r""" r"""
用来管理训练过程中的所有的 callback 实例; 用来管理训练过程中的所有的 callback 实例;
@@ -45,24 +84,13 @@ class CallbackManager:
""" """
self._need_reproducible_sampler = False self._need_reproducible_sampler = False


_callbacks = []
if callbacks is not None:
if isinstance(callbacks, Callback):
callbacks = [callbacks]
if not isinstance(callbacks, Sequence):
raise ValueError("Parameter `callbacks` should be type 'List' or 'Tuple'.")
callbacks = list(callbacks)
for _callback in callbacks:
if not isinstance(_callback, Callback):
raise TypeError(f"callbacks must be of Callback type, instead of `{type(_callback)}`")
_callbacks += callbacks
self.callback_fns = defaultdict(list) self.callback_fns = defaultdict(list)
# 因为理论上用户最多只能通过 'trainer.on_train_begin' 或者 'trainer.callback_manager.on_train_begin' 来调用,即其是没办法 # 因为理论上用户最多只能通过 'trainer.on_train_begin' 或者 'trainer.callback_manager.on_train_begin' 来调用,即其是没办法
# 直接调用具体的某一个 callback 函数,而不调用其余的同名的 callback 函数的,因此我们只需要记录具体 Event 的时机即可; # 直接调用具体的某一个 callback 函数,而不调用其余的同名的 callback 函数的,因此我们只需要记录具体 Event 的时机即可;
self.callback_counter = defaultdict(lambda: 0) self.callback_counter = defaultdict(lambda: 0)
if len(_callbacks):
if len(callbacks):
# 这一对象是为了保存原始的类 callback 对象来帮助用户进行 debug,理论上在正常的使用中你并不会需要它; # 这一对象是为了保存原始的类 callback 对象来帮助用户进行 debug,理论上在正常的使用中你并不会需要它;
self.class_callbacks = _callbacks
self.class_callbacks = callbacks
else: else:
self.class_callbacks: Optional[List[Callback]] = [] self.class_callbacks: Optional[List[Callback]] = []




+ 0
- 2
fastNLP/core/callbacks/progress_callback.py View File

@@ -11,8 +11,6 @@ __all__ = [
from .has_monitor_callback import HasMonitorCallback from .has_monitor_callback import HasMonitorCallback
from fastNLP.core.utils import f_rich_progress from fastNLP.core.utils import f_rich_progress
from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.core.utils.utils import is_notebook





class ProgressCallback(HasMonitorCallback): class ProgressCallback(HasMonitorCallback):


+ 3
- 13
fastNLP/core/controllers/trainer.py View File

@@ -19,8 +19,8 @@ from .evaluator import Evaluator
from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader
from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList
from fastNLP.core.callbacks.callback import _CallbackWrapper from fastNLP.core.callbacks.callback import _CallbackWrapper
from fastNLP.core.callbacks.callback_manager import prepare_callbacks
from fastNLP.core.callbacks.callback_events import _SingleEventState from fastNLP.core.callbacks.callback_events import _SingleEventState
from fastNLP.core.callbacks.progress_callback import choose_progress_callback
from fastNLP.core.drivers import Driver from fastNLP.core.drivers import Driver
from fastNLP.core.drivers.utils import choose_driver from fastNLP.core.drivers.utils import choose_driver
from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext
@@ -133,7 +133,7 @@ class Trainer(TrainerEventTrigger):
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error";
progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象, progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象,
默认为 auto , auto 表示如果检测到当前 terminal 为交互型 则使用 RichCallback,否则使用 RawTextCallback对象。如果
默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 RichCallback,否则使用 RawTextCallback对象。如果
需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。 需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。
train_input_mapping: 与 input_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。 train_input_mapping: 与 input_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。
train_output_mapping: 与 output_mapping 一致,但是只用于 train 中。与 output_mapping 互斥。 train_output_mapping: 与 output_mapping 一致,但是只用于 train 中。与 output_mapping 互斥。
@@ -212,17 +212,7 @@ class Trainer(TrainerEventTrigger):
self.driver.set_optimizers(optimizers=optimizers) self.driver.set_optimizers(optimizers=optimizers)


# 根据 progress_bar 参数选择 ProgressBarCallback # 根据 progress_bar 参数选择 ProgressBarCallback
progress_bar_callback = choose_progress_callback(kwargs.get('progress_bar', 'auto'))
if progress_bar_callback is not None:
if callbacks is None:
callbacks = []
elif not isinstance(callbacks, Sequence):
callbacks = [callbacks]

callbacks = list(callbacks) + [progress_bar_callback]
else:
rank_zero_call(logger.warning)("No progress bar is provided, there will have no information output "
"during training.")
callbacks = prepare_callbacks(callbacks, kwargs.get('progress_bar', 'auto'))
# 初始化 callback manager; # 初始化 callback manager;
self.callback_manager = CallbackManager(callbacks) self.callback_manager = CallbackManager(callbacks)
# 添加所有的函数式 callbacks; # 添加所有的函数式 callbacks;


+ 7
- 4
tests/core/callbacks/test_checkpoint_callback_torch.py View File

@@ -73,7 +73,7 @@ def model_and_optimizers(request):
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)
@pytest.mark.parametrize("version", [0, 1]) @pytest.mark.parametrize("version", [0, 1])
@pytest.mark.parametrize("only_state_dict", [True, False]) @pytest.mark.parametrize("only_state_dict", [True, False])
@magic_argv_env_context
@magic_argv_env_context(timeout=100)
def test_model_checkpoint_callback_1( def test_model_checkpoint_callback_1(
model_and_optimizers: TrainerParameters, model_and_optimizers: TrainerParameters,
driver, driver,
@@ -193,7 +193,7 @@ def test_model_checkpoint_callback_1(
trainer.load_model(folder, only_state_dict=only_state_dict) trainer.load_model(folder, only_state_dict=only_state_dict)


trainer.run() trainer.run()
trainer.driver.barrier()
finally: finally:
rank_zero_rm(path) rank_zero_rm(path)


@@ -203,7 +203,7 @@ def test_model_checkpoint_callback_1(


@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)
@pytest.mark.parametrize("only_state_dict", [True]) @pytest.mark.parametrize("only_state_dict", [True])
@magic_argv_env_context
@magic_argv_env_context(timeout=100)
def test_model_checkpoint_callback_2( def test_model_checkpoint_callback_2(
model_and_optimizers: TrainerParameters, model_and_optimizers: TrainerParameters,
driver, driver,
@@ -283,6 +283,7 @@ def test_model_checkpoint_callback_2(


trainer.load_model(folder, only_state_dict=only_state_dict) trainer.load_model(folder, only_state_dict=only_state_dict)
trainer.run() trainer.run()
trainer.driver.barrier()


finally: finally:
rank_zero_rm(path) rank_zero_rm(path)
@@ -295,7 +296,7 @@ def test_model_checkpoint_callback_2(
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)
@pytest.mark.parametrize("version", [0, 1]) @pytest.mark.parametrize("version", [0, 1])
@pytest.mark.parametrize("only_state_dict", [True, False]) @pytest.mark.parametrize("only_state_dict", [True, False])
@magic_argv_env_context
@magic_argv_env_context(timeout=100)
def test_trainer_checkpoint_callback_1( def test_trainer_checkpoint_callback_1(
model_and_optimizers: TrainerParameters, model_and_optimizers: TrainerParameters,
driver, driver,
@@ -413,6 +414,7 @@ def test_trainer_checkpoint_callback_1(
trainer.load(folder, only_state_dict=only_state_dict) trainer.load(folder, only_state_dict=only_state_dict)


trainer.run() trainer.run()
trainer.driver.barrier()


finally: finally:
rank_zero_rm(path) rank_zero_rm(path)
@@ -661,6 +663,7 @@ def test_trainer_checkpoint_callback_2(
trainer.load(folder, model_load_fn=model_load_fn) trainer.load(folder, model_load_fn=model_load_fn)


trainer.run() trainer.run()
trainer.driver.barrier()


finally: finally:
rank_zero_rm(path) rank_zero_rm(path)


+ 2
- 2
tests/core/callbacks/test_load_best_model_callback_torch.py View File

@@ -16,7 +16,6 @@ from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.metrics.accuracy import Accuracy from fastNLP.core.metrics.accuracy import Accuracy
from fastNLP.core.callbacks.load_best_model_callback import LoadBestModelCallback from fastNLP.core.callbacks.load_best_model_callback import LoadBestModelCallback
from fastNLP.core import Evaluator from fastNLP.core import Evaluator
from fastNLP.core.utils.utils import safe_rm
from fastNLP.core.drivers.torch_driver import TorchSingleDriver from fastNLP.core.drivers.torch_driver import TorchSingleDriver
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchArgMaxDataset from tests.helpers.datasets.torch_data import TorchArgMaxDataset
@@ -112,7 +111,8 @@ def test_load_best_model_callback(
results = evaluator.run() results = evaluator.run()
assert np.allclose(callbacks[0].monitor_value, results['acc#acc#dl1']) assert np.allclose(callbacks[0].monitor_value, results['acc#acc#dl1'])
if save_folder: if save_folder:
safe_rm(save_folder)
import shutil
shutil.rmtree(save_folder, ignore_errors=True)
if dist.is_initialized(): if dist.is_initialized():
dist.destroy_process_group() dist.destroy_process_group()




+ 2
- 1
tests/core/callbacks/test_more_evaluate_callback.py View File

@@ -171,7 +171,7 @@ def test_model_more_evaluate_callback_1(
trainer.load_model(folder, only_state_dict=only_state_dict) trainer.load_model(folder, only_state_dict=only_state_dict)


trainer.run() trainer.run()
trainer.driver.barrier()
finally: finally:
rank_zero_rm(path) rank_zero_rm(path)


@@ -255,6 +255,7 @@ def test_trainer_checkpoint_callback_1(
trainer.load(folder, only_state_dict=only_state_dict) trainer.load(folder, only_state_dict=only_state_dict)


trainer.run() trainer.run()
trainer.driver.barrier()


finally: finally:
rank_zero_rm(path) rank_zero_rm(path)


+ 6
- 3
tests/helpers/utils.py View File

@@ -33,6 +33,8 @@ def recover_logger(fn):
def magic_argv_env_context(fn=None, timeout=600): def magic_argv_env_context(fn=None, timeout=600):
""" """
用来在测试时包裹每一个单独的测试函数,使得 ddp 测试正确; 用来在测试时包裹每一个单独的测试函数,使得 ddp 测试正确;
会丢掉 pytest 中的 arg 参数。

:param timeout: 表示一个测试如果经过多久还没有通过的话就主动将其 kill 掉,默认为 10 分钟,单位为秒; :param timeout: 表示一个测试如果经过多久还没有通过的话就主动将其 kill 掉,默认为 10 分钟,单位为秒;
:return: :return:
""" """
@@ -46,9 +48,10 @@ def magic_argv_env_context(fn=None, timeout=600):
env = deepcopy(os.environ.copy()) env = deepcopy(os.environ.copy())


used_args = [] used_args = []
for each_arg in sys.argv[1:]:
if "test" not in each_arg:
used_args.append(each_arg)
# for each_arg in sys.argv[1:]:
# # warning,否则 可能导致 pytest -s . 中的点混入其中,导致多卡启动的 collect tests items 不为 1
# if each_arg.startswith('-'):
# used_args.append(each_arg)


pytest_current_test = os.environ.get('PYTEST_CURRENT_TEST') pytest_current_test = os.environ.get('PYTEST_CURRENT_TEST')




Loading…
Cancel
Save