diff --git a/fastNLP/core/callbacks/callback_manager.py b/fastNLP/core/callbacks/callback_manager.py index 90d2e1b1..f63c6088 100644 --- a/fastNLP/core/callbacks/callback_manager.py +++ b/fastNLP/core/callbacks/callback_manager.py @@ -9,6 +9,8 @@ __all__ = [ from .callback_events import Events from .callback import Callback from fastNLP.core.log import logger +from .progress_callback import ProgressCallback, choose_progress_callback +from fastNLP.envs import rank_zero_call def _transfer(func): @@ -26,6 +28,43 @@ def _transfer(func): return wrapper +def prepare_callbacks(callbacks, progress_bar): + """ + + :param callbacks: + :param progress_bar: + :return: + """ + _callbacks = [] + if callbacks is not None: + if isinstance(callbacks, Callback): + callbacks = [callbacks] + if not isinstance(callbacks, Sequence): + raise ValueError("Parameter `callbacks` should be type 'List' or 'Tuple'.") + callbacks = list(callbacks) + for _callback in callbacks: + if not isinstance(_callback, Callback): + raise TypeError(f"callbacks must be of Callback type, instead of `{type(_callback)}`") + _callbacks += callbacks + + has_no_progress = False + for _callback in _callbacks: + if isinstance(_callback, ProgressCallback): + has_no_progress = True + if not has_no_progress: + callback = choose_progress_callback(progress_bar) + if callback is not None: + _callbacks.append(callback) + elif progress_bar is not None and progress_bar != 'auto': + logger.warning(f"Since you have passed in ProgressBar callback, progress_bar will be ignored.") + + if has_no_progress and progress_bar is None: + rank_zero_call(logger.warning)("No progress bar is provided, there will have no information output " + "during training.") + + return _callbacks + + class CallbackManager: r""" 用来管理训练过程中的所有的 callback 实例; @@ -45,24 +84,13 @@ class CallbackManager: """ self._need_reproducible_sampler = False - _callbacks = [] - if callbacks is not None: - if isinstance(callbacks, Callback): - callbacks = [callbacks] - if not isinstance(callbacks, Sequence): - raise ValueError("Parameter `callbacks` should be type 'List' or 'Tuple'.") - callbacks = list(callbacks) - for _callback in callbacks: - if not isinstance(_callback, Callback): - raise TypeError(f"callbacks must be of Callback type, instead of `{type(_callback)}`") - _callbacks += callbacks self.callback_fns = defaultdict(list) # 因为理论上用户最多只能通过 'trainer.on_train_begin' 或者 'trainer.callback_manager.on_train_begin' 来调用,即其是没办法 # 直接调用具体的某一个 callback 函数,而不调用其余的同名的 callback 函数的,因此我们只需要记录具体 Event 的时机即可; self.callback_counter = defaultdict(lambda: 0) - if len(_callbacks): + if len(callbacks): # 这一对象是为了保存原始的类 callback 对象来帮助用户进行 debug,理论上在正常的使用中你并不会需要它; - self.class_callbacks = _callbacks + self.class_callbacks = callbacks else: self.class_callbacks: Optional[List[Callback]] = [] diff --git a/fastNLP/core/callbacks/progress_callback.py b/fastNLP/core/callbacks/progress_callback.py index bacdea48..335345e0 100644 --- a/fastNLP/core/callbacks/progress_callback.py +++ b/fastNLP/core/callbacks/progress_callback.py @@ -11,8 +11,6 @@ __all__ = [ from .has_monitor_callback import HasMonitorCallback from fastNLP.core.utils import f_rich_progress from fastNLP.core.log import logger -from fastNLP.core.utils.utils import is_notebook - class ProgressCallback(HasMonitorCallback): diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 307901b1..5223c9d8 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -19,8 +19,8 @@ from .evaluator import Evaluator from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList from fastNLP.core.callbacks.callback import _CallbackWrapper +from fastNLP.core.callbacks.callback_manager import prepare_callbacks from fastNLP.core.callbacks.callback_events import _SingleEventState -from fastNLP.core.callbacks.progress_callback import choose_progress_callback from fastNLP.core.drivers import Driver from fastNLP.core.drivers.utils import choose_driver from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext @@ -133,7 +133,7 @@ class Trainer(TrainerEventTrigger): ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象, - 默认为 auto , auto 表示如果检测到当前 terminal 为交互型 则使用 RichCallback,否则使用 RawTextCallback对象。如果 + 默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 RichCallback,否则使用 RawTextCallback对象。如果 需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。 train_input_mapping: 与 input_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。 train_output_mapping: 与 output_mapping 一致,但是只用于 train 中。与 output_mapping 互斥。 @@ -212,17 +212,7 @@ class Trainer(TrainerEventTrigger): self.driver.set_optimizers(optimizers=optimizers) # 根据 progress_bar 参数选择 ProgressBarCallback - progress_bar_callback = choose_progress_callback(kwargs.get('progress_bar', 'auto')) - if progress_bar_callback is not None: - if callbacks is None: - callbacks = [] - elif not isinstance(callbacks, Sequence): - callbacks = [callbacks] - - callbacks = list(callbacks) + [progress_bar_callback] - else: - rank_zero_call(logger.warning)("No progress bar is provided, there will have no information output " - "during training.") + callbacks = prepare_callbacks(callbacks, kwargs.get('progress_bar', 'auto')) # 初始化 callback manager; self.callback_manager = CallbackManager(callbacks) # 添加所有的函数式 callbacks; diff --git a/tests/core/callbacks/test_checkpoint_callback_torch.py b/tests/core/callbacks/test_checkpoint_callback_torch.py index ca2a3292..0ae9e801 100644 --- a/tests/core/callbacks/test_checkpoint_callback_torch.py +++ b/tests/core/callbacks/test_checkpoint_callback_torch.py @@ -73,7 +73,7 @@ def model_and_optimizers(request): @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) @pytest.mark.parametrize("version", [0, 1]) @pytest.mark.parametrize("only_state_dict", [True, False]) -@magic_argv_env_context +@magic_argv_env_context(timeout=100) def test_model_checkpoint_callback_1( model_and_optimizers: TrainerParameters, driver, @@ -193,7 +193,7 @@ def test_model_checkpoint_callback_1( trainer.load_model(folder, only_state_dict=only_state_dict) trainer.run() - + trainer.driver.barrier() finally: rank_zero_rm(path) @@ -203,7 +203,7 @@ def test_model_checkpoint_callback_1( @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) @pytest.mark.parametrize("only_state_dict", [True]) -@magic_argv_env_context +@magic_argv_env_context(timeout=100) def test_model_checkpoint_callback_2( model_and_optimizers: TrainerParameters, driver, @@ -283,6 +283,7 @@ def test_model_checkpoint_callback_2( trainer.load_model(folder, only_state_dict=only_state_dict) trainer.run() + trainer.driver.barrier() finally: rank_zero_rm(path) @@ -295,7 +296,7 @@ def test_model_checkpoint_callback_2( @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) @pytest.mark.parametrize("version", [0, 1]) @pytest.mark.parametrize("only_state_dict", [True, False]) -@magic_argv_env_context +@magic_argv_env_context(timeout=100) def test_trainer_checkpoint_callback_1( model_and_optimizers: TrainerParameters, driver, @@ -413,6 +414,7 @@ def test_trainer_checkpoint_callback_1( trainer.load(folder, only_state_dict=only_state_dict) trainer.run() + trainer.driver.barrier() finally: rank_zero_rm(path) @@ -661,6 +663,7 @@ def test_trainer_checkpoint_callback_2( trainer.load(folder, model_load_fn=model_load_fn) trainer.run() + trainer.driver.barrier() finally: rank_zero_rm(path) diff --git a/tests/core/callbacks/test_load_best_model_callback_torch.py b/tests/core/callbacks/test_load_best_model_callback_torch.py index 0bc63bd5..f5b67f95 100644 --- a/tests/core/callbacks/test_load_best_model_callback_torch.py +++ b/tests/core/callbacks/test_load_best_model_callback_torch.py @@ -16,7 +16,6 @@ from fastNLP.core.controllers.trainer import Trainer from fastNLP.core.metrics.accuracy import Accuracy from fastNLP.core.callbacks.load_best_model_callback import LoadBestModelCallback from fastNLP.core import Evaluator -from fastNLP.core.utils.utils import safe_rm from fastNLP.core.drivers.torch_driver import TorchSingleDriver from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchArgMaxDataset @@ -112,7 +111,8 @@ def test_load_best_model_callback( results = evaluator.run() assert np.allclose(callbacks[0].monitor_value, results['acc#acc#dl1']) if save_folder: - safe_rm(save_folder) + import shutil + shutil.rmtree(save_folder, ignore_errors=True) if dist.is_initialized(): dist.destroy_process_group() diff --git a/tests/core/callbacks/test_more_evaluate_callback.py b/tests/core/callbacks/test_more_evaluate_callback.py index 16ee3e17..115f519a 100644 --- a/tests/core/callbacks/test_more_evaluate_callback.py +++ b/tests/core/callbacks/test_more_evaluate_callback.py @@ -171,7 +171,7 @@ def test_model_more_evaluate_callback_1( trainer.load_model(folder, only_state_dict=only_state_dict) trainer.run() - + trainer.driver.barrier() finally: rank_zero_rm(path) @@ -255,6 +255,7 @@ def test_trainer_checkpoint_callback_1( trainer.load(folder, only_state_dict=only_state_dict) trainer.run() + trainer.driver.barrier() finally: rank_zero_rm(path) diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index c0b51a8b..7e02ca0d 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -33,6 +33,8 @@ def recover_logger(fn): def magic_argv_env_context(fn=None, timeout=600): """ 用来在测试时包裹每一个单独的测试函数,使得 ddp 测试正确; + 会丢掉 pytest 中的 arg 参数。 + :param timeout: 表示一个测试如果经过多久还没有通过的话就主动将其 kill 掉,默认为 10 分钟,单位为秒; :return: """ @@ -46,9 +48,10 @@ def magic_argv_env_context(fn=None, timeout=600): env = deepcopy(os.environ.copy()) used_args = [] - for each_arg in sys.argv[1:]: - if "test" not in each_arg: - used_args.append(each_arg) + # for each_arg in sys.argv[1:]: + # # warning,否则 可能导致 pytest -s . 中的点混入其中,导致多卡启动的 collect tests items 不为 1 + # if each_arg.startswith('-'): + # used_args.append(each_arg) pytest_current_test = os.environ.get('PYTEST_CURRENT_TEST')