diff --git a/fastNLP/core/callbacks/callback.py b/fastNLP/core/callbacks/callback.py index b37eda63..117cb524 100644 --- a/fastNLP/core/callbacks/callback.py +++ b/fastNLP/core/callbacks/callback.py @@ -1,16 +1,12 @@ -from typing import Union, Callable, Dict, Optional, Any -from abc import ABC __all__ = [ 'Callback', ] +from typing import Union, Callable, Dict, Optional, Any + from .callback_events import Events, EventsList, Filter -from .utils import _get_monitor_value from fastNLP.core.callbacks.callback_events import _SingleEventState -from fastNLP.core.log import logger -from fastNLP.core.utils import apply_to_collection -from fastNLP.core.utils.utils import _check_valid_parameters_number class Callback: @@ -278,135 +274,3 @@ class _CallbackWrapper(Callback): @property def callback_name(self): return self.fn.__name__ - - -class CanItemDataType(ABC): - """ - 检测可以进行传输的对象。 - - """ - - @classmethod - def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: - if cls is CanItemDataType: - item = getattr(subclass, 'item', None) - return callable(item) - return NotImplemented - - -class HasMonitorCallback(Callback): - def __init__(self, monitor, larger_better, must_have_monitor=False): - self.set_monitor(monitor, larger_better) - self.must_have_moinitor = must_have_monitor - - def set_monitor(self, monitor, larger_better): - if callable(monitor): # 检查是否能够接受一个参数 - _check_valid_parameters_number(monitor, expected_params=['results'], fn_name='monitor') - self.monitor = monitor - else: - self.monitor = str(monitor) if monitor is not None else None - self.larger_better = bool(larger_better) - if larger_better: - self.monitor_value = float('-inf') - else: - self.monitor_value = float('inf') - self._real_monitor = self.monitor - - def on_after_trainer_initialized(self, trainer, driver): - """ - 如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。 - 同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。 - - :param trainer: - :param driver: - :return: - """ - if self.monitor is None and trainer.monitor is not None: - self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better) - if self.must_have_moinitor and self.monitor is None: - raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " - f"You can set it in the initialization or through Trainer.") - - def get_monitor_value(self, results:Dict)->Union[float, None]: - """ - 获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。 - - :param results: - :return: 如果为 None ,表明此次没有找到合适的monitor - """ - if len(results)==0: - return None - # 保证所有的 tensor 都被转换为了 python 特定的类型 - results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) - use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, - real_monitor=self._real_monitor, - res=results) - if monitor_value is None: - return monitor_value - # 第一次运行 - if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor: - logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " - f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.") - # 检测到此次和上次不同。 - elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor: - logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. " - f"The expected monitor is:`{self.monitor}`, last used monitor is:" - f"`{self._real_monitor}` and current monitor is:`{use_monitor}`. Please consider using a " - f"customized monitor function when the evaluation results are varying between validation.") - - self._real_monitor = use_monitor - return monitor_value - - def is_better_monitor_value(self, monitor_value: float, keep_if_better=True): - """ - 检测 monitor_value 是否是更好的 - - :param monitor_value: 待检查的 monitor_value 。如果为 None ,返回 False - :param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 - :return: - """ - if monitor_value is None: - return False - better = self.is_former_monitor_value_better(monitor_value, self.monitor_value) - if keep_if_better and better: - self.monitor_value = monitor_value - return better - - def is_former_monitor_value_better(self, monitor_value1, monitor_value2): - """ - 传入的两个值中,是否monitor_value1的结果更好。 - - :param monitor_value1: - :param monitor_value2: - :return: - """ - if monitor_value1 is None and monitor_value2 is None: - return True - if monitor_value1 is None: - return False - if monitor_value2 is None: - return True - better = False - if (self.larger_better and monitor_value1 > monitor_value2) or \ - (not self.larger_better and monitor_value1 < monitor_value2): - better = True - return better - - @property - def monitor_name(self): - """ - 返回 monitor 的名字,如果 monitor 是个 callable 的函数,则返回该函数的名称。 - - :return: - """ - if callable(self.monitor): - try: - monitor_name = self.monitor.__qualname__ - except: - monitor_name = self.monitor.__name__ - elif self.monitor is None: - return None - else: - # 这里是能是monitor,而不能是real_monitor,因为用户再次运行的时候real_monitor被初始化为monitor了 - monitor_name = str(self.monitor) - return monitor_name diff --git a/fastNLP/core/callbacks/checkpoint_callback.py b/fastNLP/core/callbacks/checkpoint_callback.py index d2d97294..b13632d1 100644 --- a/fastNLP/core/callbacks/checkpoint_callback.py +++ b/fastNLP/core/callbacks/checkpoint_callback.py @@ -10,9 +10,9 @@ from copy import deepcopy import fastNLP -from .callback import HasMonitorCallback +from .has_monitor_callback import HasMonitorCallback from fastNLP.core.log import logger -from fastNLP.envs import FASTNLP_LAUNCH_TIME +from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_GLOBAL_RANK from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir @@ -217,7 +217,8 @@ class CheckpointCallback(HasMonitorCallback): :return: """ folder = self.timestamp_path.joinpath(folder_name) - synchronize_mkdir(folder) + if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: # 只在进程0上创建 + synchronize_mkdir(folder) _fn = getattr(trainer, self.save_fn_name) _fn( folder=folder, diff --git a/fastNLP/core/callbacks/early_stop_callback.py b/fastNLP/core/callbacks/early_stop_callback.py index c679ad7e..0923eb00 100644 --- a/fastNLP/core/callbacks/early_stop_callback.py +++ b/fastNLP/core/callbacks/early_stop_callback.py @@ -4,7 +4,7 @@ __all__ = [ from typing import Dict, Union, Callable -from .callback import HasMonitorCallback +from .has_monitor_callback import HasMonitorCallback from fastNLP.core.utils.exceptions import EarlyStopException diff --git a/fastNLP/core/callbacks/has_monitor_callback.py b/fastNLP/core/callbacks/has_monitor_callback.py new file mode 100644 index 00000000..54bd9bb4 --- /dev/null +++ b/fastNLP/core/callbacks/has_monitor_callback.py @@ -0,0 +1,189 @@ +__all__ = [ + 'HasMonitorCallback', + 'ExecuteOnceBetterMonitor' +] + +from typing import Dict, Union, Any +from abc import ABC + +from fastNLP.core.utils import apply_to_collection +from fastNLP.core.callbacks import Callback +from fastNLP.core.callbacks.utils import _get_monitor_value +from fastNLP.core.log import logger +from fastNLP.core.utils.utils import _check_valid_parameters_number + + +class CanItemDataType(ABC): + """ + 检测可以进行传输的对象。 + + """ + + @classmethod + def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: + if cls is CanItemDataType: + item = getattr(subclass, 'item', None) + return callable(item) + return NotImplemented + + + +class HasMonitorCallback(Callback): + def __init__(self, monitor, larger_better, must_have_monitor=False): + """ + 该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了 + (1)判断monitor合法性;(2)在需要时, 根据trainer的monitor设置自己的monitor名称。 + + :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 + 果(字典类型),返回一个 float 值作为 monitor 的结果。 + :param larger_better: monitor 是否时越大越好 + :param must_have_monitor: 这个 callback 是否必须有 monitor 设置。如果设置为 True ,且没检测到设置 monitor 会报错。 + """ + self.set_monitor(monitor, larger_better) + self.must_have_moinitor = must_have_monitor + + def set_monitor(self, monitor, larger_better): + if callable(monitor): # 检查是否能够接受一个参数 + _check_valid_parameters_number(monitor, expected_params=['results'], fn_name='monitor') + self.monitor = monitor + else: + self.monitor = str(monitor) if monitor is not None else None + self.larger_better = bool(larger_better) + if larger_better: + self.monitor_value = float('-inf') + else: + self.monitor_value = float('inf') + self._real_monitor = self.monitor + + def on_after_trainer_initialized(self, trainer, driver): + """ + 如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。 + 同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。 + + :param trainer: + :param driver: + :return: + """ + if self.monitor is None and trainer.monitor is not None: + self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better) + if self.must_have_moinitor and self.monitor is None: + raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " + f"You can set it in the initialization or through Trainer.") + + def get_monitor_value(self, results:Dict)->Union[float, None]: + """ + 获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。 + + :param results: + :return: 如果为 None ,表明此次没有找到合适的monitor + """ + if len(results)==0: + return None + # 保证所有的 tensor 都被转换为了 python 特定的类型 + results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) + use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, + real_monitor=self._real_monitor, + res=results) + if monitor_value is None: + return monitor_value + # 第一次运行 + if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor: + logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " + f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.") + # 检测到此次和上次不同。 + elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor: + logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. " + f"The expected monitor is:`{self.monitor}`, last used monitor is:" + f"`{self._real_monitor}` and current monitor is:`{use_monitor}`. Please consider using a " + f"customized monitor function when the evaluation results are varying between validation.") + + self._real_monitor = use_monitor + return monitor_value + + def is_better_monitor_value(self, monitor_value: float, keep_if_better=True): + """ + 检测 monitor_value 是否是更好的 + + :param monitor_value: 待检查的 monitor_value 。如果为 None ,返回 False + :param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 + :return: + """ + if monitor_value is None: + return False + better = self.is_former_monitor_value_better(monitor_value, self.monitor_value) + if keep_if_better and better: + self.monitor_value = monitor_value + return better + + def is_better_results(self, results, keep_if_better=True): + """ + 检测给定的 results 是否比上一次更好,如果本次 results 中没有找到相关的monitor 返回 False。 + + :param results: on_valid_ends() 接口中传入的 evaluation 结果。 + :param keep_if_better: 当返回为 True 时,是否保存到 self.monitor_value 中。 + :return: + """ + monitor_value = self.get_monitor_value(results) + if monitor_value is None: + return False + return self.is_better_monitor_value(monitor_value, keep_if_better=keep_if_better) + + def is_former_monitor_value_better(self, monitor_value1, monitor_value2): + """ + 传入的两个值中,是否monitor_value1的结果更好。 + + :param monitor_value1: + :param monitor_value2: + :return: + """ + if monitor_value1 is None and monitor_value2 is None: + return True + if monitor_value1 is None: + return False + if monitor_value2 is None: + return True + better = False + if (self.larger_better and monitor_value1 > monitor_value2) or \ + (not self.larger_better and monitor_value1 < monitor_value2): + better = True + return better + + @property + def monitor_name(self): + """ + 返回 monitor 的名字,如果 monitor 是个 callable 的函数,则返回该函数的名称。 + + :return: + """ + if callable(self.monitor): + try: + monitor_name = self.monitor.__qualname__ + except: + monitor_name = self.monitor.__name__ + elif self.monitor is None: + return None + else: + # 这里是能是monitor,而不能是real_monitor,因为用户再次运行的时候real_monitor被初始化为monitor了 + monitor_name = str(self.monitor) + return monitor_name + + +class ExecuteOnceBetterMonitor(HasMonitorCallback): + def __init__(self, monitor, larger_better, execute_fn): + """ + 当监控的 monitor 结果更好的时候,调用 execute_fn 函数。 + + :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 + 果(字典类型),返回一个 float 值作为 monitor 的结果。 + :param larger_better: monitor 是否时越大越好 + :param execute_fn: 一个可执行的函数,不接受任何参数,不反回值。在 monitor 取得更好结果的时候会调用。 + """ + super().__init__(monitor, larger_better, must_have_monitor=True) + _check_valid_parameters_number(execute_fn, expected_params=[], fn_name='execute_fn') + self.execute_fn = execute_fn() + + def on_validate_end(self, trainer, results): + if self.is_better_results(results): + self.execute_fn() \ No newline at end of file diff --git a/fastNLP/core/callbacks/load_best_model_callback.py b/fastNLP/core/callbacks/load_best_model_callback.py index 09f85d01..f240caa7 100644 --- a/fastNLP/core/callbacks/load_best_model_callback.py +++ b/fastNLP/core/callbacks/load_best_model_callback.py @@ -4,7 +4,7 @@ __all__ = [ import os from typing import Optional, Callable, Union -from .callback import HasMonitorCallback +from .has_monitor_callback import HasMonitorCallback from io import BytesIO import shutil @@ -80,10 +80,7 @@ class LoadBestModelCallback(HasMonitorCallback): self.get_monitor_value(sanity_check_res) def on_validate_end(self, trainer, results): - monitor_value = self.get_monitor_value(results) - if monitor_value is None: - return - if self.is_better_monitor_value(monitor_value, keep_if_better=True): + if self.is_better_results(results, keep_if_better=True): if self.real_save_folder: trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, model_save_fn=self.model_save_fn) diff --git a/fastNLP/core/callbacks/progress_callback.py b/fastNLP/core/callbacks/progress_callback.py index f351f204..bb638122 100644 --- a/fastNLP/core/callbacks/progress_callback.py +++ b/fastNLP/core/callbacks/progress_callback.py @@ -8,7 +8,7 @@ __all__ = [ 'RichCallback' ] -from .callback import HasMonitorCallback +from .has_monitor_callback import HasMonitorCallback from fastNLP.core.callbacks.utils import _get_monitor_value from fastNLP.core.utils import f_rich_progress from fastNLP.core.log import logger diff --git a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py index 2c9c5162..f149855f 100644 --- a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py @@ -27,7 +27,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic # world_size 和 rank if FASTNLP_BACKEND_LAUNCH in os.environ: if device is not None: - logger.info("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull " + logger.warning_once("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull " "up your script. And we will directly get the local device via " "`os.environ['LOCAL_RANK']`.") return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs) diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 233d7040..f00d3f1f 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -25,7 +25,7 @@ __all__ = [ from .utils import optimizer_state_to_device from fastNLP.core.drivers.driver import Driver -from fastNLP.core.drivers.torch_driver.utils import _build_fp16_env +from fastNLP.core.drivers.torch_driver.utils import _build_fp16_env, DummyGradScaler from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device from fastNLP.envs import rank_zero_call from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME @@ -224,6 +224,11 @@ class TorchDriver(Driver): optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu")) optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; + # 4. 保存fp16的状态 + if not isinstance(self.grad_scaler, DummyGradScaler): + grad_scaler_state_dict = self.grad_scaler.state_dict() + states['grad_scaler_state_dict'] = grad_scaler_state_dict + logger.debug("Save optimizer state dict") states["optimizers_state_dict"] = optimizers_state_dict torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) @@ -232,7 +237,7 @@ class TorchDriver(Driver): states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) # 1. 加载 optimizers 的状态; - optimizers_state_dict = states["optimizers_state_dict"] + optimizers_state_dict = states.pop("optimizers_state_dict") for i in range(len(self.optimizers)): optimizer: torch.optim.Optimizer = self.optimizers[i] optimizer.load_state_dict(optimizers_state_dict[f"optimizer{i}"]) @@ -244,26 +249,37 @@ class TorchDriver(Driver): res = torch.load(folder.joinpath(FASTNLP_MODEL_FILENAME), map_location='cpu') if only_state_dict: model.load_state_dict(res) - logger.debug("Load model state dict.") + logger.debug("Load model state dict...") else: model.load_state_dict(res.state_dict()) - logger.debug("Load model.") - - # 3. 恢复 sampler 的状态; + logger.debug("Load model...") + + # 3. 加载fp16的状态 + if 'grad_scaler_state_dict' in states: + grad_scaler_state_dict = states.pop('grad_scaler_state_dict') + if not isinstance(self.grad_scaler, DummyGradScaler): + self.grad_scaler.load_state_dict(grad_scaler_state_dict) + logger.debug("Load grad_scaler state dict...") + elif not isinstance(self.grad_scaler, DummyGradScaler): + logger.warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, " + f"the training process may be unstable.") + + # 4. 恢复 sampler 的状态; dataloader_args = self.get_dataloader_args(dataloader) if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): sampler = dataloader_args.batch_sampler elif isinstance(dataloader_args.sampler, ReproducibleSampler): sampler = dataloader_args.sampler elif self.is_distributed(): - raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") + raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " + "`ReproducibleSampler`.") else: sampler = RandomBatchSampler( batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, batch_size=dataloader_args.batch_size, drop_last=dataloader_args.drop_last ) - sampler.load_state_dict(states['sampler_states']) + sampler.load_state_dict(states.pop('sampler_states')) states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) # 4. 修改 trainer_state.batch_idx_in_epoch diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py index 729ca960..e0d94cc8 100644 --- a/fastNLP/core/utils/utils.py +++ b/fastNLP/core/utils/utils.py @@ -203,7 +203,7 @@ def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None): :return: """ if fn_name is not None: - assert callable(fn), f"{fn_name} should be callable, instead of {type(fn)}." + assert callable(fn), f"`{fn_name}` should be callable, instead of `{type(fn)}`." parameters = list(inspect.signature(fn).parameters.values()) if inspect.ismethod(fn): @@ -606,16 +606,38 @@ def seq_len_to_mask(seq_len, max_len=None): return mask -def wait_to_success(fn, no=False): +def wait_filepath(path, exist=True): + """ + 等待当 path 的存在状态为 {exist} 时返回 + + :param path: 待检测的 path + :param exist: 为 True 时表明检测这个 path 存在就返回; 为 False 表明检测到这个 path 不存在 返回。 + :return: + """ + if isinstance(path, str): + path = Path(path) + assert isinstance(path, Path) + count = 0 while True: sleep(0.01) - if (no and not fn()) or (not no and fn()): + if path.exists() == exist: break + count += 1 + if count % 1000 == 0: + msg = 'create' if exist else 'delete' + logger.warning(f"Waiting path:{path} to {msg} for {count*0.01} seconds...") + -# 这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候 -# 在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件; def synchronize_safe_rm(path: Optional[Union[str, Path]]): + """ + 这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候 + 在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件; + 该函数会保证所有进程都检测到 path 删除之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。 + + :param path: + :return: + """ if path is None: return if isinstance(path, str): @@ -624,7 +646,7 @@ def synchronize_safe_rm(path: Optional[Union[str, Path]]): return if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: _recursive_rm(path) - wait_to_success(path.exists, no=True) + wait_filepath(path, exist=False) def _recursive_rm(path: Path): @@ -643,6 +665,8 @@ def _recursive_rm(path: Path): def synchronize_mkdir(path: Optional[Union[str, Path]]): """ 注意该函数是用来创建文件夹,如果需要创建一个文件,不要使用该函数; + 该函数会保证所有进程都检测到 path 创建之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。 + """ if path is None: return @@ -652,7 +676,7 @@ def synchronize_mkdir(path: Optional[Union[str, Path]]): if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: path.mkdir(parents=True, exist_ok=True) - wait_to_success(path.exists) + wait_filepath(path, exist=True) def get_class_that_defined_method(method): diff --git a/tests/envs/test_set_backend.py b/tests/envs/test_set_backend.py index 2c8fbadf..03931bdc 100644 --- a/tests/envs/test_set_backend.py +++ b/tests/envs/test_set_backend.py @@ -1,6 +1,6 @@ import os -from fastNLP.envs.set_env import dump_fastnlp_backend +from fastNLP.envs.set_backend import dump_fastnlp_backend from tests.helpers.utils import Capturing from fastNLP.core import synchronize_safe_rm