| @@ -1,16 +1,12 @@ | |||||
| from typing import Union, Callable, Dict, Optional, Any | |||||
| from abc import ABC | |||||
| __all__ = [ | __all__ = [ | ||||
| 'Callback', | 'Callback', | ||||
| ] | ] | ||||
| from typing import Union, Callable, Dict, Optional, Any | |||||
| from .callback_events import Events, EventsList, Filter | from .callback_events import Events, EventsList, Filter | ||||
| from .utils import _get_monitor_value | |||||
| from fastNLP.core.callbacks.callback_events import _SingleEventState | from fastNLP.core.callbacks.callback_events import _SingleEventState | ||||
| from fastNLP.core.log import logger | |||||
| from fastNLP.core.utils import apply_to_collection | |||||
| from fastNLP.core.utils.utils import _check_valid_parameters_number | |||||
| class Callback: | class Callback: | ||||
| @@ -278,135 +274,3 @@ class _CallbackWrapper(Callback): | |||||
| @property | @property | ||||
| def callback_name(self): | def callback_name(self): | ||||
| return self.fn.__name__ | return self.fn.__name__ | ||||
| class CanItemDataType(ABC): | |||||
| """ | |||||
| 检测可以进行传输的对象。 | |||||
| """ | |||||
| @classmethod | |||||
| def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: | |||||
| if cls is CanItemDataType: | |||||
| item = getattr(subclass, 'item', None) | |||||
| return callable(item) | |||||
| return NotImplemented | |||||
| class HasMonitorCallback(Callback): | |||||
| def __init__(self, monitor, larger_better, must_have_monitor=False): | |||||
| self.set_monitor(monitor, larger_better) | |||||
| self.must_have_moinitor = must_have_monitor | |||||
| def set_monitor(self, monitor, larger_better): | |||||
| if callable(monitor): # 检查是否能够接受一个参数 | |||||
| _check_valid_parameters_number(monitor, expected_params=['results'], fn_name='monitor') | |||||
| self.monitor = monitor | |||||
| else: | |||||
| self.monitor = str(monitor) if monitor is not None else None | |||||
| self.larger_better = bool(larger_better) | |||||
| if larger_better: | |||||
| self.monitor_value = float('-inf') | |||||
| else: | |||||
| self.monitor_value = float('inf') | |||||
| self._real_monitor = self.monitor | |||||
| def on_after_trainer_initialized(self, trainer, driver): | |||||
| """ | |||||
| 如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。 | |||||
| 同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。 | |||||
| :param trainer: | |||||
| :param driver: | |||||
| :return: | |||||
| """ | |||||
| if self.monitor is None and trainer.monitor is not None: | |||||
| self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better) | |||||
| if self.must_have_moinitor and self.monitor is None: | |||||
| raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " | |||||
| f"You can set it in the initialization or through Trainer.") | |||||
| def get_monitor_value(self, results:Dict)->Union[float, None]: | |||||
| """ | |||||
| 获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。 | |||||
| :param results: | |||||
| :return: 如果为 None ,表明此次没有找到合适的monitor | |||||
| """ | |||||
| if len(results)==0: | |||||
| return None | |||||
| # 保证所有的 tensor 都被转换为了 python 特定的类型 | |||||
| results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) | |||||
| use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | |||||
| real_monitor=self._real_monitor, | |||||
| res=results) | |||||
| if monitor_value is None: | |||||
| return monitor_value | |||||
| # 第一次运行 | |||||
| if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor: | |||||
| logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " | |||||
| f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.") | |||||
| # 检测到此次和上次不同。 | |||||
| elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor: | |||||
| logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. " | |||||
| f"The expected monitor is:`{self.monitor}`, last used monitor is:" | |||||
| f"`{self._real_monitor}` and current monitor is:`{use_monitor}`. Please consider using a " | |||||
| f"customized monitor function when the evaluation results are varying between validation.") | |||||
| self._real_monitor = use_monitor | |||||
| return monitor_value | |||||
| def is_better_monitor_value(self, monitor_value: float, keep_if_better=True): | |||||
| """ | |||||
| 检测 monitor_value 是否是更好的 | |||||
| :param monitor_value: 待检查的 monitor_value 。如果为 None ,返回 False | |||||
| :param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 | |||||
| :return: | |||||
| """ | |||||
| if monitor_value is None: | |||||
| return False | |||||
| better = self.is_former_monitor_value_better(monitor_value, self.monitor_value) | |||||
| if keep_if_better and better: | |||||
| self.monitor_value = monitor_value | |||||
| return better | |||||
| def is_former_monitor_value_better(self, monitor_value1, monitor_value2): | |||||
| """ | |||||
| 传入的两个值中,是否monitor_value1的结果更好。 | |||||
| :param monitor_value1: | |||||
| :param monitor_value2: | |||||
| :return: | |||||
| """ | |||||
| if monitor_value1 is None and monitor_value2 is None: | |||||
| return True | |||||
| if monitor_value1 is None: | |||||
| return False | |||||
| if monitor_value2 is None: | |||||
| return True | |||||
| better = False | |||||
| if (self.larger_better and monitor_value1 > monitor_value2) or \ | |||||
| (not self.larger_better and monitor_value1 < monitor_value2): | |||||
| better = True | |||||
| return better | |||||
| @property | |||||
| def monitor_name(self): | |||||
| """ | |||||
| 返回 monitor 的名字,如果 monitor 是个 callable 的函数,则返回该函数的名称。 | |||||
| :return: | |||||
| """ | |||||
| if callable(self.monitor): | |||||
| try: | |||||
| monitor_name = self.monitor.__qualname__ | |||||
| except: | |||||
| monitor_name = self.monitor.__name__ | |||||
| elif self.monitor is None: | |||||
| return None | |||||
| else: | |||||
| # 这里是能是monitor,而不能是real_monitor,因为用户再次运行的时候real_monitor被初始化为monitor了 | |||||
| monitor_name = str(self.monitor) | |||||
| return monitor_name | |||||
| @@ -10,9 +10,9 @@ from copy import deepcopy | |||||
| import fastNLP | import fastNLP | ||||
| from .callback import HasMonitorCallback | |||||
| from .has_monitor_callback import HasMonitorCallback | |||||
| from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
| from fastNLP.envs import FASTNLP_LAUNCH_TIME | |||||
| from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_GLOBAL_RANK | |||||
| from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir | from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir | ||||
| @@ -217,7 +217,8 @@ class CheckpointCallback(HasMonitorCallback): | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| folder = self.timestamp_path.joinpath(folder_name) | folder = self.timestamp_path.joinpath(folder_name) | ||||
| synchronize_mkdir(folder) | |||||
| if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: # 只在进程0上创建 | |||||
| synchronize_mkdir(folder) | |||||
| _fn = getattr(trainer, self.save_fn_name) | _fn = getattr(trainer, self.save_fn_name) | ||||
| _fn( | _fn( | ||||
| folder=folder, | folder=folder, | ||||
| @@ -4,7 +4,7 @@ __all__ = [ | |||||
| from typing import Dict, Union, Callable | from typing import Dict, Union, Callable | ||||
| from .callback import HasMonitorCallback | |||||
| from .has_monitor_callback import HasMonitorCallback | |||||
| from fastNLP.core.utils.exceptions import EarlyStopException | from fastNLP.core.utils.exceptions import EarlyStopException | ||||
| @@ -0,0 +1,189 @@ | |||||
| __all__ = [ | |||||
| 'HasMonitorCallback', | |||||
| 'ExecuteOnceBetterMonitor' | |||||
| ] | |||||
| from typing import Dict, Union, Any | |||||
| from abc import ABC | |||||
| from fastNLP.core.utils import apply_to_collection | |||||
| from fastNLP.core.callbacks import Callback | |||||
| from fastNLP.core.callbacks.utils import _get_monitor_value | |||||
| from fastNLP.core.log import logger | |||||
| from fastNLP.core.utils.utils import _check_valid_parameters_number | |||||
| class CanItemDataType(ABC): | |||||
| """ | |||||
| 检测可以进行传输的对象。 | |||||
| """ | |||||
| @classmethod | |||||
| def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: | |||||
| if cls is CanItemDataType: | |||||
| item = getattr(subclass, 'item', None) | |||||
| return callable(item) | |||||
| return NotImplemented | |||||
| class HasMonitorCallback(Callback): | |||||
| def __init__(self, monitor, larger_better, must_have_monitor=False): | |||||
| """ | |||||
| 该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了 | |||||
| (1)判断monitor合法性;(2)在需要时, 根据trainer的monitor设置自己的monitor名称。 | |||||
| :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
| 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
| 果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
| :param larger_better: monitor 是否时越大越好 | |||||
| :param must_have_monitor: 这个 callback 是否必须有 monitor 设置。如果设置为 True ,且没检测到设置 monitor 会报错。 | |||||
| """ | |||||
| self.set_monitor(monitor, larger_better) | |||||
| self.must_have_moinitor = must_have_monitor | |||||
| def set_monitor(self, monitor, larger_better): | |||||
| if callable(monitor): # 检查是否能够接受一个参数 | |||||
| _check_valid_parameters_number(monitor, expected_params=['results'], fn_name='monitor') | |||||
| self.monitor = monitor | |||||
| else: | |||||
| self.monitor = str(monitor) if monitor is not None else None | |||||
| self.larger_better = bool(larger_better) | |||||
| if larger_better: | |||||
| self.monitor_value = float('-inf') | |||||
| else: | |||||
| self.monitor_value = float('inf') | |||||
| self._real_monitor = self.monitor | |||||
| def on_after_trainer_initialized(self, trainer, driver): | |||||
| """ | |||||
| 如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。 | |||||
| 同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。 | |||||
| :param trainer: | |||||
| :param driver: | |||||
| :return: | |||||
| """ | |||||
| if self.monitor is None and trainer.monitor is not None: | |||||
| self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better) | |||||
| if self.must_have_moinitor and self.monitor is None: | |||||
| raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " | |||||
| f"You can set it in the initialization or through Trainer.") | |||||
| def get_monitor_value(self, results:Dict)->Union[float, None]: | |||||
| """ | |||||
| 获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。 | |||||
| :param results: | |||||
| :return: 如果为 None ,表明此次没有找到合适的monitor | |||||
| """ | |||||
| if len(results)==0: | |||||
| return None | |||||
| # 保证所有的 tensor 都被转换为了 python 特定的类型 | |||||
| results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) | |||||
| use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | |||||
| real_monitor=self._real_monitor, | |||||
| res=results) | |||||
| if monitor_value is None: | |||||
| return monitor_value | |||||
| # 第一次运行 | |||||
| if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor: | |||||
| logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " | |||||
| f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.") | |||||
| # 检测到此次和上次不同。 | |||||
| elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor: | |||||
| logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. " | |||||
| f"The expected monitor is:`{self.monitor}`, last used monitor is:" | |||||
| f"`{self._real_monitor}` and current monitor is:`{use_monitor}`. Please consider using a " | |||||
| f"customized monitor function when the evaluation results are varying between validation.") | |||||
| self._real_monitor = use_monitor | |||||
| return monitor_value | |||||
| def is_better_monitor_value(self, monitor_value: float, keep_if_better=True): | |||||
| """ | |||||
| 检测 monitor_value 是否是更好的 | |||||
| :param monitor_value: 待检查的 monitor_value 。如果为 None ,返回 False | |||||
| :param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 | |||||
| :return: | |||||
| """ | |||||
| if monitor_value is None: | |||||
| return False | |||||
| better = self.is_former_monitor_value_better(monitor_value, self.monitor_value) | |||||
| if keep_if_better and better: | |||||
| self.monitor_value = monitor_value | |||||
| return better | |||||
| def is_better_results(self, results, keep_if_better=True): | |||||
| """ | |||||
| 检测给定的 results 是否比上一次更好,如果本次 results 中没有找到相关的monitor 返回 False。 | |||||
| :param results: on_valid_ends() 接口中传入的 evaluation 结果。 | |||||
| :param keep_if_better: 当返回为 True 时,是否保存到 self.monitor_value 中。 | |||||
| :return: | |||||
| """ | |||||
| monitor_value = self.get_monitor_value(results) | |||||
| if monitor_value is None: | |||||
| return False | |||||
| return self.is_better_monitor_value(monitor_value, keep_if_better=keep_if_better) | |||||
| def is_former_monitor_value_better(self, monitor_value1, monitor_value2): | |||||
| """ | |||||
| 传入的两个值中,是否monitor_value1的结果更好。 | |||||
| :param monitor_value1: | |||||
| :param monitor_value2: | |||||
| :return: | |||||
| """ | |||||
| if monitor_value1 is None and monitor_value2 is None: | |||||
| return True | |||||
| if monitor_value1 is None: | |||||
| return False | |||||
| if monitor_value2 is None: | |||||
| return True | |||||
| better = False | |||||
| if (self.larger_better and monitor_value1 > monitor_value2) or \ | |||||
| (not self.larger_better and monitor_value1 < monitor_value2): | |||||
| better = True | |||||
| return better | |||||
| @property | |||||
| def monitor_name(self): | |||||
| """ | |||||
| 返回 monitor 的名字,如果 monitor 是个 callable 的函数,则返回该函数的名称。 | |||||
| :return: | |||||
| """ | |||||
| if callable(self.monitor): | |||||
| try: | |||||
| monitor_name = self.monitor.__qualname__ | |||||
| except: | |||||
| monitor_name = self.monitor.__name__ | |||||
| elif self.monitor is None: | |||||
| return None | |||||
| else: | |||||
| # 这里是能是monitor,而不能是real_monitor,因为用户再次运行的时候real_monitor被初始化为monitor了 | |||||
| monitor_name = str(self.monitor) | |||||
| return monitor_name | |||||
| class ExecuteOnceBetterMonitor(HasMonitorCallback): | |||||
| def __init__(self, monitor, larger_better, execute_fn): | |||||
| """ | |||||
| 当监控的 monitor 结果更好的时候,调用 execute_fn 函数。 | |||||
| :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
| 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
| 果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
| :param larger_better: monitor 是否时越大越好 | |||||
| :param execute_fn: 一个可执行的函数,不接受任何参数,不反回值。在 monitor 取得更好结果的时候会调用。 | |||||
| """ | |||||
| super().__init__(monitor, larger_better, must_have_monitor=True) | |||||
| _check_valid_parameters_number(execute_fn, expected_params=[], fn_name='execute_fn') | |||||
| self.execute_fn = execute_fn() | |||||
| def on_validate_end(self, trainer, results): | |||||
| if self.is_better_results(results): | |||||
| self.execute_fn() | |||||
| @@ -4,7 +4,7 @@ __all__ = [ | |||||
| import os | import os | ||||
| from typing import Optional, Callable, Union | from typing import Optional, Callable, Union | ||||
| from .callback import HasMonitorCallback | |||||
| from .has_monitor_callback import HasMonitorCallback | |||||
| from io import BytesIO | from io import BytesIO | ||||
| import shutil | import shutil | ||||
| @@ -80,10 +80,7 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
| self.get_monitor_value(sanity_check_res) | self.get_monitor_value(sanity_check_res) | ||||
| def on_validate_end(self, trainer, results): | def on_validate_end(self, trainer, results): | ||||
| monitor_value = self.get_monitor_value(results) | |||||
| if monitor_value is None: | |||||
| return | |||||
| if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||||
| if self.is_better_results(results, keep_if_better=True): | |||||
| if self.real_save_folder: | if self.real_save_folder: | ||||
| trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | ||||
| model_save_fn=self.model_save_fn) | model_save_fn=self.model_save_fn) | ||||
| @@ -8,7 +8,7 @@ __all__ = [ | |||||
| 'RichCallback' | 'RichCallback' | ||||
| ] | ] | ||||
| from .callback import HasMonitorCallback | |||||
| from .has_monitor_callback import HasMonitorCallback | |||||
| from fastNLP.core.callbacks.utils import _get_monitor_value | from fastNLP.core.callbacks.utils import _get_monitor_value | ||||
| from fastNLP.core.utils import f_rich_progress | from fastNLP.core.utils import f_rich_progress | ||||
| from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
| @@ -98,6 +98,7 @@ class TorchDataLoader(DataLoader): | |||||
| def __getattr__(self, item): | def __getattr__(self, item): | ||||
| """ | """ | ||||
| 为FDataLoader提供dataset的方法和属性,实现该方法后,用户可以在FDataLoader实例化后使用apply等dataset的方法 | 为FDataLoader提供dataset的方法和属性,实现该方法后,用户可以在FDataLoader实例化后使用apply等dataset的方法 | ||||
| :param item: | :param item: | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| @@ -119,6 +120,7 @@ class TorchDataLoader(DataLoader): | |||||
| """ | """ | ||||
| 设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数 | 设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数 | ||||
| 当val=None时,意味着给定的field_names都不需要尝试padding | 当val=None时,意味着给定的field_names都不需要尝试padding | ||||
| :param field_names: | :param field_names: | ||||
| :param val: padding值,默认为0 | :param val: padding值,默认为0 | ||||
| :return: | :return: | ||||
| @@ -27,7 +27,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic | |||||
| # world_size 和 rank | # world_size 和 rank | ||||
| if FASTNLP_BACKEND_LAUNCH in os.environ: | if FASTNLP_BACKEND_LAUNCH in os.environ: | ||||
| if device is not None: | if device is not None: | ||||
| logger.info("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull " | |||||
| logger.warning_once("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull " | |||||
| "up your script. And we will directly get the local device via " | "up your script. And we will directly get the local device via " | ||||
| "`os.environ['LOCAL_RANK']`.") | "`os.environ['LOCAL_RANK']`.") | ||||
| return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs) | return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs) | ||||
| @@ -25,7 +25,7 @@ __all__ = [ | |||||
| from .utils import optimizer_state_to_device | from .utils import optimizer_state_to_device | ||||
| from fastNLP.core.drivers.driver import Driver | from fastNLP.core.drivers.driver import Driver | ||||
| from fastNLP.core.drivers.torch_driver.utils import _build_fp16_env | |||||
| from fastNLP.core.drivers.torch_driver.utils import _build_fp16_env, DummyGradScaler | |||||
| from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device | from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device | ||||
| from fastNLP.envs import rank_zero_call | from fastNLP.envs import rank_zero_call | ||||
| from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | ||||
| @@ -224,6 +224,11 @@ class TorchDriver(Driver): | |||||
| optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu")) | optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu")) | ||||
| optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | ||||
| # 4. 保存fp16的状态 | |||||
| if not isinstance(self.grad_scaler, DummyGradScaler): | |||||
| grad_scaler_state_dict = self.grad_scaler.state_dict() | |||||
| states['grad_scaler_state_dict'] = grad_scaler_state_dict | |||||
| logger.debug("Save optimizer state dict") | logger.debug("Save optimizer state dict") | ||||
| states["optimizers_state_dict"] = optimizers_state_dict | states["optimizers_state_dict"] = optimizers_state_dict | ||||
| torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) | torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) | ||||
| @@ -232,7 +237,7 @@ class TorchDriver(Driver): | |||||
| states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) | states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) | ||||
| # 1. 加载 optimizers 的状态; | # 1. 加载 optimizers 的状态; | ||||
| optimizers_state_dict = states["optimizers_state_dict"] | |||||
| optimizers_state_dict = states.pop("optimizers_state_dict") | |||||
| for i in range(len(self.optimizers)): | for i in range(len(self.optimizers)): | ||||
| optimizer: torch.optim.Optimizer = self.optimizers[i] | optimizer: torch.optim.Optimizer = self.optimizers[i] | ||||
| optimizer.load_state_dict(optimizers_state_dict[f"optimizer{i}"]) | optimizer.load_state_dict(optimizers_state_dict[f"optimizer{i}"]) | ||||
| @@ -244,26 +249,37 @@ class TorchDriver(Driver): | |||||
| res = torch.load(folder.joinpath(FASTNLP_MODEL_FILENAME), map_location='cpu') | res = torch.load(folder.joinpath(FASTNLP_MODEL_FILENAME), map_location='cpu') | ||||
| if only_state_dict: | if only_state_dict: | ||||
| model.load_state_dict(res) | model.load_state_dict(res) | ||||
| logger.debug("Load model state dict.") | |||||
| logger.debug("Load model state dict...") | |||||
| else: | else: | ||||
| model.load_state_dict(res.state_dict()) | model.load_state_dict(res.state_dict()) | ||||
| logger.debug("Load model.") | |||||
| # 3. 恢复 sampler 的状态; | |||||
| logger.debug("Load model...") | |||||
| # 3. 加载fp16的状态 | |||||
| if 'grad_scaler_state_dict' in states: | |||||
| grad_scaler_state_dict = states.pop('grad_scaler_state_dict') | |||||
| if not isinstance(self.grad_scaler, DummyGradScaler): | |||||
| self.grad_scaler.load_state_dict(grad_scaler_state_dict) | |||||
| logger.debug("Load grad_scaler state dict...") | |||||
| elif not isinstance(self.grad_scaler, DummyGradScaler): | |||||
| logger.warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, " | |||||
| f"the training process may be unstable.") | |||||
| # 4. 恢复 sampler 的状态; | |||||
| dataloader_args = self.get_dataloader_args(dataloader) | dataloader_args = self.get_dataloader_args(dataloader) | ||||
| if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | ||||
| sampler = dataloader_args.batch_sampler | sampler = dataloader_args.batch_sampler | ||||
| elif isinstance(dataloader_args.sampler, ReproducibleSampler): | elif isinstance(dataloader_args.sampler, ReproducibleSampler): | ||||
| sampler = dataloader_args.sampler | sampler = dataloader_args.sampler | ||||
| elif self.is_distributed(): | elif self.is_distributed(): | ||||
| raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") | |||||
| raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " | |||||
| "`ReproducibleSampler`.") | |||||
| else: | else: | ||||
| sampler = RandomBatchSampler( | sampler = RandomBatchSampler( | ||||
| batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | ||||
| batch_size=dataloader_args.batch_size, | batch_size=dataloader_args.batch_size, | ||||
| drop_last=dataloader_args.drop_last | drop_last=dataloader_args.drop_last | ||||
| ) | ) | ||||
| sampler.load_state_dict(states['sampler_states']) | |||||
| sampler.load_state_dict(states.pop('sampler_states')) | |||||
| states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) | states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) | ||||
| # 4. 修改 trainer_state.batch_idx_in_epoch | # 4. 修改 trainer_state.batch_idx_in_epoch | ||||
| @@ -14,11 +14,13 @@ if _NEED_IMPORT_PADDLE: | |||||
| import paddle.distributed as dist | import paddle.distributed as dist | ||||
| from paddle.fluid.dygraph import parallel_helper | from paddle.fluid.dygraph import parallel_helper | ||||
| def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List: | def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List: | ||||
| gathered_result = [paddle.zeros_like(result) for _ in range(world_size)] | gathered_result = [paddle.zeros_like(result) for _ in range(world_size)] | ||||
| dist.all_gather(gathered_result, result, group) | dist.all_gather(gathered_result, result, group) | ||||
| return gathered_result | return gathered_result | ||||
| class PaddleBackend(Backend): | class PaddleBackend(Backend): | ||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||
| @@ -124,4 +126,3 @@ class PaddleBackend(Backend): | |||||
| # TODO 如果在这里处理的话,会不会在别的地方引起bug? | # TODO 如果在这里处理的话,会不会在别的地方引起bug? | ||||
| device = get_device_from_visible(device) | device = get_device_from_visible(device) | ||||
| return paddle_to(tensor, device) | return paddle_to(tensor, device) | ||||
| @@ -11,7 +11,6 @@ from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gathe | |||||
| if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
| import torch | import torch | ||||
| import torch.distributed as dist | import torch.distributed as dist | ||||
| import torch.nn.functional as F | |||||
| def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List: | def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List: | ||||
| @@ -33,7 +32,7 @@ class TorchBackend(Backend): | |||||
| if dist.is_initialized(): | if dist.is_initialized(): | ||||
| if method is None: | if method is None: | ||||
| raise AggregateMethodError(should_have_aggregate_method=True) | raise AggregateMethodError(should_have_aggregate_method=True) | ||||
| tensor = fastnlp_torch_all_gather(tensor) | |||||
| tensor = self.all_gather_object(tensor) | |||||
| if isinstance(tensor[0], torch.Tensor): | if isinstance(tensor[0], torch.Tensor): | ||||
| tensor = torch.stack(tensor) | tensor = torch.stack(tensor) | ||||
| # 第一步, aggregate结果 | # 第一步, aggregate结果 | ||||
| @@ -203,7 +203,7 @@ def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None): | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| if fn_name is not None: | if fn_name is not None: | ||||
| assert callable(fn), f"{fn_name} should be callable, instead of {type(fn)}." | |||||
| assert callable(fn), f"`{fn_name}` should be callable, instead of `{type(fn)}`." | |||||
| parameters = list(inspect.signature(fn).parameters.values()) | parameters = list(inspect.signature(fn).parameters.values()) | ||||
| if inspect.ismethod(fn): | if inspect.ismethod(fn): | ||||
| @@ -606,16 +606,38 @@ def seq_len_to_mask(seq_len, max_len=None): | |||||
| return mask | return mask | ||||
| def wait_to_success(fn, no=False): | |||||
| def wait_filepath(path, exist=True): | |||||
| """ | |||||
| 等待当 path 的存在状态为 {exist} 时返回 | |||||
| :param path: 待检测的 path | |||||
| :param exist: 为 True 时表明检测这个 path 存在就返回; 为 False 表明检测到这个 path 不存在 返回。 | |||||
| :return: | |||||
| """ | |||||
| if isinstance(path, str): | |||||
| path = Path(path) | |||||
| assert isinstance(path, Path) | |||||
| count = 0 | |||||
| while True: | while True: | ||||
| sleep(0.01) | sleep(0.01) | ||||
| if (no and not fn()) or (not no and fn()): | |||||
| if path.exists() == exist: | |||||
| break | break | ||||
| count += 1 | |||||
| if count % 1000 == 0: | |||||
| msg = 'create' if exist else 'delete' | |||||
| logger.warning(f"Waiting path:{path} to {msg} for {count*0.01} seconds...") | |||||
| # 这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候 | |||||
| # 在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件; | |||||
| def synchronize_safe_rm(path: Optional[Union[str, Path]]): | def synchronize_safe_rm(path: Optional[Union[str, Path]]): | ||||
| """ | |||||
| 这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候 | |||||
| 在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件; | |||||
| 该函数会保证所有进程都检测到 path 删除之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。 | |||||
| :param path: | |||||
| :return: | |||||
| """ | |||||
| if path is None: | if path is None: | ||||
| return | return | ||||
| if isinstance(path, str): | if isinstance(path, str): | ||||
| @@ -624,7 +646,7 @@ def synchronize_safe_rm(path: Optional[Union[str, Path]]): | |||||
| return | return | ||||
| if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | ||||
| _recursive_rm(path) | _recursive_rm(path) | ||||
| wait_to_success(path.exists, no=True) | |||||
| wait_filepath(path, exist=False) | |||||
| def _recursive_rm(path: Path): | def _recursive_rm(path: Path): | ||||
| @@ -643,6 +665,8 @@ def _recursive_rm(path: Path): | |||||
| def synchronize_mkdir(path: Optional[Union[str, Path]]): | def synchronize_mkdir(path: Optional[Union[str, Path]]): | ||||
| """ | """ | ||||
| 注意该函数是用来创建文件夹,如果需要创建一个文件,不要使用该函数; | 注意该函数是用来创建文件夹,如果需要创建一个文件,不要使用该函数; | ||||
| 该函数会保证所有进程都检测到 path 创建之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。 | |||||
| """ | """ | ||||
| if path is None: | if path is None: | ||||
| return | return | ||||
| @@ -652,7 +676,7 @@ def synchronize_mkdir(path: Optional[Union[str, Path]]): | |||||
| if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | ||||
| path.mkdir(parents=True, exist_ok=True) | path.mkdir(parents=True, exist_ok=True) | ||||
| wait_to_success(path.exists) | |||||
| wait_filepath(path, exist=True) | |||||
| def get_class_that_defined_method(method): | def get_class_that_defined_method(method): | ||||
| @@ -21,11 +21,12 @@ class TestFdl: | |||||
| ds.set_pad_val("x", val=-1) | ds.set_pad_val("x", val=-1) | ||||
| fdl = TorchDataLoader(ds, batch_size=3) | fdl = TorchDataLoader(ds, batch_size=3) | ||||
| fdl.set_input("x", "y") | fdl.set_input("x", "y") | ||||
| fdl.set_pad_val("x", val=None) | |||||
| for batch in fdl: | for batch in fdl: | ||||
| print(batch) | print(batch) | ||||
| fdl.set_pad_val("x", val=-2) | |||||
| for batch in fdl: | |||||
| print(batch) | |||||
| # fdl.set_pad_val("x", val=-2) | |||||
| # for batch in fdl: | |||||
| # print(batch) | |||||
| def test_add_collator(self): | def test_add_collator(self): | ||||
| ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | ||||
| @@ -38,6 +39,7 @@ class TestFdl: | |||||
| fdl = TorchDataLoader(ds, batch_size=3, as_numpy=True) | fdl = TorchDataLoader(ds, batch_size=3, as_numpy=True) | ||||
| fdl.set_input("x", "y") | fdl.set_input("x", "y") | ||||
| # fdl.set_pad_val("x", val=None) | |||||
| fdl.add_collator(collate_fn) | fdl.add_collator(collate_fn) | ||||
| for batch in fdl: | for batch in fdl: | ||||
| print(batch) | print(batch) | ||||
| @@ -0,0 +1,59 @@ | |||||
| import os | |||||
| import pytest | |||||
| import paddle | |||||
| import paddle.distributed | |||||
| import paddle.distributed.fleet.base.role_maker as role_maker | |||||
| import paddle.distributed.fleet as fleet | |||||
| from fastNLP.core.metrics import Accuracy | |||||
| from fastNLP.core.drivers.paddle_driver.fleet_launcher import FleetLauncher | |||||
| ############################################################################ | |||||
| # | |||||
| # 测试 单机单卡情况下的Accuracy | |||||
| # | |||||
| ############################################################################ | |||||
| def test_accuracy_single(): | |||||
| pred = paddle.to_tensor([[1.19812393, -0.82041764, -0.53517765, -0.73061031, -1.45006669, | |||||
| 0.46514302], | |||||
| [-0.85775983, -2.18273783, -1.07505429, -1.45561373, 0.40011844, | |||||
| 1.02202022], | |||||
| [-0.39487389, 0.65682763, -0.62424040, 0.53692561, -0.28390560, | |||||
| -0.02559055], | |||||
| [-0.22586937, -0.07676325, -0.95977223, 0.36395910, -0.91758579, | |||||
| -0.83857095], | |||||
| [0.25136873, 2.49652624, 1.06251311, 1.60194016, 1.01451588, | |||||
| 0.08403367], | |||||
| [0.10844281, 1.19017303, -0.11378096, 1.12686944, -0.08654942, | |||||
| 0.48605862], | |||||
| [1.27320433, -1.13902378, 1.47072780, -0.98665696, -0.42589864, | |||||
| 0.64618838], | |||||
| [0.83809763, -0.05356205, 0.03042423, -0.28371972, 0.81611472, | |||||
| -0.45802942], | |||||
| [0.38535264, 0.09721313, 2.27187467, 0.32045507, -0.20711982, | |||||
| -0.13550705], | |||||
| [-0.75228405, -1.34161997, 1.08697927, 0.33218071, -1.19470012, | |||||
| 2.58735061]]) | |||||
| tg = paddle.to_tensor([1, 2, 1, 3, 5, 4, 4, 2, 1, 5]) | |||||
| acc_metric = Accuracy() | |||||
| acc_metric.update(pred, tg) | |||||
| result = acc_metric.get_metric() | |||||
| true_result = {'acc': 0.3} | |||||
| assert true_result == result | |||||
| ############################################################################ | |||||
| # | |||||
| # 测试 单机多卡情况下的Accuracy | |||||
| # | |||||
| ############################################################################ | |||||
| def test_accuracy_ddp(): | |||||
| launcher = FleetLauncher(devices=[0, 1]) | |||||
| launcher.launch() | |||||
| role = role_maker.PaddleCloudRoleMaker(is_collective=True) | |||||
| fleet.init(role) | |||||
| if fleet.is_server(): | |||||
| pass | |||||
| elif fleet.is_worker(): | |||||
| print(os.getenv("PADDLE_TRAINER_ID")) | |||||
| @@ -1,6 +1,6 @@ | |||||
| import os | import os | ||||
| from fastNLP.envs.set_env import dump_fastnlp_backend | |||||
| from fastNLP.envs.set_backend import dump_fastnlp_backend | |||||
| from tests.helpers.utils import Capturing | from tests.helpers.utils import Capturing | ||||
| from fastNLP.core import synchronize_safe_rm | from fastNLP.core import synchronize_safe_rm | ||||