Browse Source

补充部分文档

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
cb5aa37015
14 changed files with 42 additions and 40 deletions
  1. +4
    -2
      fastNLP/core/callbacks/callback_manager.py
  2. +7
    -20
      fastNLP/core/callbacks/checkpoint_callback.py
  3. +2
    -1
      fastNLP/core/callbacks/load_best_model_callback.py
  4. +2
    -3
      fastNLP/core/callbacks/topk_saver.py
  5. +1
    -1
      fastNLP/core/collators/padders/get_padder.py
  6. +4
    -1
      fastNLP/core/collators/padders/paddle_padder.py
  7. +2
    -2
      fastNLP/core/collators/padders/torch_padder.py
  8. +6
    -2
      fastNLP/core/controllers/evaluator.py
  9. +1
    -1
      fastNLP/core/controllers/trainer.py
  10. +1
    -1
      fastNLP/core/dataloaders/prepare_dataloader.py
  11. +2
    -2
      fastNLP/core/dataset/dataset.py
  12. +2
    -2
      fastNLP/core/drivers/torch_driver/initialize_torch_driver.py
  13. +6
    -0
      fastNLP/core/drivers/torch_driver/torch_driver.py
  14. +2
    -2
      tests/core/collators/test_collator.py

+ 4
- 2
fastNLP/core/callbacks/callback_manager.py View File

@@ -10,8 +10,8 @@ from .callback_event import Event
from .callback import Callback
from fastNLP.core.log import logger
from .progress_callback import ProgressCallback, choose_progress_callback
from fastNLP.envs import rank_zero_call
from fastNLP.core.utils.utils import _get_fun_msg
from ..utils.exceptions import EarlyStopException
from ..utils.utils import _get_fun_msg


def _transfer(func):
@@ -25,6 +25,8 @@ def _transfer(func):
for callback_fn in manager.callback_fns[func.__name__]:
try:
callback_fn(*arg, **kwargs)
except EarlyStopException as e:
raise e
except BaseException as e:
logger.error(f"The following callback_fn raise exception:{_get_fun_msg(callback_fn)}.")
raise e


+ 7
- 20
fastNLP/core/callbacks/checkpoint_callback.py View File

@@ -19,7 +19,7 @@ class CheckpointCallback(Callback):
only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, save_object: str = 'model',
save_evaluate_results=True, **kwargs):
"""
保存模型 checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下::
保存 checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下::

- folder/
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
@@ -29,8 +29,9 @@ class CheckpointCallback(Callback):
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名

model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。
若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 在该 folder 下不进行模型保存。
model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。若 model_save_fn 不为 None,
则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 在该 folder 下不进行模型保存。默认情况下,本 checkpoint 只保存了 model
的状态;如还需保存 Trainer 的状态以断点重训的话,请使用 ``save_object='trainer'`` 。

:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
@@ -46,22 +47,14 @@ class CheckpointCallback(Callback):
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 trainer+model 还是 只是model 。
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果
保存 ``trainer`` 对象的话,将会保存 :class:~fastNLP.Trainer 的相关状态,可以通过 :meth:`Trainer.load` 加载该断
点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。
:param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个
fastnlp_evaluate_results.json 文件,记录当前的 results。仅在设置了 topk 的场景下有用,默认为 True 。
:param kwargs:
"""
super().__init__()
if folder is None:
logger.warning(
"Parameter `folder` is None, and we will use the current work directory to find and load your model.")
folder = Path.cwd()
folder = Path(folder)
if not folder.exists():
raise NotADirectoryError(f"Path '{folder.absolute()}' is not existed!")
elif folder.is_file():
raise ValueError("Parameter `folder` should be a directory instead of a file.")

if every_n_epochs is not None:
if not isinstance(every_n_epochs, int) or every_n_epochs < 1:
raise ValueError("Parameter `every_n_epochs` should be an int and greater than or equal to 1.")
@@ -74,12 +67,6 @@ class CheckpointCallback(Callback):
else:
every_n_batches = sys.maxsize # 使得没有数字可以整除

if topk is not None:
if not isinstance(topk, int):
raise ValueError("Parameter `topk` should be an int.")
else:
topk = 0

if on_exceptions is not None:
if not isinstance(on_exceptions, Sequence):
on_exceptions = [on_exceptions]


+ 2
- 1
fastNLP/core/callbacks/load_best_model_callback.py View File

@@ -19,7 +19,8 @@ class LoadBestModelCallback(HasMonitorCallback):
model_load_fn:Optional[Callable] = None,
delete_after_train:bool = True):
"""
保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。
保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型,默认会在加载之后删除权重文件。仅在训练正常结束的时候才能加载
最好的模型。

:param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结


+ 2
- 3
fastNLP/core/callbacks/topk_saver.py View File

@@ -33,9 +33,8 @@ class Saver:
:param kwargs: 更多需要传递给 Trainer.save() 或者 Trainer.save_model() 接口的参数。
"""
if folder is None:
logger.rank_zero_warning(
"Parameter `folder` is None, and we will use the current work directory to find and load your model.")
folder = Path.cwd()
folder = Path.cwd().absolute()
logger.info(f"Parameter `folder` is None, and we will use {folder} to save and load your model.")
folder = Path(folder)
if not folder.exists():
folder.mkdir(parents=True, exist_ok=True)


+ 1
- 1
fastNLP/core/collators/padders/get_padder.py View File

@@ -121,7 +121,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
# 这里 ele_dtype 传入为 None 的原因是防止出现 paddle tensor 转换为 torch tensor
return TorchTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype)
elif backend == 'paddle':
return PaddleTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype)
return PaddleTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
elif backend == 'jittor':
return JittorTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
else:


+ 4
- 1
fastNLP/core/collators/padders/paddle_padder.py View File

@@ -141,7 +141,10 @@ class PaddleTensorPadder(Padder):

shapes = [field.shape for field in batch_field]
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
array = np.full(max_shape, fill_value=pad_val)
if isinstance(batch_field[0], paddle.Tensor):
array = paddle.full(max_shape, fill_value=pad_val, dtype=dtype)
else:
array = np.full(max_shape, fill_value=pad_val, dtype=batch_field[0].dtype)
for i, field in enumerate(batch_field):
slices = (i, ) + tuple(slice(0, s) for s in shapes[i])
array[slices] = field


+ 2
- 2
fastNLP/core/collators/padders/torch_padder.py View File

@@ -118,8 +118,8 @@ class TorchTensorPadder(Padder):
batch_field = [torch.tensor(field.tolist(), dtype=dtype) for field in batch_field]
else:
device = batch_field[0].device
if dtype is None:
dtype = batch_field[0].dtype
if dtype is None:
dtype = batch_field[0].dtype
except AttributeError:
raise RuntimeError(f"If the field is not a torch.Tensor (it is {type(batch_field[0])}), "
f"it must have tolist() method.")


+ 6
- 2
fastNLP/core/controllers/evaluator.py View File

@@ -234,8 +234,7 @@ class Evaluator:
"""
调用所有 metric 的 reset() 方法,清除累积的状态。

Returns:

:return:
"""
self.metrics_wrapper.reset()

@@ -357,6 +356,11 @@ class _MetricsWrapper:
metric.update(res)

def reset(self):
"""
将 Metric 中的状态重新设置。

:return:
"""
for metric in self._metrics:
if _is_allennlp_metric(metric):
metric.get_metric(reset=True)


+ 1
- 1
fastNLP/core/controllers/trainer.py View File

@@ -646,7 +646,7 @@ class Trainer(TrainerEventTrigger):
self.driver.save_model(folder, only_state_dict, **kwargs)
self.driver.barrier()

def load_model(self, folder: Union[str, Path, BinaryIO, io.BytesIO], only_state_dict: bool = False,
def load_model(self, folder: Union[str, Path, BinaryIO, io.BytesIO], only_state_dict: bool = True,
model_load_fn: Optional[Callable] = None, **kwargs):
"""
加载模型


+ 1
- 1
fastNLP/core/dataloaders/prepare_dataloader.py View File

@@ -10,7 +10,7 @@ from ..samplers import RandomBatchSampler, RandomSampler
from .torch_dataloader import prepare_torch_dataloader
from .paddle_dataloader import prepare_paddle_dataloader
from .jittor_dataloader import prepare_jittor_dataloader
from ...envs import FASTNLP_BACKEND, SUPPORT_BACKENDS, _module_available
from ...envs import FASTNLP_BACKEND, SUPPORT_BACKENDS
from ..log import logger




+ 2
- 2
fastNLP/core/dataset/dataset.py View File

@@ -451,8 +451,8 @@ class DataSet:
apply_out = self._apply_process(num_proc, func, progress_desc=progress_desc,
show_progress_bar=show_progress_bar, _apply_field=field_name)
# 只检测第一个数据是否为dict类型,若是则默认所有返回值为dict;否则报错。
if not isinstance(apply_out[0], dict):
raise Exception("The result of func is not a dict")
if not isinstance(apply_out[0], Mapping):
raise Exception(f"The result of func is not a Mapping, but a {type(apply_out[0])}")

for key, value in apply_out[0].items():
results[key] = [value]


+ 2
- 2
fastNLP/core/drivers/torch_driver/initialize_torch_driver.py View File

@@ -55,8 +55,8 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi
elif each < 0:
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be bigger than 0.")
elif each >= _could_use_device_num:
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should not be bigger than"
" the available gpu number.")
raise ValueError(f"When parameter `device` is 'Sequence' type, the value in it should not be bigger than"
f" the available gpu number:{_could_use_device_num}.")
device = [torch.device(f"cuda:{w}") for w in device]
elif device is not None and not isinstance(device, torch.device):
raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.")


+ 6
- 0
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -167,6 +167,12 @@ class TorchDriver(Driver):
"""
model = self.unwrap_model()
res = torch.load(filepath, map_location='cpu')
if isinstance(res, dict) and only_state_dict is False:
logger.rank_zero_warning(f"It seems like that {filepath} only contains state, you may need to use "
f"`only_state_dict=True`")
elif not isinstance(res, dict) and only_state_dict is True:
logger.rank_zero_warning(f"It seems like that {filepath} is not state, you may need to use "
f"`only_state_dict=False`")
if only_state_dict:
model.load_state_dict(res)
else:


+ 2
- 2
tests/core/collators/test_collator.py View File

@@ -334,9 +334,9 @@ def test_torch_dl():
dl = TorchDataLoader(ds, batch_size=2)
batch = next(iter(dl))
assert 'x' in batch and 'y' in batch and 'z' in batch and 'i' in batch and 'j' in batch
assert isinstance(batch['z'], torch.Tensor)
assert isinstance(batch['z'], torch.FloatTensor)
assert isinstance(batch['j'], list)
assert isinstance(batch['i']['j'], torch.Tensor)
assert isinstance(batch['i']['j'], torch.LongTensor)

dl.set_ignore('x')
batch = next(iter(dl))


Loading…
Cancel
Save