diff --git a/fastNLP/core/callbacks/__init__.py b/fastNLP/core/callbacks/__init__.py index f45cf5e0..a47ab998 100644 --- a/fastNLP/core/callbacks/__init__.py +++ b/fastNLP/core/callbacks/__init__.py @@ -4,7 +4,8 @@ __all__ = [ 'EventsList', 'Filter', 'CallbackManager', - 'CheckpointCallback', + 'ModelCheckpointCallback', + 'TrainerCheckpointCallback', 'choose_progress_callback', 'ProgressCallback', 'RichCallback', @@ -16,7 +17,7 @@ __all__ = [ from .callback import Callback from .callback_events import EventsList, Events, Filter from .callback_manager import CallbackManager -from .checkpoint_callback import CheckpointCallback +from .checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback from .lr_scheduler_callback import LRSchedCallback from .load_best_model_callback import LoadBestModelCallback diff --git a/fastNLP/core/callbacks/checkpoint_callback.py b/fastNLP/core/callbacks/checkpoint_callback.py index 5cd102e0..d3a3b52d 100644 --- a/fastNLP/core/callbacks/checkpoint_callback.py +++ b/fastNLP/core/callbacks/checkpoint_callback.py @@ -1,5 +1,6 @@ __all__ = [ - 'CheckpointCallback' + 'ModelCheckpointCallback', + 'TrainerCheckpointCallback' ] import os from typing import Union, Optional, Callable, Dict, Sequence, Any, Mapping diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index f58a7faf..bd66d0a0 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -133,17 +133,18 @@ class Evaluator: self.driver.barrier() - def run(self, num_eval_batch_per_dl: int = -1) -> Dict: + def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: """ 返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。 - 如果存在多个metric,一个dataloader的情况,key的命名规则是 - metric_indicator_name#metric_name - 如果存在多个数据集,一个metric的情况,key的命名规则是 - metric_indicator_name#dataloader_name (其中 # 是默认的 separator ,可以通过 Evaluator 初始化参数修改)。 - 如果存在多个metric,多个dataloader的情况,key的命名规则是 - metric_indicator_name#metric_name#dataloader_name - :param num_eval_batch_per_dl: 每个 dataloader 测试多少个 batch 的数据,-1 为测试所有数据。 + 如果存在多个metric,一个dataloader的情况,key的命名规则是 + metric_indicator_name#metric_name + 如果存在多个数据集,一个metric的情况,key的命名规则是 + metric_indicator_name#metric_name#dataloader_name (其中 # 是默认的 separator ,可以通过 Evaluator 初始化参数修改)。 + 如果存在多个metric,多个dataloader的情况,key的命名规则是 + metric_indicator_name#metric_name#dataloader_name + 其中 metric_indicator_name 可能不存在。 + :param num_eval_batch_per_dl: 每个 dataloader 测试多少个 batch 的数据,-1 为测试所有数据。 :return: """ assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type." @@ -157,7 +158,6 @@ class Evaluator: assert self.driver.has_test_dataloaders() metric_results = {} - self.reset() evaluate_context = self.driver.get_evaluate_context() self.driver.set_model_mode(mode='eval' if self.model_use_eval_mode else 'train') diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index e7aaeea8..b7456b61 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -291,6 +291,7 @@ class Trainer(TrainerEventTrigger): raise FileNotFoundError("You are using `resume_from`, but we can not find your specific file.") if self.evaluator is not None and num_eval_sanity_batch > 0: + logger.info(f"Running evaluator sanity check for {num_eval_sanity_batch} batches.") self.on_sanity_check_begin() sanity_check_res = self.evaluator.run(num_eval_batch_per_dl=num_eval_sanity_batch) self.on_sanity_check_end(sanity_check_res) diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index 037fde00..9630a3a0 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -8,9 +8,8 @@ __all__ = [ import _pickle as pickle from copy import deepcopy -from typing import Optional, List, Callable, Union, Dict, Any +from typing import Optional, List, Callable, Union, Dict, Any, Mapping from functools import partial -import warnings import numpy as np from threading import Thread @@ -197,6 +196,20 @@ class DataSet: else: raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) + def __setitem__(self, key, value): + assert isinstance(key, int) and key