Browse Source

1.torch在保存和load的时候会考虑GradScaler的保存问题; 2.新增Torch的GradientClip和Warmpup

tags/v1.0.0alpha
yh_cc 2 years ago
parent
commit
687db6d86a
10 changed files with 256 additions and 165 deletions
  1. +2
    -138
      fastNLP/core/callbacks/callback.py
  2. +4
    -3
      fastNLP/core/callbacks/checkpoint_callback.py
  3. +1
    -1
      fastNLP/core/callbacks/early_stop_callback.py
  4. +189
    -0
      fastNLP/core/callbacks/has_monitor_callback.py
  5. +2
    -5
      fastNLP/core/callbacks/load_best_model_callback.py
  6. +1
    -1
      fastNLP/core/callbacks/progress_callback.py
  7. +1
    -1
      fastNLP/core/drivers/torch_driver/initialize_torch_driver.py
  8. +24
    -8
      fastNLP/core/drivers/torch_driver/torch_driver.py
  9. +31
    -7
      fastNLP/core/utils/utils.py
  10. +1
    -1
      tests/envs/test_set_backend.py

+ 2
- 138
fastNLP/core/callbacks/callback.py View File

@@ -1,16 +1,12 @@
from typing import Union, Callable, Dict, Optional, Any
from abc import ABC

__all__ = [
'Callback',
]

from typing import Union, Callable, Dict, Optional, Any

from .callback_events import Events, EventsList, Filter
from .utils import _get_monitor_value
from fastNLP.core.callbacks.callback_events import _SingleEventState
from fastNLP.core.log import logger
from fastNLP.core.utils import apply_to_collection
from fastNLP.core.utils.utils import _check_valid_parameters_number


class Callback:
@@ -278,135 +274,3 @@ class _CallbackWrapper(Callback):
@property
def callback_name(self):
return self.fn.__name__


class CanItemDataType(ABC):
"""
检测可以进行传输的对象。

"""

@classmethod
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
if cls is CanItemDataType:
item = getattr(subclass, 'item', None)
return callable(item)
return NotImplemented


class HasMonitorCallback(Callback):
def __init__(self, monitor, larger_better, must_have_monitor=False):
self.set_monitor(monitor, larger_better)
self.must_have_moinitor = must_have_monitor

def set_monitor(self, monitor, larger_better):
if callable(monitor): # 检查是否能够接受一个参数
_check_valid_parameters_number(monitor, expected_params=['results'], fn_name='monitor')
self.monitor = monitor
else:
self.monitor = str(monitor) if monitor is not None else None
self.larger_better = bool(larger_better)
if larger_better:
self.monitor_value = float('-inf')
else:
self.monitor_value = float('inf')
self._real_monitor = self.monitor

def on_after_trainer_initialized(self, trainer, driver):
"""
如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。
同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。

:param trainer:
:param driver:
:return:
"""
if self.monitor is None and trainer.monitor is not None:
self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better)
if self.must_have_moinitor and self.monitor is None:
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. "
f"You can set it in the initialization or through Trainer.")

def get_monitor_value(self, results:Dict)->Union[float, None]:
"""
获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。

:param results:
:return: 如果为 None ,表明此次没有找到合适的monitor
"""
if len(results)==0:
return None
# 保证所有的 tensor 都被转换为了 python 特定的类型
results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item())
use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor,
real_monitor=self._real_monitor,
res=results)
if monitor_value is None:
return monitor_value
# 第一次运行
if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor:
logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), "
f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.")
# 检测到此次和上次不同。
elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor:
logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. "
f"The expected monitor is:`{self.monitor}`, last used monitor is:"
f"`{self._real_monitor}` and current monitor is:`{use_monitor}`. Please consider using a "
f"customized monitor function when the evaluation results are varying between validation.")

self._real_monitor = use_monitor
return monitor_value

def is_better_monitor_value(self, monitor_value: float, keep_if_better=True):
"""
检测 monitor_value 是否是更好的

:param monitor_value: 待检查的 monitor_value 。如果为 None ,返回 False
:param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。
:return:
"""
if monitor_value is None:
return False
better = self.is_former_monitor_value_better(monitor_value, self.monitor_value)
if keep_if_better and better:
self.monitor_value = monitor_value
return better

def is_former_monitor_value_better(self, monitor_value1, monitor_value2):
"""
传入的两个值中,是否monitor_value1的结果更好。

:param monitor_value1:
:param monitor_value2:
:return:
"""
if monitor_value1 is None and monitor_value2 is None:
return True
if monitor_value1 is None:
return False
if monitor_value2 is None:
return True
better = False
if (self.larger_better and monitor_value1 > monitor_value2) or \
(not self.larger_better and monitor_value1 < monitor_value2):
better = True
return better

@property
def monitor_name(self):
"""
返回 monitor 的名字,如果 monitor 是个 callable 的函数,则返回该函数的名称。

:return:
"""
if callable(self.monitor):
try:
monitor_name = self.monitor.__qualname__
except:
monitor_name = self.monitor.__name__
elif self.monitor is None:
return None
else:
# 这里是能是monitor,而不能是real_monitor,因为用户再次运行的时候real_monitor被初始化为monitor了
monitor_name = str(self.monitor)
return monitor_name

+ 4
- 3
fastNLP/core/callbacks/checkpoint_callback.py View File

@@ -10,9 +10,9 @@ from copy import deepcopy


import fastNLP
from .callback import HasMonitorCallback
from .has_monitor_callback import HasMonitorCallback
from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_LAUNCH_TIME
from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_GLOBAL_RANK
from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir


@@ -217,7 +217,8 @@ class CheckpointCallback(HasMonitorCallback):
:return:
"""
folder = self.timestamp_path.joinpath(folder_name)
synchronize_mkdir(folder)
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: # 只在进程0上创建
synchronize_mkdir(folder)
_fn = getattr(trainer, self.save_fn_name)
_fn(
folder=folder,


+ 1
- 1
fastNLP/core/callbacks/early_stop_callback.py View File

@@ -4,7 +4,7 @@ __all__ = [

from typing import Dict, Union, Callable

from .callback import HasMonitorCallback
from .has_monitor_callback import HasMonitorCallback
from fastNLP.core.utils.exceptions import EarlyStopException




+ 189
- 0
fastNLP/core/callbacks/has_monitor_callback.py View File

@@ -0,0 +1,189 @@
__all__ = [
'HasMonitorCallback',
'ExecuteOnceBetterMonitor'
]

from typing import Dict, Union, Any
from abc import ABC

from fastNLP.core.utils import apply_to_collection
from fastNLP.core.callbacks import Callback
from fastNLP.core.callbacks.utils import _get_monitor_value
from fastNLP.core.log import logger
from fastNLP.core.utils.utils import _check_valid_parameters_number


class CanItemDataType(ABC):
"""
检测可以进行传输的对象。

"""

@classmethod
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
if cls is CanItemDataType:
item = getattr(subclass, 'item', None)
return callable(item)
return NotImplemented



class HasMonitorCallback(Callback):
def __init__(self, monitor, larger_better, must_have_monitor=False):
"""
该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了
(1)判断monitor合法性;(2)在需要时, 根据trainer的monitor设置自己的monitor名称。

:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: monitor 是否时越大越好
:param must_have_monitor: 这个 callback 是否必须有 monitor 设置。如果设置为 True ,且没检测到设置 monitor 会报错。
"""
self.set_monitor(monitor, larger_better)
self.must_have_moinitor = must_have_monitor

def set_monitor(self, monitor, larger_better):
if callable(monitor): # 检查是否能够接受一个参数
_check_valid_parameters_number(monitor, expected_params=['results'], fn_name='monitor')
self.monitor = monitor
else:
self.monitor = str(monitor) if monitor is not None else None
self.larger_better = bool(larger_better)
if larger_better:
self.monitor_value = float('-inf')
else:
self.monitor_value = float('inf')
self._real_monitor = self.monitor

def on_after_trainer_initialized(self, trainer, driver):
"""
如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。
同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。

:param trainer:
:param driver:
:return:
"""
if self.monitor is None and trainer.monitor is not None:
self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better)
if self.must_have_moinitor and self.monitor is None:
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. "
f"You can set it in the initialization or through Trainer.")

def get_monitor_value(self, results:Dict)->Union[float, None]:
"""
获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。

:param results:
:return: 如果为 None ,表明此次没有找到合适的monitor
"""
if len(results)==0:
return None
# 保证所有的 tensor 都被转换为了 python 特定的类型
results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item())
use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor,
real_monitor=self._real_monitor,
res=results)
if monitor_value is None:
return monitor_value
# 第一次运行
if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor:
logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), "
f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.")
# 检测到此次和上次不同。
elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor:
logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. "
f"The expected monitor is:`{self.monitor}`, last used monitor is:"
f"`{self._real_monitor}` and current monitor is:`{use_monitor}`. Please consider using a "
f"customized monitor function when the evaluation results are varying between validation.")

self._real_monitor = use_monitor
return monitor_value

def is_better_monitor_value(self, monitor_value: float, keep_if_better=True):
"""
检测 monitor_value 是否是更好的

:param monitor_value: 待检查的 monitor_value 。如果为 None ,返回 False
:param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。
:return:
"""
if monitor_value is None:
return False
better = self.is_former_monitor_value_better(monitor_value, self.monitor_value)
if keep_if_better and better:
self.monitor_value = monitor_value
return better

def is_better_results(self, results, keep_if_better=True):
"""
检测给定的 results 是否比上一次更好,如果本次 results 中没有找到相关的monitor 返回 False。

:param results: on_valid_ends() 接口中传入的 evaluation 结果。
:param keep_if_better: 当返回为 True 时,是否保存到 self.monitor_value 中。
:return:
"""
monitor_value = self.get_monitor_value(results)
if monitor_value is None:
return False
return self.is_better_monitor_value(monitor_value, keep_if_better=keep_if_better)

def is_former_monitor_value_better(self, monitor_value1, monitor_value2):
"""
传入的两个值中,是否monitor_value1的结果更好。

:param monitor_value1:
:param monitor_value2:
:return:
"""
if monitor_value1 is None and monitor_value2 is None:
return True
if monitor_value1 is None:
return False
if monitor_value2 is None:
return True
better = False
if (self.larger_better and monitor_value1 > monitor_value2) or \
(not self.larger_better and monitor_value1 < monitor_value2):
better = True
return better

@property
def monitor_name(self):
"""
返回 monitor 的名字,如果 monitor 是个 callable 的函数,则返回该函数的名称。

:return:
"""
if callable(self.monitor):
try:
monitor_name = self.monitor.__qualname__
except:
monitor_name = self.monitor.__name__
elif self.monitor is None:
return None
else:
# 这里是能是monitor,而不能是real_monitor,因为用户再次运行的时候real_monitor被初始化为monitor了
monitor_name = str(self.monitor)
return monitor_name


class ExecuteOnceBetterMonitor(HasMonitorCallback):
def __init__(self, monitor, larger_better, execute_fn):
"""
当监控的 monitor 结果更好的时候,调用 execute_fn 函数。

:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: monitor 是否时越大越好
:param execute_fn: 一个可执行的函数,不接受任何参数,不反回值。在 monitor 取得更好结果的时候会调用。
"""
super().__init__(monitor, larger_better, must_have_monitor=True)
_check_valid_parameters_number(execute_fn, expected_params=[], fn_name='execute_fn')
self.execute_fn = execute_fn()

def on_validate_end(self, trainer, results):
if self.is_better_results(results):
self.execute_fn()

+ 2
- 5
fastNLP/core/callbacks/load_best_model_callback.py View File

@@ -4,7 +4,7 @@ __all__ = [

import os
from typing import Optional, Callable, Union
from .callback import HasMonitorCallback
from .has_monitor_callback import HasMonitorCallback
from io import BytesIO
import shutil

@@ -80,10 +80,7 @@ class LoadBestModelCallback(HasMonitorCallback):
self.get_monitor_value(sanity_check_res)

def on_validate_end(self, trainer, results):
monitor_value = self.get_monitor_value(results)
if monitor_value is None:
return
if self.is_better_monitor_value(monitor_value, keep_if_better=True):
if self.is_better_results(results, keep_if_better=True):
if self.real_save_folder:
trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,
model_save_fn=self.model_save_fn)


+ 1
- 1
fastNLP/core/callbacks/progress_callback.py View File

@@ -8,7 +8,7 @@ __all__ = [
'RichCallback'
]

from .callback import HasMonitorCallback
from .has_monitor_callback import HasMonitorCallback
from fastNLP.core.callbacks.utils import _get_monitor_value
from fastNLP.core.utils import f_rich_progress
from fastNLP.core.log import logger


+ 1
- 1
fastNLP/core/drivers/torch_driver/initialize_torch_driver.py View File

@@ -27,7 +27,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic
# world_size 和 rank
if FASTNLP_BACKEND_LAUNCH in os.environ:
if device is not None:
logger.info("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull "
logger.warning_once("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull "
"up your script. And we will directly get the local device via "
"`os.environ['LOCAL_RANK']`.")
return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs)


+ 24
- 8
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -25,7 +25,7 @@ __all__ = [

from .utils import optimizer_state_to_device
from fastNLP.core.drivers.driver import Driver
from fastNLP.core.drivers.torch_driver.utils import _build_fp16_env
from fastNLP.core.drivers.torch_driver.utils import _build_fp16_env, DummyGradScaler
from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device
from fastNLP.envs import rank_zero_call
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
@@ -224,6 +224,11 @@ class TorchDriver(Driver):
optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu"))
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的;

# 4. 保存fp16的状态
if not isinstance(self.grad_scaler, DummyGradScaler):
grad_scaler_state_dict = self.grad_scaler.state_dict()
states['grad_scaler_state_dict'] = grad_scaler_state_dict

logger.debug("Save optimizer state dict")
states["optimizers_state_dict"] = optimizers_state_dict
torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME))
@@ -232,7 +237,7 @@ class TorchDriver(Driver):
states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))

# 1. 加载 optimizers 的状态;
optimizers_state_dict = states["optimizers_state_dict"]
optimizers_state_dict = states.pop("optimizers_state_dict")
for i in range(len(self.optimizers)):
optimizer: torch.optim.Optimizer = self.optimizers[i]
optimizer.load_state_dict(optimizers_state_dict[f"optimizer{i}"])
@@ -244,26 +249,37 @@ class TorchDriver(Driver):
res = torch.load(folder.joinpath(FASTNLP_MODEL_FILENAME), map_location='cpu')
if only_state_dict:
model.load_state_dict(res)
logger.debug("Load model state dict.")
logger.debug("Load model state dict...")
else:
model.load_state_dict(res.state_dict())
logger.debug("Load model.")

# 3. 恢复 sampler 的状态;
logger.debug("Load model...")

# 3. 加载fp16的状态
if 'grad_scaler_state_dict' in states:
grad_scaler_state_dict = states.pop('grad_scaler_state_dict')
if not isinstance(self.grad_scaler, DummyGradScaler):
self.grad_scaler.load_state_dict(grad_scaler_state_dict)
logger.debug("Load grad_scaler state dict...")
elif not isinstance(self.grad_scaler, DummyGradScaler):
logger.warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, "
f"the training process may be unstable.")

# 4. 恢复 sampler 的状态;
dataloader_args = self.get_dataloader_args(dataloader)
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
sampler = dataloader_args.batch_sampler
elif isinstance(dataloader_args.sampler, ReproducibleSampler):
sampler = dataloader_args.sampler
elif self.is_distributed():
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.")
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or "
"`ReproducibleSampler`.")
else:
sampler = RandomBatchSampler(
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last
)
sampler.load_state_dict(states['sampler_states'])
sampler.load_state_dict(states.pop('sampler_states'))
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler)

# 4. 修改 trainer_state.batch_idx_in_epoch


+ 31
- 7
fastNLP/core/utils/utils.py View File

@@ -203,7 +203,7 @@ def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None):
:return:
"""
if fn_name is not None:
assert callable(fn), f"{fn_name} should be callable, instead of {type(fn)}."
assert callable(fn), f"`{fn_name}` should be callable, instead of `{type(fn)}`."

parameters = list(inspect.signature(fn).parameters.values())
if inspect.ismethod(fn):
@@ -606,16 +606,38 @@ def seq_len_to_mask(seq_len, max_len=None):
return mask


def wait_to_success(fn, no=False):
def wait_filepath(path, exist=True):
"""
等待当 path 的存在状态为 {exist} 时返回

:param path: 待检测的 path
:param exist: 为 True 时表明检测这个 path 存在就返回; 为 False 表明检测到这个 path 不存在 返回。
:return:
"""
if isinstance(path, str):
path = Path(path)
assert isinstance(path, Path)
count = 0
while True:
sleep(0.01)
if (no and not fn()) or (not no and fn()):
if path.exists() == exist:
break
count += 1
if count % 1000 == 0:
msg = 'create' if exist else 'delete'
logger.warning(f"Waiting path:{path} to {msg} for {count*0.01} seconds...")



# 这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候
# 在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件;
def synchronize_safe_rm(path: Optional[Union[str, Path]]):
"""
这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候
在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件;
该函数会保证所有进程都检测到 path 删除之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。

:param path:
:return:
"""
if path is None:
return
if isinstance(path, str):
@@ -624,7 +646,7 @@ def synchronize_safe_rm(path: Optional[Union[str, Path]]):
return
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
_recursive_rm(path)
wait_to_success(path.exists, no=True)
wait_filepath(path, exist=False)


def _recursive_rm(path: Path):
@@ -643,6 +665,8 @@ def _recursive_rm(path: Path):
def synchronize_mkdir(path: Optional[Union[str, Path]]):
"""
注意该函数是用来创建文件夹,如果需要创建一个文件,不要使用该函数;
该函数会保证所有进程都检测到 path 创建之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。

"""
if path is None:
return
@@ -652,7 +676,7 @@ def synchronize_mkdir(path: Optional[Union[str, Path]]):
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
path.mkdir(parents=True, exist_ok=True)

wait_to_success(path.exists)
wait_filepath(path, exist=True)


def get_class_that_defined_method(method):


+ 1
- 1
tests/envs/test_set_backend.py View File

@@ -1,6 +1,6 @@
import os

from fastNLP.envs.set_env import dump_fastnlp_backend
from fastNLP.envs.set_backend import dump_fastnlp_backend
from tests.helpers.utils import Capturing
from fastNLP.core import synchronize_safe_rm



Loading…
Cancel
Save