|
|
@@ -8,7 +8,10 @@ r""" |
|
|
|
``Trainer`` 通过来自己内部内置一个 ``Evaluator`` 实例来支持在训练过程中进行验证的功能; |
|
|
|
""" |
|
|
|
|
|
|
|
from typing import Union, List, Optional, Dict, Callable |
|
|
|
from typing import Union, List, Optional, Dict, Callable, BinaryIO |
|
|
|
import os |
|
|
|
from pathlib import Path |
|
|
|
import io |
|
|
|
from dataclasses import is_dataclass |
|
|
|
|
|
|
|
__all__ = [ |
|
|
@@ -25,6 +28,8 @@ from fastNLP.core.metrics.utils import _is_torchmetrics_metric, _is_paddle_metri |
|
|
|
from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader |
|
|
|
from fastNLP.core.utils.utils import _check_valid_parameters_number |
|
|
|
from fastNLP.core.log import logger |
|
|
|
from fastNLP.envs import FASTNLP_MODEL_FILENAME |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Evaluator: |
|
|
@@ -174,6 +179,44 @@ class Evaluator: |
|
|
|
|
|
|
|
self.driver.barrier() |
|
|
|
|
|
|
|
def load_model(self, folder: Union[str, Path, BinaryIO, io.BytesIO], only_state_dict: bool = True, |
|
|
|
model_load_fn: Optional[Callable] = None, **kwargs): |
|
|
|
""" |
|
|
|
用于帮助您加载模型的辅助函数; |
|
|
|
|
|
|
|
:param folder: 存放着您需要加载的 model 的文件夹,默认会尝试读取该文件夹下的 fastnlp_model.pkl.tar 文件。在 model_load_fn 不为空时, |
|
|
|
直接将该 folder 传递到 model_load_fn 中; |
|
|
|
:param only_state_dict: 要读取的文件中是否仅包含模型权重。在 ``model_load_fn 不为 None`` 时,该参数无意义; |
|
|
|
:param model_load_fn: ``callable`` 的函数,接受一个 folder 作为参数,需要注意该函数不需要返回任何内容; |
|
|
|
:param kwargs: 理论上您不需要使用到该参数; |
|
|
|
|
|
|
|
.. note:: |
|
|
|
|
|
|
|
注意您需要在初始化 ``Evaluator`` 后再通过 ``evaluator`` 实例来调用该函数;这意味着您需要保证在保存和加载时使用的 ``driver`` 是属于同一个 |
|
|
|
训练框架的,例如都是 ``pytorch`` 或者 ``paddle``; |
|
|
|
""" |
|
|
|
self.driver.barrier() |
|
|
|
if not isinstance(folder, (io.BytesIO, BinaryIO)): |
|
|
|
try: |
|
|
|
if model_load_fn is not None: |
|
|
|
if not callable(model_load_fn): |
|
|
|
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") |
|
|
|
model_load_fn(folder) |
|
|
|
else: |
|
|
|
if isinstance(folder, str): |
|
|
|
folder = Path(folder) |
|
|
|
self.driver.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs) |
|
|
|
except FileNotFoundError as e: |
|
|
|
if FASTNLP_MODEL_FILENAME not in os.listdir(folder): |
|
|
|
logger.error(f"fastNLP model checkpoint file:{FASTNLP_MODEL_FILENAME} is not found in {folder}.") |
|
|
|
raise e |
|
|
|
else: |
|
|
|
if model_load_fn is not None: |
|
|
|
raise RuntimeError("It is not allowed to specify a `model_save_fn` parameter with `folder` being " |
|
|
|
"`io.BytesIO` type.") |
|
|
|
self.driver.load_model(folder, only_state_dict, **kwargs) |
|
|
|
self.driver.barrier() |
|
|
|
|
|
|
|
def run(self, num_eval_batch_per_dl: int = -1) -> Dict: |
|
|
|
""" |
|
|
|
该函数是在 ``Evaluator`` 初始化后用于真正开始评测的函数; |
|
|
|