@@ -9,6 +9,8 @@ __all__ = [ | |||||
from .callback_events import Events | from .callback_events import Events | ||||
from .callback import Callback | from .callback import Callback | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from .progress_callback import ProgressCallback, choose_progress_callback | |||||
from fastNLP.envs import rank_zero_call | |||||
def _transfer(func): | def _transfer(func): | ||||
@@ -26,6 +28,43 @@ def _transfer(func): | |||||
return wrapper | return wrapper | ||||
def prepare_callbacks(callbacks, progress_bar): | |||||
""" | |||||
:param callbacks: | |||||
:param progress_bar: | |||||
:return: | |||||
""" | |||||
_callbacks = [] | |||||
if callbacks is not None: | |||||
if isinstance(callbacks, Callback): | |||||
callbacks = [callbacks] | |||||
if not isinstance(callbacks, Sequence): | |||||
raise ValueError("Parameter `callbacks` should be type 'List' or 'Tuple'.") | |||||
callbacks = list(callbacks) | |||||
for _callback in callbacks: | |||||
if not isinstance(_callback, Callback): | |||||
raise TypeError(f"callbacks must be of Callback type, instead of `{type(_callback)}`") | |||||
_callbacks += callbacks | |||||
has_no_progress = False | |||||
for _callback in _callbacks: | |||||
if isinstance(_callback, ProgressCallback): | |||||
has_no_progress = True | |||||
if not has_no_progress: | |||||
callback = choose_progress_callback(progress_bar) | |||||
if callback is not None: | |||||
_callbacks.append(callback) | |||||
elif progress_bar is not None and progress_bar != 'auto': | |||||
logger.warning(f"Since you have passed in ProgressBar callback, progress_bar will be ignored.") | |||||
if has_no_progress and progress_bar is None: | |||||
rank_zero_call(logger.warning)("No progress bar is provided, there will have no information output " | |||||
"during training.") | |||||
return _callbacks | |||||
class CallbackManager: | class CallbackManager: | ||||
r""" | r""" | ||||
用来管理训练过程中的所有的 callback 实例; | 用来管理训练过程中的所有的 callback 实例; | ||||
@@ -45,24 +84,13 @@ class CallbackManager: | |||||
""" | """ | ||||
self._need_reproducible_sampler = False | self._need_reproducible_sampler = False | ||||
_callbacks = [] | |||||
if callbacks is not None: | |||||
if isinstance(callbacks, Callback): | |||||
callbacks = [callbacks] | |||||
if not isinstance(callbacks, Sequence): | |||||
raise ValueError("Parameter `callbacks` should be type 'List' or 'Tuple'.") | |||||
callbacks = list(callbacks) | |||||
for _callback in callbacks: | |||||
if not isinstance(_callback, Callback): | |||||
raise TypeError(f"callbacks must be of Callback type, instead of `{type(_callback)}`") | |||||
_callbacks += callbacks | |||||
self.callback_fns = defaultdict(list) | self.callback_fns = defaultdict(list) | ||||
# 因为理论上用户最多只能通过 'trainer.on_train_begin' 或者 'trainer.callback_manager.on_train_begin' 来调用,即其是没办法 | # 因为理论上用户最多只能通过 'trainer.on_train_begin' 或者 'trainer.callback_manager.on_train_begin' 来调用,即其是没办法 | ||||
# 直接调用具体的某一个 callback 函数,而不调用其余的同名的 callback 函数的,因此我们只需要记录具体 Event 的时机即可; | # 直接调用具体的某一个 callback 函数,而不调用其余的同名的 callback 函数的,因此我们只需要记录具体 Event 的时机即可; | ||||
self.callback_counter = defaultdict(lambda: 0) | self.callback_counter = defaultdict(lambda: 0) | ||||
if len(_callbacks): | |||||
if len(callbacks): | |||||
# 这一对象是为了保存原始的类 callback 对象来帮助用户进行 debug,理论上在正常的使用中你并不会需要它; | # 这一对象是为了保存原始的类 callback 对象来帮助用户进行 debug,理论上在正常的使用中你并不会需要它; | ||||
self.class_callbacks = _callbacks | |||||
self.class_callbacks = callbacks | |||||
else: | else: | ||||
self.class_callbacks: Optional[List[Callback]] = [] | self.class_callbacks: Optional[List[Callback]] = [] | ||||
@@ -11,8 +11,6 @@ __all__ = [ | |||||
from .has_monitor_callback import HasMonitorCallback | from .has_monitor_callback import HasMonitorCallback | ||||
from fastNLP.core.utils import f_rich_progress | from fastNLP.core.utils import f_rich_progress | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.utils.utils import is_notebook | |||||
class ProgressCallback(HasMonitorCallback): | class ProgressCallback(HasMonitorCallback): | ||||
@@ -19,8 +19,8 @@ from .evaluator import Evaluator | |||||
from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader | from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader | ||||
from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList | from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList | ||||
from fastNLP.core.callbacks.callback import _CallbackWrapper | from fastNLP.core.callbacks.callback import _CallbackWrapper | ||||
from fastNLP.core.callbacks.callback_manager import prepare_callbacks | |||||
from fastNLP.core.callbacks.callback_events import _SingleEventState | from fastNLP.core.callbacks.callback_events import _SingleEventState | ||||
from fastNLP.core.callbacks.progress_callback import choose_progress_callback | |||||
from fastNLP.core.drivers import Driver | from fastNLP.core.drivers import Driver | ||||
from fastNLP.core.drivers.utils import choose_driver | from fastNLP.core.drivers.utils import choose_driver | ||||
from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext | from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext | ||||
@@ -133,7 +133,7 @@ class Trainer(TrainerEventTrigger): | |||||
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | ||||
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | ||||
progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象, | progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象, | ||||
默认为 auto , auto 表示如果检测到当前 terminal 为交互型 则使用 RichCallback,否则使用 RawTextCallback对象。如果 | |||||
默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 RichCallback,否则使用 RawTextCallback对象。如果 | |||||
需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。 | 需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。 | ||||
train_input_mapping: 与 input_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。 | train_input_mapping: 与 input_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。 | ||||
train_output_mapping: 与 output_mapping 一致,但是只用于 train 中。与 output_mapping 互斥。 | train_output_mapping: 与 output_mapping 一致,但是只用于 train 中。与 output_mapping 互斥。 | ||||
@@ -212,17 +212,7 @@ class Trainer(TrainerEventTrigger): | |||||
self.driver.set_optimizers(optimizers=optimizers) | self.driver.set_optimizers(optimizers=optimizers) | ||||
# 根据 progress_bar 参数选择 ProgressBarCallback | # 根据 progress_bar 参数选择 ProgressBarCallback | ||||
progress_bar_callback = choose_progress_callback(kwargs.get('progress_bar', 'auto')) | |||||
if progress_bar_callback is not None: | |||||
if callbacks is None: | |||||
callbacks = [] | |||||
elif not isinstance(callbacks, Sequence): | |||||
callbacks = [callbacks] | |||||
callbacks = list(callbacks) + [progress_bar_callback] | |||||
else: | |||||
rank_zero_call(logger.warning)("No progress bar is provided, there will have no information output " | |||||
"during training.") | |||||
callbacks = prepare_callbacks(callbacks, kwargs.get('progress_bar', 'auto')) | |||||
# 初始化 callback manager; | # 初始化 callback manager; | ||||
self.callback_manager = CallbackManager(callbacks) | self.callback_manager = CallbackManager(callbacks) | ||||
# 添加所有的函数式 callbacks; | # 添加所有的函数式 callbacks; | ||||
@@ -73,7 +73,7 @@ def model_and_optimizers(request): | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | ||||
@pytest.mark.parametrize("version", [0, 1]) | @pytest.mark.parametrize("version", [0, 1]) | ||||
@pytest.mark.parametrize("only_state_dict", [True, False]) | @pytest.mark.parametrize("only_state_dict", [True, False]) | ||||
@magic_argv_env_context | |||||
@magic_argv_env_context(timeout=100) | |||||
def test_model_checkpoint_callback_1( | def test_model_checkpoint_callback_1( | ||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
driver, | driver, | ||||
@@ -193,7 +193,7 @@ def test_model_checkpoint_callback_1( | |||||
trainer.load_model(folder, only_state_dict=only_state_dict) | trainer.load_model(folder, only_state_dict=only_state_dict) | ||||
trainer.run() | trainer.run() | ||||
trainer.driver.barrier() | |||||
finally: | finally: | ||||
rank_zero_rm(path) | rank_zero_rm(path) | ||||
@@ -203,7 +203,7 @@ def test_model_checkpoint_callback_1( | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | ||||
@pytest.mark.parametrize("only_state_dict", [True]) | @pytest.mark.parametrize("only_state_dict", [True]) | ||||
@magic_argv_env_context | |||||
@magic_argv_env_context(timeout=100) | |||||
def test_model_checkpoint_callback_2( | def test_model_checkpoint_callback_2( | ||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
driver, | driver, | ||||
@@ -283,6 +283,7 @@ def test_model_checkpoint_callback_2( | |||||
trainer.load_model(folder, only_state_dict=only_state_dict) | trainer.load_model(folder, only_state_dict=only_state_dict) | ||||
trainer.run() | trainer.run() | ||||
trainer.driver.barrier() | |||||
finally: | finally: | ||||
rank_zero_rm(path) | rank_zero_rm(path) | ||||
@@ -295,7 +296,7 @@ def test_model_checkpoint_callback_2( | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | ||||
@pytest.mark.parametrize("version", [0, 1]) | @pytest.mark.parametrize("version", [0, 1]) | ||||
@pytest.mark.parametrize("only_state_dict", [True, False]) | @pytest.mark.parametrize("only_state_dict", [True, False]) | ||||
@magic_argv_env_context | |||||
@magic_argv_env_context(timeout=100) | |||||
def test_trainer_checkpoint_callback_1( | def test_trainer_checkpoint_callback_1( | ||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
driver, | driver, | ||||
@@ -413,6 +414,7 @@ def test_trainer_checkpoint_callback_1( | |||||
trainer.load(folder, only_state_dict=only_state_dict) | trainer.load(folder, only_state_dict=only_state_dict) | ||||
trainer.run() | trainer.run() | ||||
trainer.driver.barrier() | |||||
finally: | finally: | ||||
rank_zero_rm(path) | rank_zero_rm(path) | ||||
@@ -661,6 +663,7 @@ def test_trainer_checkpoint_callback_2( | |||||
trainer.load(folder, model_load_fn=model_load_fn) | trainer.load(folder, model_load_fn=model_load_fn) | ||||
trainer.run() | trainer.run() | ||||
trainer.driver.barrier() | |||||
finally: | finally: | ||||
rank_zero_rm(path) | rank_zero_rm(path) | ||||
@@ -16,7 +16,6 @@ from fastNLP.core.controllers.trainer import Trainer | |||||
from fastNLP.core.metrics.accuracy import Accuracy | from fastNLP.core.metrics.accuracy import Accuracy | ||||
from fastNLP.core.callbacks.load_best_model_callback import LoadBestModelCallback | from fastNLP.core.callbacks.load_best_model_callback import LoadBestModelCallback | ||||
from fastNLP.core import Evaluator | from fastNLP.core import Evaluator | ||||
from fastNLP.core.utils.utils import safe_rm | |||||
from fastNLP.core.drivers.torch_driver import TorchSingleDriver | from fastNLP.core.drivers.torch_driver import TorchSingleDriver | ||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from tests.helpers.datasets.torch_data import TorchArgMaxDataset | from tests.helpers.datasets.torch_data import TorchArgMaxDataset | ||||
@@ -112,7 +111,8 @@ def test_load_best_model_callback( | |||||
results = evaluator.run() | results = evaluator.run() | ||||
assert np.allclose(callbacks[0].monitor_value, results['acc#acc#dl1']) | assert np.allclose(callbacks[0].monitor_value, results['acc#acc#dl1']) | ||||
if save_folder: | if save_folder: | ||||
safe_rm(save_folder) | |||||
import shutil | |||||
shutil.rmtree(save_folder, ignore_errors=True) | |||||
if dist.is_initialized(): | if dist.is_initialized(): | ||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@@ -171,7 +171,7 @@ def test_model_more_evaluate_callback_1( | |||||
trainer.load_model(folder, only_state_dict=only_state_dict) | trainer.load_model(folder, only_state_dict=only_state_dict) | ||||
trainer.run() | trainer.run() | ||||
trainer.driver.barrier() | |||||
finally: | finally: | ||||
rank_zero_rm(path) | rank_zero_rm(path) | ||||
@@ -255,6 +255,7 @@ def test_trainer_checkpoint_callback_1( | |||||
trainer.load(folder, only_state_dict=only_state_dict) | trainer.load(folder, only_state_dict=only_state_dict) | ||||
trainer.run() | trainer.run() | ||||
trainer.driver.barrier() | |||||
finally: | finally: | ||||
rank_zero_rm(path) | rank_zero_rm(path) | ||||
@@ -33,6 +33,8 @@ def recover_logger(fn): | |||||
def magic_argv_env_context(fn=None, timeout=600): | def magic_argv_env_context(fn=None, timeout=600): | ||||
""" | """ | ||||
用来在测试时包裹每一个单独的测试函数,使得 ddp 测试正确; | 用来在测试时包裹每一个单独的测试函数,使得 ddp 测试正确; | ||||
会丢掉 pytest 中的 arg 参数。 | |||||
:param timeout: 表示一个测试如果经过多久还没有通过的话就主动将其 kill 掉,默认为 10 分钟,单位为秒; | :param timeout: 表示一个测试如果经过多久还没有通过的话就主动将其 kill 掉,默认为 10 分钟,单位为秒; | ||||
:return: | :return: | ||||
""" | """ | ||||
@@ -46,9 +48,10 @@ def magic_argv_env_context(fn=None, timeout=600): | |||||
env = deepcopy(os.environ.copy()) | env = deepcopy(os.environ.copy()) | ||||
used_args = [] | used_args = [] | ||||
for each_arg in sys.argv[1:]: | |||||
if "test" not in each_arg: | |||||
used_args.append(each_arg) | |||||
# for each_arg in sys.argv[1:]: | |||||
# # warning,否则 可能导致 pytest -s . 中的点混入其中,导致多卡启动的 collect tests items 不为 1 | |||||
# if each_arg.startswith('-'): | |||||
# used_args.append(each_arg) | |||||
pytest_current_test = os.environ.get('PYTEST_CURRENT_TEST') | pytest_current_test = os.environ.get('PYTEST_CURRENT_TEST') | ||||