From 653b129e608380d69c29cfcb92dc2c527754cacb Mon Sep 17 00:00:00 2001 From: yh Date: Sun, 22 May 2022 16:29:45 +0800 Subject: [PATCH] =?UTF-8?q?Evaluator=E6=96=B0=E5=A2=9Eload=5Fmodel?= =?UTF-8?q?=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/__init__.py | 1 + fastNLP/core/controllers/evaluator.py | 45 ++++++++++++++++++- .../controllers/loops/evaluate_batch_loop.py | 6 ++- .../core/dataloaders/torch_dataloader/fdl.py | 3 +- fastNLP/core/dataset/dataset.py | 6 +-- fastNLP/transformers/torch/modeling_utils.py | 4 +- fastNLP/transformers/torch/models/__init__.py | 1 + .../test_checkpoint_callback_torch.py | 8 ++++ 8 files changed, 66 insertions(+), 8 deletions(-) diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index 095c314c..855c335f 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -93,6 +93,7 @@ __all__ = [ "f_rich_progress", "auto_param_call", "seq_len_to_mask", + "f_tqdm_progress", # vocabulary.py 'Vocabulary' diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index 0f7476e1..3a584b4b 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -8,7 +8,10 @@ r""" ``Trainer`` 通过来自己内部内置一个 ``Evaluator`` 实例来支持在训练过程中进行验证的功能; """ -from typing import Union, List, Optional, Dict, Callable +from typing import Union, List, Optional, Dict, Callable, BinaryIO +import os +from pathlib import Path +import io from dataclasses import is_dataclass __all__ = [ @@ -25,6 +28,8 @@ from fastNLP.core.metrics.utils import _is_torchmetrics_metric, _is_paddle_metri from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader from fastNLP.core.utils.utils import _check_valid_parameters_number from fastNLP.core.log import logger +from fastNLP.envs import FASTNLP_MODEL_FILENAME + class Evaluator: @@ -174,6 +179,44 @@ class Evaluator: self.driver.barrier() + def load_model(self, folder: Union[str, Path, BinaryIO, io.BytesIO], only_state_dict: bool = True, + model_load_fn: Optional[Callable] = None, **kwargs): + """ + 用于帮助您加载模型的辅助函数; + + :param folder: 存放着您需要加载的 model 的文件夹,默认会尝试读取该文件夹下的 fastnlp_model.pkl.tar 文件。在 model_load_fn 不为空时, + 直接将该 folder 传递到 model_load_fn 中; + :param only_state_dict: 要读取的文件中是否仅包含模型权重。在 ``model_load_fn 不为 None`` 时,该参数无意义; + :param model_load_fn: ``callable`` 的函数,接受一个 folder 作为参数,需要注意该函数不需要返回任何内容; + :param kwargs: 理论上您不需要使用到该参数; + + .. note:: + + 注意您需要在初始化 ``Evaluator`` 后再通过 ``evaluator`` 实例来调用该函数;这意味着您需要保证在保存和加载时使用的 ``driver`` 是属于同一个 + 训练框架的,例如都是 ``pytorch`` 或者 ``paddle``; + """ + self.driver.barrier() + if not isinstance(folder, (io.BytesIO, BinaryIO)): + try: + if model_load_fn is not None: + if not callable(model_load_fn): + raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") + model_load_fn(folder) + else: + if isinstance(folder, str): + folder = Path(folder) + self.driver.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs) + except FileNotFoundError as e: + if FASTNLP_MODEL_FILENAME not in os.listdir(folder): + logger.error(f"fastNLP model checkpoint file:{FASTNLP_MODEL_FILENAME} is not found in {folder}.") + raise e + else: + if model_load_fn is not None: + raise RuntimeError("It is not allowed to specify a `model_save_fn` parameter with `folder` being " + "`io.BytesIO` type.") + self.driver.load_model(folder, only_state_dict, **kwargs) + self.driver.barrier() + def run(self, num_eval_batch_per_dl: int = -1) -> Dict: """ 该函数是在 ``Evaluator`` 初始化后用于真正开始评测的函数; diff --git a/fastNLP/core/controllers/loops/evaluate_batch_loop.py b/fastNLP/core/controllers/loops/evaluate_batch_loop.py index c6301772..fb936236 100644 --- a/fastNLP/core/controllers/loops/evaluate_batch_loop.py +++ b/fastNLP/core/controllers/loops/evaluate_batch_loop.py @@ -45,7 +45,11 @@ class EvaluateBatchLoop(Loop): except BaseException as e: if callable(getattr(dataloader, 'get_batch_indices', None)): indices = dataloader.get_batch_indices() - logger.error(f"Exception happens when evaluating on samples: {indices}") + if evaluator.cur_dataloader_name is not None: + logger.error(f"Exception happens when evaluating on samples in dataloader:" + f"{evaluator.cur_dataloader_name}: {indices}") + else: + logger.error(f"Exception happens when evaluating on samples: {indices}") raise e # 获取metric结果。返回的dict内容示例为{'metric_name1': metric_results, 'metric_name2': metric_results, ...} results = evaluator.get_metric() diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index 4b289a1d..4e208b3f 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -220,7 +220,8 @@ def prepare_torch_dataloader(ds_or_db, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[Callable] = None, multiprocessing_context=None, generator=None, prefetch_factor: int = 2, - persistent_workers: bool = False, non_train_sampler: Optional["Sampler[int]"] = None, + persistent_workers: bool = False, + non_train_sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, non_train_batch_size: int = 16) \ -> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]: """ diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index b808e01b..fac13195 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -119,10 +119,10 @@ class DataSet: self._collator = Collator() if data is not None: if isinstance(data, Dict): - length_set = set() + length_set = {} for key, value in data.items(): - length_set.add(len(value)) - assert len(length_set) == 1, "Arrays must all be same length." + length_set[key] = len(value) + assert len(set(length_set.values())) == 1, f"Fields must all be of same length, instead of {length_set}." for key, value in data.items(): self.add_field(field_name=key, fields=value) elif isinstance(data, List): diff --git a/fastNLP/transformers/torch/modeling_utils.py b/fastNLP/transformers/torch/modeling_utils.py index 74f370b6..6a31649e 100644 --- a/fastNLP/transformers/torch/modeling_utils.py +++ b/fastNLP/transformers/torch/modeling_utils.py @@ -1497,7 +1497,7 @@ class PreTrainedModel(Module, ModuleUtilsMixin, GenerationMixin): else: logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") if len(missing_keys) > 0: - logger.warning( + logger.rank_zero_warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " f"and are newly initialized: {missing_keys}\n" f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." @@ -1515,7 +1515,7 @@ class PreTrainedModel(Module, ModuleUtilsMixin, GenerationMixin): for key, shape1, shape2 in mismatched_keys ] ) - logger.warning( + logger.rank_zero_warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n" f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." diff --git a/fastNLP/transformers/torch/models/__init__.py b/fastNLP/transformers/torch/models/__init__.py index ddf3005f..39191a1e 100644 --- a/fastNLP/transformers/torch/models/__init__.py +++ b/fastNLP/transformers/torch/models/__init__.py @@ -1,3 +1,4 @@ +from .auto import * from .bart import * from .bert import * from .cpt import * diff --git a/tests/core/callbacks/test_checkpoint_callback_torch.py b/tests/core/callbacks/test_checkpoint_callback_torch.py index 1fc2f9ee..d227a162 100644 --- a/tests/core/callbacks/test_checkpoint_callback_torch.py +++ b/tests/core/callbacks/test_checkpoint_callback_torch.py @@ -8,6 +8,7 @@ import time from fastNLP.core.callbacks.checkpoint_callback import CheckpointCallback from fastNLP.core.controllers.trainer import Trainer +from fastNLP import Evaluator from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK from tests.helpers.utils import magic_argv_env_context @@ -286,6 +287,13 @@ def test_model_checkpoint_callback_2( trainer.load_model(folder, only_state_dict=only_state_dict) trainer.run() + evaluator = Evaluator(model=model_and_optimizers.model, driver='torch', device=0, + dataloaders=model_and_optimizers.evaluate_dataloaders, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics) + evaluator.load_model(folder, only_state_dict=only_state_dict) + evaluator.run() trainer.driver.barrier() finally: