@@ -4,7 +4,8 @@ __all__ = [ | |||||
'EventsList', | 'EventsList', | ||||
'Filter', | 'Filter', | ||||
'CallbackManager', | 'CallbackManager', | ||||
'CheckpointCallback', | |||||
'ModelCheckpointCallback', | |||||
'TrainerCheckpointCallback', | |||||
'choose_progress_callback', | 'choose_progress_callback', | ||||
'ProgressCallback', | 'ProgressCallback', | ||||
'RichCallback', | 'RichCallback', | ||||
@@ -16,7 +17,7 @@ __all__ = [ | |||||
from .callback import Callback | from .callback import Callback | ||||
from .callback_events import EventsList, Events, Filter | from .callback_events import EventsList, Events, Filter | ||||
from .callback_manager import CallbackManager | from .callback_manager import CallbackManager | ||||
from .checkpoint_callback import CheckpointCallback | |||||
from .checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback | |||||
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback | from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback | ||||
from .lr_scheduler_callback import LRSchedCallback | from .lr_scheduler_callback import LRSchedCallback | ||||
from .load_best_model_callback import LoadBestModelCallback | from .load_best_model_callback import LoadBestModelCallback | ||||
@@ -8,7 +8,7 @@ __all__ = [ | |||||
from .callback_events import Events | from .callback_events import Events | ||||
from .callback import Callback | from .callback import Callback | ||||
from .checkpoint_callback import CheckpointCallback | |||||
from .checkpoint_callback import TrainerCheckpointCallback | |||||
from .progress_callback import ProgressCallback, choose_progress_callback | from .progress_callback import ProgressCallback, choose_progress_callback | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -98,7 +98,7 @@ class CallbackManager: | |||||
:return: | :return: | ||||
""" | """ | ||||
for each_callback in self.class_callbacks: | for each_callback in self.class_callbacks: | ||||
if isinstance(each_callback, CheckpointCallback) and each_callback.is_trainer_checkpoint: | |||||
if isinstance(each_callback, TrainerCheckpointCallback): | |||||
self._has_trainer_checkpoint = True | self._has_trainer_checkpoint = True | ||||
self.dissect_one_callback(each_callback) | self.dissect_one_callback(each_callback) | ||||
@@ -210,7 +210,7 @@ class CallbackManager: | |||||
each_callback.on_load_checkpoint(trainer, None) | each_callback.on_load_checkpoint(trainer, None) | ||||
@property | @property | ||||
def has_trainer_chechpoint(self) -> bool: | |||||
def has_trainer_checkpoint(self) -> bool: | |||||
return self._has_trainer_checkpoint | return self._has_trainer_checkpoint | ||||
@_transfer | @_transfer | ||||
@@ -1,12 +1,13 @@ | |||||
__all__ = [ | |||||
'ModelCheckpointCallback', | |||||
'TrainerCheckpointCallback' | |||||
] | |||||
import os | import os | ||||
from typing import Union, Optional, Callable, Dict, Sequence | |||||
from typing import Union, Optional, Callable, Dict, Sequence, Any, Mapping | |||||
from pathlib import Path | from pathlib import Path | ||||
from functools import partial | |||||
from time import sleep | |||||
from abc import ABC | |||||
import sys | |||||
__all__ = [ | |||||
'CheckpointCallback' | |||||
] | |||||
import fastNLP | import fastNLP | ||||
from .callback import Callback, Filter | from .callback import Callback, Filter | ||||
@@ -14,35 +15,37 @@ from fastNLP.core.callbacks.utils import _get_monitor_value | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.envs import FASTNLP_LAUNCH_TIME | from fastNLP.envs import FASTNLP_LAUNCH_TIME | ||||
from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir | from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir | ||||
from fastNLP.core.utils import apply_to_collection | |||||
class CheckpointCallback(Callback): | |||||
class CanItemDataType(ABC): | |||||
""" | """ | ||||
1. 因为只有 'Trainer' 才有 callback,因此评测 metric 实际上就是 validate 时干的事情; | |||||
2. 默认 'save_last' 为 True,即 model_checkpoint 的默认逻辑是在每一个 epoch 下保存最后的一个模型,模型名字为 last.pth.tar; | |||||
3. 理论上一个 model_checkpoint 的实例只会负责一个 monitor 的监视,如果用户在训练过程中指定了多个 monitor 的监视,例如 "acc1", | |||||
"acc2", ... 那么我们会为用户创建多个 model_checkpoint 的实例; | |||||
4. 理论上,在实际保存的过程中,topk 模式和 固定频率保存的模式是完全独立的,我们确实应当采取一些措施至少保证两者的名字不一样; | |||||
检测可以进行传输的对象。 | |||||
""" | """ | ||||
@classmethod | |||||
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: | |||||
if cls is CanItemDataType: | |||||
item = getattr(subclass, 'item', None) | |||||
return callable(item) | |||||
return NotImplemented | |||||
class CheckpointCallback(Callback): | |||||
def __init__( | def __init__( | ||||
self, | self, | ||||
monitor, | monitor, | ||||
is_trainer_checkpoint: Optional[bool] = False, | |||||
save_folder: Optional[Union[str, Path]] = None, | save_folder: Optional[Union[str, Path]] = None, | ||||
save_every_n_epochs: Optional[int] = None, | save_every_n_epochs: Optional[int] = None, | ||||
save_every_n_global_batches: Optional[int] = None, | |||||
save_every_n_batches: Optional[int] = None, | |||||
save_last: bool = True, | save_last: bool = True, | ||||
save_topk: Optional[int] = None, | save_topk: Optional[int] = None, | ||||
save_on_exception: Optional[Union[BaseException, Sequence[BaseException]]] = None, | save_on_exception: Optional[Union[BaseException, Sequence[BaseException]]] = None, | ||||
larger_better: bool = True, | larger_better: bool = True, | ||||
only_state_dict: bool = True, | only_state_dict: bool = True, | ||||
model_save_fn: Optional[Callable] = None, | model_save_fn: Optional[Callable] = None, | ||||
**kwargs, | **kwargs, | ||||
): | ): | ||||
if monitor is None and save_topk is not None: | if monitor is None and save_topk is not None: | ||||
@@ -51,9 +54,6 @@ class CheckpointCallback(Callback): | |||||
if monitor is not None and not isinstance(monitor, str): | if monitor is not None and not isinstance(monitor, str): | ||||
raise ValueError("Parameter `monitor` should be of 'str' type.") | raise ValueError("Parameter `monitor` should be of 'str' type.") | ||||
if not isinstance(is_trainer_checkpoint, bool): | |||||
raise TypeError("Parameter 'is_trainer_checkpoint' can only be `bool` type.") | |||||
if save_folder is None: | if save_folder is None: | ||||
logger.warning( | logger.warning( | ||||
"Parameter `path` is None, and we will use the current work directory to find and load your model.") | "Parameter `path` is None, and we will use the current work directory to find and load your model.") | ||||
@@ -67,15 +67,15 @@ class CheckpointCallback(Callback): | |||||
if not isinstance(save_every_n_epochs, int) or save_every_n_epochs < 1: | if not isinstance(save_every_n_epochs, int) or save_every_n_epochs < 1: | ||||
raise ValueError("parameter save_after_epoch_num should be an int and greater than or equal to 1.") | raise ValueError("parameter save_after_epoch_num should be an int and greater than or equal to 1.") | ||||
# 突然发现有一个骚操作在于 'Filter' 内部记载的状态值例如 'num_called' 是这个类全局的,而每次调用 __call__ 中输入的 | |||||
# 函数却是及时传入的,也就是说,我们可以保证 'Filter' 的正常控制频率的逻辑,然后每一次运行的函数都不一样; | |||||
self._filter_every_n_epochs = Filter(every=save_every_n_epochs) | |||||
else: | |||||
save_every_n_epochs = sys.maxsize # 使得没有数字可以整除 | |||||
if save_every_n_global_batches is not None: | |||||
if not isinstance(save_every_n_global_batches, int) or save_every_n_global_batches < 1: | |||||
if save_every_n_batches is not None: | |||||
if not isinstance(save_every_n_batches, int) or save_every_n_batches < 1: | |||||
raise ValueError( | raise ValueError( | ||||
"parameter save_every_n_global_batches should be an int and greater than or equal to 1.") | |||||
self._filter_every_n_global_batches = Filter(every=save_every_n_global_batches) | |||||
"parameter save_every_n_batches should be an int and greater than or equal to 1.") | |||||
else: | |||||
save_every_n_batches = sys.maxsize # 使得没有数字可以整除 | |||||
if save_topk is not None: | if save_topk is not None: | ||||
if not isinstance(save_topk, int) or save_topk < 1: | if not isinstance(save_topk, int) or save_topk < 1: | ||||
@@ -89,12 +89,12 @@ class CheckpointCallback(Callback): | |||||
if not issubclass(exception, BaseException): | if not issubclass(exception, BaseException): | ||||
raise TypeError("Each exception in parameter `save_on_exception` can only be " | raise TypeError("Each exception in parameter `save_on_exception` can only be " | ||||
"`BaseException` type.") | "`BaseException` type.") | ||||
else: | |||||
save_on_exception = [] | |||||
self.monitor = monitor | self.monitor = monitor | ||||
self.is_trainer_checkpoint = is_trainer_checkpoint | |||||
self.save_folder = Path(save_folder) | self.save_folder = Path(save_folder) | ||||
self.save_every_n_epochs = save_every_n_epochs | self.save_every_n_epochs = save_every_n_epochs | ||||
self.save_every_n_global_batches = save_every_n_global_batches | |||||
self.save_every_n_batches = save_every_n_batches | |||||
self.save_last = save_last | self.save_last = save_last | ||||
self.save_topk = save_topk | self.save_topk = save_topk | ||||
self.larger_better = larger_better | self.larger_better = larger_better | ||||
@@ -107,7 +107,7 @@ class CheckpointCallback(Callback): | |||||
self._topk_model = {} | self._topk_model = {} | ||||
self._topn = 0 # 表示目前已经保存了几个最好的模型; | self._topn = 0 # 表示目前已经保存了几个最好的模型; | ||||
# 因为我们在 `_get_validate_metric` 函数中,当在返回的 `validate_res` 字典中找不到 `monitor` 时,是使用模糊匹配找到的第一个 | |||||
# 因为我们在 `_get_validate_metric` 函数中,当在返回的 `validate_res` 字典中找不到 `monitor` 时,是使用匹配找到的 | |||||
# key 对应的 value 当做结果;但是这样存在的一个问题在于如果用户传入的 metric 返回的 sub_metric 的名字可能会混淆,并且其在下一次 | # key 对应的 value 当做结果;但是这样存在的一个问题在于如果用户传入的 metric 返回的 sub_metric 的名字可能会混淆,并且其在下一次 | ||||
# 训练的代码中修改了这些 sub_metric 返回的顺序,那么就会导致模糊匹配拿到的 key 和 value 与之前的不是同一个,这显然不是合理的行为; | # 训练的代码中修改了这些 sub_metric 返回的顺序,那么就会导致模糊匹配拿到的 key 和 value 与之前的不是同一个,这显然不是合理的行为; | ||||
# 因此我们通过该变量来表示我们通过模糊匹配拿到的 key; | # 因此我们通过该变量来表示我们通过模糊匹配拿到的 key; | ||||
@@ -115,76 +115,83 @@ class CheckpointCallback(Callback): | |||||
# 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候, | # 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候, | ||||
# FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的; | # FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的; | ||||
self.log_filepath = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) | |||||
self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) | |||||
# 我们只需要保证这个创建文件夹的操作只在进程 0 上进行即可;因为后续的实际的保存操作,其它进程实际并不会去执行; | # 我们只需要保证这个创建文件夹的操作只在进程 0 上进行即可;因为后续的实际的保存操作,其它进程实际并不会去执行; | ||||
synchronize_mkdir(self.log_filepath) | |||||
synchronize_mkdir(self.timestamp_path) | |||||
def on_validate_end(self, trainer, validate_res): | def on_validate_end(self, trainer, validate_res): | ||||
self._save_topk(trainer, validate_res) | self._save_topk(trainer, validate_res) | ||||
def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): | def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): | ||||
self._save_every_n_epochs(trainer) | |||||
self._save_last(trainer) | |||||
if trainer.cur_epoch_idx % self.save_every_n_epochs == 0: | |||||
folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}' | |||||
self.save(trainer, folder_name=folder_name) | |||||
if self.save_last: | |||||
folder_name = f'{self.folder_prefix}-last' | |||||
self.save(trainer, folder_name=folder_name) | |||||
def on_train_batch_end(self, trainer): | def on_train_batch_end(self, trainer): | ||||
self._save_every_n_global_batches(trainer) | |||||
if trainer.global_forward_batches % self.save_every_n_batches == 0: | |||||
folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}' | |||||
self.save(trainer, folder_name=folder_name) | |||||
def on_exception(self, trainer, exception: BaseException): | def on_exception(self, trainer, exception: BaseException): | ||||
if self.save_on_exception is not None and exception.__class__ in self.save_on_exception: | |||||
folder = self._get_checkpoint_real_save_folder(trainer=trainer, topk=False, metric=None) | |||||
folder = folder + f"_{exception.__class__.__name__}" | |||||
self._save_fn(trainer=trainer, topk=False, metric=None, substitute_folder=folder) | |||||
if exception.__class__ in self.save_on_exception: | |||||
folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}-' \ | |||||
f'exception_{exception.__class__.__name__}' | |||||
self.save(trainer=trainer, folder_name=folder_name) | |||||
def on_sanity_check_end(self, trainer, sanity_check_res): | def on_sanity_check_end(self, trainer, sanity_check_res): | ||||
# 主要核对一下 monitor 是否存在。 | |||||
self._get_validate_metric(sanity_check_res) | self._get_validate_metric(sanity_check_res) | ||||
def on_save_checkpoint(self, trainer) -> Dict: | def on_save_checkpoint(self, trainer) -> Dict: | ||||
""" | """ | ||||
我们需要保存 CheckpointCallback 内部的几个 filter 的状态; | |||||
保存 timestamp_path 使得之后可以继续训练并保存到该文件夹。 | |||||
topk_model的状态 | |||||
_real_monitor的值 | |||||
""" | """ | ||||
states = {} | states = {} | ||||
if self.save_every_n_epochs is not None: | |||||
states["_filter_every_n_epochs"] = self._filter_every_n_epochs.state_dict() | |||||
if self.save_every_n_global_batches is not None: | |||||
states["_filter_every_n_global_batches"] = self._filter_every_n_global_batches.state_dict() | |||||
states["real_monitor"] = self._real_monitor | |||||
states['timestamp_path'] = str(self.timestamp_path.absolute()) | |||||
states['_topk_model'] = apply_to_collection(self._topk_model, dtype=CanItemDataType, | |||||
function=lambda x:x.item()) | |||||
states['save_topk'] = 0 if self.save_topk is None else self.save_topk | |||||
states['_real_monitor'] = self._real_monitor | |||||
return states | return states | ||||
def on_load_checkpoint(self, trainer, states: Optional[Dict]): | def on_load_checkpoint(self, trainer, states: Optional[Dict]): | ||||
if self.save_every_n_epochs is not None: | |||||
self._filter_every_n_epochs.load_state_dict(states["_filter_every_n_epochs"]) | |||||
if self.save_every_n_global_batches is not None: | |||||
self._filter_every_n_global_batches.load_state_dict(states["_filter_every_n_global_batches"]) | |||||
timestamp_path = states['timestamp_path'] | |||||
if not os.path.exists(timestamp_path): | |||||
logger.info(f"The resuming save folder {timestamp_path} is not exists, will checkpoint save to " | |||||
f" {self.timestamp_path.absolute()}.") | |||||
else: | |||||
logger.info(f"Resume to save in path: {timestamp_path}.") | |||||
self.timestamp_path = Path(timestamp_path) | |||||
_topk_model = states['_topk_model'] | |||||
save_topk = None if int(states['save_topk']) == 0 else int(states['save_topk']) | |||||
if save_topk is not None and self.save_topk is not None: | |||||
assert self.save_topk == save_topk, f"The checkpoint set save_topk={save_topk}, while this callback set it " \ | |||||
f"as {save_topk}." | |||||
self._topk_model.update(self._topk_model) | |||||
self._real_monitor = states["real_monitor"] | self._real_monitor = states["real_monitor"] | ||||
def _save_every_n_epochs(self, trainer: "fastNLP.Trainer"): | |||||
if self.save_every_n_epochs is not None: | |||||
if self.is_trainer_checkpoint: | |||||
_fn_every_n_epochs = trainer.save | |||||
else: | |||||
_fn_every_n_epochs = trainer.save_model | |||||
_fn_every_n_epochs = partial(self._save_fn, trainer, False, None, _fn_every_n_epochs, None) | |||||
_fn_every_n_epochs = self._filter_every_n_epochs(_fn_every_n_epochs) | |||||
_fn_every_n_epochs() | |||||
def _save_every_n_global_batches(self, trainer: "fastNLP.Trainer"): | |||||
if self.save_every_n_global_batches is not None: | |||||
if self.is_trainer_checkpoint: | |||||
_fn_every_n_global_batches = trainer.save | |||||
else: | |||||
_fn_every_n_global_batches = trainer.save_model | |||||
_fn_every_n_global_batches = partial(self._save_fn, trainer, False, None, _fn_every_n_global_batches, None) | |||||
_fn_every_n_global_batches = self._filter_every_n_global_batches(_fn_every_n_global_batches) | |||||
_fn_every_n_global_batches() | |||||
def _save_topk(self, trainer: "fastNLP.Trainer", validate_res: Dict): | def _save_topk(self, trainer: "fastNLP.Trainer", validate_res: Dict): | ||||
""" | |||||
根据validate_res决定保存哪些model的函数。会自动移除掉不满足topk的文件夹。 | |||||
:param trainer: | |||||
:param validate_res: | |||||
:return: | |||||
""" | |||||
if self.save_topk is not None: | if self.save_topk is not None: | ||||
_metric_value = self._get_validate_metric(validate_res) | _metric_value = self._get_validate_metric(validate_res) | ||||
_saved_name = self._get_checkpoint_real_save_folder(trainer=trainer, topk=True, metric=_metric_value) | |||||
folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ | |||||
f"-{self._real_monitor}_{_metric_value}" | |||||
_should_save = False | _should_save = False | ||||
if self._topn < self.save_topk: | if self._topn < self.save_topk: | ||||
self._topk_model[_saved_name] = _metric_value | |||||
self._topk_model[folder_name] = _metric_value | |||||
self._topn += 1 | self._topn += 1 | ||||
_should_save = True | _should_save = True | ||||
else: | else: | ||||
@@ -192,39 +199,27 @@ class CheckpointCallback(Callback): | |||||
key=lambda x: self._topk_model[x]) | key=lambda x: self._topk_model[x]) | ||||
if (self.larger_better and _metric_value > self._topk_model[_least_valuable_model]) or \ | if (self.larger_better and _metric_value > self._topk_model[_least_valuable_model]) or \ | ||||
(self.larger_better is False and _metric_value < self._topk_model[_least_valuable_model]): | (self.larger_better is False and _metric_value < self._topk_model[_least_valuable_model]): | ||||
self._topk_model[_saved_name] = _metric_value | |||||
self._topk_model[folder_name] = _metric_value | |||||
_should_save = True | _should_save = True | ||||
self._topk_model.pop(_least_valuable_model) | self._topk_model.pop(_least_valuable_model) | ||||
synchronize_safe_rm(self.log_filepath.joinpath(_least_valuable_model)) | |||||
synchronize_safe_rm(self.timestamp_path.joinpath(_least_valuable_model)) | |||||
assert len(self._topk_model) == self.save_topk == self._topn | assert len(self._topk_model) == self.save_topk == self._topn | ||||
if _should_save: | if _should_save: | ||||
self._save_fn(trainer=trainer, topk=True, metric=_metric_value, substitute_folder=_saved_name) | |||||
self.save(trainer, folder_name=folder_name) | |||||
def _save_last(self, trainer: "fastNLP.Trainer"): | |||||
if self.save_last: | |||||
self._save_fn(trainer=trainer, topk=False, metric=None, substitute_folder="last") | |||||
def _save_fn(self, trainer, topk: bool = False, metric: Optional[Union[int, float]] = None, | |||||
substitute_fn: Optional[Callable] = None, substitute_folder: Optional[str] = None): | |||||
# 首先根据当前的 epoch 和 batch 在 parent_path/FASTNLP_LAUNCH_TIME 下创建子文件夹 epoch-batch-monitor 或者 | |||||
# epoch-batch-monitor-monitor_value; | |||||
if substitute_folder is None: | |||||
folder = self.log_filepath.joinpath(self._get_checkpoint_real_save_folder(trainer, topk, metric)) | |||||
else: | |||||
folder = self.log_filepath.joinpath(substitute_folder) | |||||
def save(self, trainer, folder_name): | |||||
""" | |||||
执行保存的函数,将数据保存在 save_folder/timestamp/folder_name 下。 | |||||
:param trainer: | |||||
:param folder_name: | |||||
:return: | |||||
""" | |||||
folder = self.timestamp_path.joinpath(folder_name) | |||||
synchronize_mkdir(folder) | synchronize_mkdir(folder) | ||||
# 然后再调用 trainer 的 save_model(用于保存模型)或者 save(用于断点重训)函数; | |||||
if substitute_fn is not None: | |||||
_fn = substitute_fn | |||||
else: | |||||
if self.is_trainer_checkpoint: | |||||
_fn = trainer.save | |||||
else: | |||||
_fn = trainer.save_model | |||||
_fn = getattr(trainer, self.save_fn_name) | |||||
_fn( | _fn( | ||||
folder=folder, | folder=folder, | ||||
only_state_dict=self.only_state_dict, | only_state_dict=self.only_state_dict, | ||||
@@ -243,18 +238,95 @@ class CheckpointCallback(Callback): | |||||
self._real_monitor = use_monitor | self._real_monitor = use_monitor | ||||
return value | return value | ||||
def _get_checkpoint_real_save_folder(self, trainer: "fastNLP.Trainer", topk: bool = False, | |||||
metric: Optional[Union[int, float]] = None) -> str: | |||||
@property | |||||
def folder_prefix(self): | |||||
raise NotImplementedError("The `folder_prefix` is not specified") | |||||
@property | |||||
def save_fn_name(self): | |||||
raise NotImplementedError("The `save_fn_name` is not specified.") | |||||
class ModelCheckpointCallback(CheckpointCallback): | |||||
""" | |||||
保存模型 checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下 | |||||
- save_folder/ | |||||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||||
- model-epoch_{epoch_idx}/ # 满足 save_every_n_epochs 条件保存的模型 | |||||
- model-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 save_every_n_batches 保存的模型 | |||||
- model-last/ # 最后一个 epoch 的保存 | |||||
- model-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。 | |||||
- model-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名 | |||||
model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。 | |||||
若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 | |||||
:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。 | |||||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | |||||
:param save_every_n_batches: 多少个 batch 保存一次。 | |||||
:param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 | |||||
:param save_topk: 保存 monitor 结果 topK 个。 | |||||
:param save_on_exception: 在出异常信息时,是否保存。传入需要捕获的异常的类。 | |||||
:param larger_better: monitor 的值是否时越大越好。 | |||||
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 | |||||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||||
:param kwargs: | |||||
""" | |||||
@property | |||||
def save_fn_name(self): | |||||
return 'save_model' | |||||
@property | |||||
def callback_name(self): | |||||
""" | """ | ||||
获取当前保存模型的真正地名字; | |||||
metric 参数仅当 mode 为 'topk' 时起作用; | |||||
通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; | |||||
:return: | |||||
""" | """ | ||||
cur_epoch_idx = trainer.cur_epoch_idx | |||||
global_forward_batches = trainer.global_forward_batches | |||||
_other = "" | |||||
if topk: | |||||
_other = f"_{metric}" | |||||
return f"epoch_{cur_epoch_idx}-global_batch_{global_forward_batches}-{self._real_monitor}{_other}" | |||||
return f"model_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" | |||||
@property | |||||
def folder_prefix(self): | |||||
return 'model' | |||||
class TrainerCheckpointCallback(CheckpointCallback): | |||||
""" | |||||
保存 Trainer checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下 | |||||
- save_folder/ | |||||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||||
- trainer-epoch_{epoch_idx}/ # 满足 save_every_n_epochs 条件保存的模型 | |||||
- trainer-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 save_every_n_batches 保存的模型 | |||||
- trainer-last/ # 最后一个 epoch 的保存 | |||||
- trainer-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。 | |||||
- trainer-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名 | |||||
model_save_fn 为 None ,则以上每个 folder 中,将生成两个文件:fastnlp_trainer.pkl.tar 以及 fastnlp_model.pkl.tar 。 | |||||
若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 | |||||
:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。 | |||||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | |||||
:param save_every_n_batches: 多少个 batch 保存一次。 | |||||
:param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 | |||||
:param save_topk: 保存 monitor 结果 topK 个。 | |||||
:param save_on_exception: 在出异常信息时,是否保存。 | |||||
:param larger_better: monitor 的值是否时越大越好。 | |||||
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无意义。 | |||||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||||
:param kwargs: | |||||
""" | |||||
@property | |||||
def save_fn_name(self): | |||||
return 'save' | |||||
@property | @property | ||||
def callback_name(self): | def callback_name(self): | ||||
@@ -262,6 +334,8 @@ class CheckpointCallback(Callback): | |||||
通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; | 通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; | ||||
:return: | :return: | ||||
""" | """ | ||||
return f"monitor-{self.monitor}#trainer_checkpoint-{self.is_trainer_checkpoint}#only_state_dict-{self.only_state_dict}" | |||||
return f"trainer_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" | |||||
@property | |||||
def folder_prefix(self): | |||||
return 'trainer' |
@@ -31,7 +31,7 @@ class LoadBestModelCallback(Callback): | |||||
请在函数内完成对模型的保存。 | 请在函数内完成对模型的保存。 | ||||
:param model_load_fn: 加载 model 的函数,与 model_save_fn 必须同时不为空。本函数的输入为一个已经创建好的文件夹,没有输出, | :param model_load_fn: 加载 model 的函数,与 model_save_fn 必须同时不为空。本函数的输入为一个已经创建好的文件夹,没有输出, | ||||
请在函数内完成对模型的加载。 | 请在函数内完成对模型的加载。 | ||||
:param delete_after_train: 在加载了最佳模型之后是否删掉模型。 | |||||
:param delete_after_train: 在训练结束后是否删掉模型。 | |||||
""" | """ | ||||
if model_load_fn is not None: | if model_load_fn is not None: | ||||
assert callable(model_load_fn), "`model_load_fn` must be a callable object." | assert callable(model_load_fn), "`model_load_fn` must be a callable object." | ||||
@@ -133,17 +133,18 @@ class Evaluator: | |||||
self.driver.barrier() | self.driver.barrier() | ||||
def run(self, num_eval_batch_per_dl: int = -1) -> Dict: | |||||
def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: | |||||
""" | """ | ||||
返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。 | 返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。 | ||||
如果存在多个metric,一个dataloader的情况,key的命名规则是 | |||||
metric_indicator_name#metric_name | |||||
如果存在多个数据集,一个metric的情况,key的命名规则是 | |||||
metric_indicator_name#dataloader_name (其中 # 是默认的 separator ,可以通过 Evaluator 初始化参数修改)。 | |||||
如果存在多个metric,多个dataloader的情况,key的命名规则是 | |||||
metric_indicator_name#metric_name#dataloader_name | |||||
:param num_eval_batch_per_dl: 每个 dataloader 测试多少个 batch 的数据,-1 为测试所有数据。 | |||||
如果存在多个metric,一个dataloader的情况,key的命名规则是 | |||||
metric_indicator_name#metric_name | |||||
如果存在多个数据集,一个metric的情况,key的命名规则是 | |||||
metric_indicator_name#metric_name#dataloader_name (其中 # 是默认的 separator ,可以通过 Evaluator 初始化参数修改)。 | |||||
如果存在多个metric,多个dataloader的情况,key的命名规则是 | |||||
metric_indicator_name#metric_name#dataloader_name | |||||
其中 metric_indicator_name 可能不存在。 | |||||
:param num_eval_batch_per_dl: 每个 dataloader 测试多少个 batch 的数据,-1 为测试所有数据。 | |||||
:return: | :return: | ||||
""" | """ | ||||
assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type." | assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type." | ||||
@@ -157,7 +158,6 @@ class Evaluator: | |||||
assert self.driver.has_test_dataloaders() | assert self.driver.has_test_dataloaders() | ||||
metric_results = {} | metric_results = {} | ||||
self.reset() | self.reset() | ||||
evaluate_context = self.driver.get_evaluate_context() | evaluate_context = self.driver.get_evaluate_context() | ||||
self.driver.set_model_mode(mode='eval' if self.model_use_eval_mode else 'train') | self.driver.set_model_mode(mode='eval' if self.model_use_eval_mode else 'train') | ||||
@@ -251,7 +251,7 @@ class Trainer(TrainerEventTrigger): | |||||
self.driver.set_deterministic_dataloader(self.dataloader) | self.driver.set_deterministic_dataloader(self.dataloader) | ||||
self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, | self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, | ||||
reproducible=self.callback_manager.has_trainer_chechpoint) | |||||
reproducible=self.callback_manager.has_trainer_checkpoint) | |||||
self.set_grad_to_none = kwargs.get("set_grad_to_none", True) | self.set_grad_to_none = kwargs.get("set_grad_to_none", True) | ||||
self.on_after_trainer_initialized(self.driver) | self.on_after_trainer_initialized(self.driver) | ||||
@@ -291,6 +291,7 @@ class Trainer(TrainerEventTrigger): | |||||
raise FileNotFoundError("You are using `resume_from`, but we can not find your specific file.") | raise FileNotFoundError("You are using `resume_from`, but we can not find your specific file.") | ||||
if self.evaluator is not None and num_eval_sanity_batch > 0: | if self.evaluator is not None and num_eval_sanity_batch > 0: | ||||
logger.info(f"Running evaluator sanity check for {num_eval_sanity_batch} batches.") | |||||
self.on_sanity_check_begin() | self.on_sanity_check_begin() | ||||
sanity_check_res = self.evaluator.run(num_eval_batch_per_dl=num_eval_sanity_batch) | sanity_check_res = self.evaluator.run(num_eval_batch_per_dl=num_eval_sanity_batch) | ||||
self.on_sanity_check_end(sanity_check_res) | self.on_sanity_check_end(sanity_check_res) | ||||
@@ -509,7 +510,7 @@ class Trainer(TrainerEventTrigger): | |||||
:param folder: 保存模型的地址; | :param folder: 保存模型的地址; | ||||
:param only_state_dict: 是否只保存模型的 `state_dict`; | :param only_state_dict: 是否只保存模型的 `state_dict`; | ||||
:param save_fn: 用户自己定制的用来替换该保存函数本身保存逻辑的函数; | |||||
:param model_save_fn: 用户自己定制的用来替换该保存函数本身保存逻辑的函数; | |||||
:param kwargs: 一些 driver 的保存模型的函数的参数另有其它; | :param kwargs: 一些 driver 的保存模型的函数的参数另有其它; | ||||
""" | """ | ||||
@@ -534,7 +535,16 @@ class Trainer(TrainerEventTrigger): | |||||
def load_model(self, folder: Union[str, Path, BinaryIO, io.BytesIO], only_state_dict: bool = False, | def load_model(self, folder: Union[str, Path, BinaryIO, io.BytesIO], only_state_dict: bool = False, | ||||
model_load_fn: Optional[Callable] = None, **kwargs): | model_load_fn: Optional[Callable] = None, **kwargs): | ||||
""" | |||||
加载模型 | |||||
:param folder: 读取 model 的文件夹,默认会尝试读取该文件夹下的 fastnlp_model.pkl.tar 文件。在 model_load_fn 不为空时, | |||||
直接将该 folder 传递到 model_load_fn 中。 | |||||
:param only_state_dict: 要读取的文件中是否仅包含模型权重。在 model_load_fn 不为 None 时,该参数无意义。 | |||||
:param model_load_fn: callable 的函数,接受一个 folder 作为参数,不返回任何内容。 | |||||
:param kwargs: | |||||
:return: | |||||
""" | |||||
self.on_load_model() | self.on_load_model() | ||||
self.driver.barrier() | self.driver.barrier() | ||||
if not isinstance(folder, (io.BytesIO, BinaryIO)): | if not isinstance(folder, (io.BytesIO, BinaryIO)): | ||||
@@ -555,7 +565,13 @@ class Trainer(TrainerEventTrigger): | |||||
def save(self, folder: Union[str, Path], only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, **kwargs): | def save(self, folder: Union[str, Path], only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, **kwargs): | ||||
r""" | r""" | ||||
用于断点重训的保存函数; | |||||
用于断点重训 Trainer 的保存函数; | |||||
:param folder: | |||||
:param only_state_dict: | |||||
:param model_save_fn: | |||||
:param kwargs: | |||||
:return: | |||||
""" | """ | ||||
self.driver.barrier() | self.driver.barrier() | ||||
@@ -8,9 +8,8 @@ __all__ = [ | |||||
import _pickle as pickle | import _pickle as pickle | ||||
from copy import deepcopy | from copy import deepcopy | ||||
from typing import Optional, List, Callable, Union, Dict, Any | |||||
from typing import Optional, List, Callable, Union, Dict, Any, Mapping | |||||
from functools import partial | from functools import partial | ||||
import warnings | |||||
import numpy as np | import numpy as np | ||||
from threading import Thread | from threading import Thread | ||||
@@ -197,6 +196,20 @@ class DataSet: | |||||
else: | else: | ||||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | ||||
def __setitem__(self, key, value): | |||||
assert isinstance(key, int) and key<len(self) | |||||
assert isinstance(value, Instance) or isinstance(value, Mapping) | |||||
ins_keys = set(value.keys()) | |||||
ds_keys = set(self.get_field_names()) | |||||
if len(ins_keys - ds_keys) != 0: | |||||
raise KeyError(f"The following keys are not found in the Dataset:{list(ins_keys - ds_keys)}.") | |||||
if len(ds_keys - ins_keys) != 0: | |||||
raise KeyError(f"The following keys are not found in the Instance:{list(ds_keys - ins_keys)}.") | |||||
for field_name, field in self.field_arrays.items(): | |||||
field[key] = value[field_name] | |||||
def __getattribute__(self, item): | def __getattribute__(self, item): | ||||
return object.__getattribute__(self, item) | return object.__getattribute__(self, item) | ||||
@@ -813,6 +826,3 @@ class DataSet: | |||||
self.collate_fns.set_input(*field_names) | self.collate_fns.set_input(*field_names) | ||||
class IterableDataset: | |||||
pass | |||||
@@ -46,9 +46,6 @@ class FieldArray: | |||||
def __setitem__(self, idx: int, val: Any): | def __setitem__(self, idx: int, val: Any): | ||||
assert isinstance(idx, int) | assert isinstance(idx, int) | ||||
if idx == -1: | |||||
idx = len(self) - 1 | |||||
assert 0 <= idx < len(self), f"0<= idx <{len(self)}, but idx is {idx}" | |||||
self.content[idx] = val | self.content[idx] = val | ||||
def get(self, indices: Union[int, List[int]]): | def get(self, indices: Union[int, List[int]]): | ||||
@@ -79,7 +76,7 @@ class FieldArray: | |||||
def split(self, sep: str = None, inplace: bool = True): | def split(self, sep: str = None, inplace: bool = True): | ||||
r""" | r""" | ||||
依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。将返回值 | |||||
依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。 | |||||
:param sep: 分割符,如果为None则直接调用str.split()。 | :param sep: 分割符,如果为None则直接调用str.split()。 | ||||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | :param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | ||||
@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod | |||||
from datetime import datetime | from datetime import datetime | ||||
from pathlib import Path | from pathlib import Path | ||||
from io import BytesIO | from io import BytesIO | ||||
import json | |||||
__all__ = [ | __all__ = [ | ||||
'Driver' | 'Driver' | ||||
@@ -68,9 +69,12 @@ class Driver(ABC): | |||||
def set_sampler_epoch(self, dataloader, cur_epoch_idx): | def set_sampler_epoch(self, dataloader, cur_epoch_idx): | ||||
r""" | r""" | ||||
对于分布式的 sampler,例如 torch 的 DistributedSampler,其需要在每一个 epoch 前设置随机数种子,来保证每一个进程上的 shuffle 是一样的; | 对于分布式的 sampler,例如 torch 的 DistributedSampler,其需要在每一个 epoch 前设置随机数种子,来保证每一个进程上的 shuffle 是一样的; | ||||
dataloader 中可能真正发挥作用的是 batch_sampler 也可能是 sampler。 | |||||
:param dataloader: 需要设置 epoch 的 dataloader 。 | |||||
:param cur_epoch_idx: 当前是第几个 epoch; | :param cur_epoch_idx: 当前是第几个 epoch; | ||||
""" | """ | ||||
@abstractmethod | @abstractmethod | ||||
def train_step(self, batch): | def train_step(self, batch): | ||||
""" | """ | ||||
@@ -444,13 +448,14 @@ class Driver(ABC): | |||||
exc_type, exc_value, exc_traceback_obj = sys.exc_info() | exc_type, exc_value, exc_traceback_obj = sys.exc_info() | ||||
_write_exc_info = { | _write_exc_info = { | ||||
'exc_type': exc_type, | |||||
'exc_value': exc_value, | |||||
'time': str(datetime.now().strftime('%Y-%m-%d-%H:%M:%S')), | |||||
'global_rank': getattr(self, "global_rank", None), | |||||
'rank': self.get_local_rank(), | |||||
'exc_type': str(exc_type.__name__), | |||||
'exc_value': str(exc_value), | |||||
'exc_time': str(datetime.now().strftime('%Y-%m-%d-%H:%M:%S')), | |||||
'exc_global_rank': getattr(self, "global_rank", None), | |||||
'exc_local_rank': self.get_local_rank(), | |||||
} | } | ||||
sys.stderr.write(str(_write_exc_info)+"\n") | |||||
sys.stderr.write("\nException info:\n") | |||||
sys.stderr.write(json.dumps(_write_exc_info, indent=2)+"\n") | |||||
sys.stderr.write(f"Start to stop these pids:{self._pids}, please wait several seconds.\n") | sys.stderr.write(f"Start to stop these pids:{self._pids}, please wait several seconds.\n") | ||||
for pid in self._pids: | for pid in self._pids: | ||||
@@ -27,7 +27,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic | |||||
# world_size 和 rank | # world_size 和 rank | ||||
if FASTNLP_BACKEND_LAUNCH in os.environ: | if FASTNLP_BACKEND_LAUNCH in os.environ: | ||||
if device is not None: | if device is not None: | ||||
logger.warning("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull " | |||||
logger.info("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull " | |||||
"up your script. And we will directly get the local device via " | "up your script. And we will directly get the local device via " | ||||
"`os.environ['LOCAL_RANK']`.") | "`os.environ['LOCAL_RANK']`.") | ||||
return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs) | return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs) | ||||
@@ -65,7 +65,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic | |||||
if not isinstance(device, List): | if not isinstance(device, List): | ||||
return TorchSingleDriver(model, device, **kwargs) | return TorchSingleDriver(model, device, **kwargs) | ||||
else: | else: | ||||
logger.warning("Notice you are using `torch` driver but your chosen `device` are multi gpus, we will use " | |||||
logger.info("Notice you are using `torch` driver but your chosen `device` are multi gpus, we will use " | |||||
"`TorchDDPDriver` by default. But if you mean using `TorchDDPDriver`, you should choose parameter" | "`TorchDDPDriver` by default. But if you mean using `TorchDDPDriver`, you should choose parameter" | ||||
"`driver` as `TorchDDPDriver`.") | "`driver` as `TorchDDPDriver`.") | ||||
return TorchDDPDriver(model, device, **kwargs) | return TorchDDPDriver(model, device, **kwargs) | ||||
@@ -143,8 +143,6 @@ class TorchDriver(Driver): | |||||
:param filepath: 保存到哪个文件夹; | :param filepath: 保存到哪个文件夹; | ||||
:param only_state_dict: 是否只保存权重; | :param only_state_dict: 是否只保存权重; | ||||
:param model_save_fn: | |||||
:return: | :return: | ||||
""" | """ | ||||
model = self.unwrap_model() | model = self.unwrap_model() | ||||
@@ -44,6 +44,9 @@ __all__ = [ | |||||
] | ] | ||||
def get_fn_arg_names(fn: Callable) -> List[str]: | def get_fn_arg_names(fn: Callable) -> List[str]: | ||||
r""" | r""" | ||||
返回一个函数的所有参数的名字; | 返回一个函数的所有参数的名字; | ||||
@@ -65,7 +65,7 @@ def set_env_on_import(): | |||||
# fastNLP 内部使用的一些变量 | # fastNLP 内部使用的一些变量 | ||||
if FASTNLP_LAUNCH_TIME not in os.environ: | if FASTNLP_LAUNCH_TIME not in os.environ: | ||||
cur_time = f"{datetime.datetime.now().strftime('%Y-%m-%d-%H_%M_%S_%M_%f')}" | |||||
cur_time = f"{datetime.datetime.now().strftime('%Y-%m-%d-%H_%M_%S_%f')}" | |||||
os.environ[FASTNLP_LAUNCH_TIME] = cur_time | os.environ[FASTNLP_LAUNCH_TIME] = cur_time | ||||
# 设置对应的值 | # 设置对应的值 | ||||
@@ -8,7 +8,7 @@ import torch.distributed as dist | |||||
from pathlib import Path | from pathlib import Path | ||||
import re | import re | ||||
from fastNLP.core.callbacks.checkpoint_callback import CheckpointCallback | |||||
from fastNLP.core.callbacks.checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback | |||||
from fastNLP.core.controllers.trainer import Trainer | from fastNLP.core.controllers.trainer import Trainer | ||||
from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME, FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME, FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | ||||
@@ -80,16 +80,23 @@ def test_model_checkpoint_callback_1( | |||||
version, | version, | ||||
only_state_dict | only_state_dict | ||||
): | ): | ||||
# def test_model_checkpoint_callback_1( | |||||
# model_and_optimizers: TrainerParameters, | |||||
# driver='torch_ddp', | |||||
# device=[0, 1], | |||||
# version=1, | |||||
# only_state_dict=True | |||||
# ): | |||||
path = Path.cwd().joinpath(f"test_model_checkpoint") | path = Path.cwd().joinpath(f"test_model_checkpoint") | ||||
path.mkdir(exist_ok=True, parents=True) | path.mkdir(exist_ok=True, parents=True) | ||||
if version == 0: | if version == 0: | ||||
callbacks = [ | callbacks = [ | ||||
CheckpointCallback( | |||||
ModelCheckpointCallback( | |||||
monitor="acc", | monitor="acc", | ||||
save_folder=path, | save_folder=path, | ||||
save_every_n_epochs=1, | save_every_n_epochs=1, | ||||
save_every_n_global_batches=123, # 避免和 epoch 的保存重复; | |||||
save_every_n_batches=123, # 避免和 epoch 的保存重复; | |||||
save_topk=None, | save_topk=None, | ||||
save_last=False, | save_last=False, | ||||
save_on_exception=None, | save_on_exception=None, | ||||
@@ -98,11 +105,11 @@ def test_model_checkpoint_callback_1( | |||||
] | ] | ||||
elif version == 1: | elif version == 1: | ||||
callbacks = [ | callbacks = [ | ||||
CheckpointCallback( | |||||
ModelCheckpointCallback( | |||||
monitor="acc", | monitor="acc", | ||||
save_folder=path, | save_folder=path, | ||||
save_every_n_epochs=3, | save_every_n_epochs=3, | ||||
save_every_n_global_batches=None, | |||||
save_every_n_batches=None, | |||||
save_topk=2, | save_topk=2, | ||||
save_last=True, | save_last=True, | ||||
save_on_exception=None, | save_on_exception=None, | ||||
@@ -121,7 +128,6 @@ def test_model_checkpoint_callback_1( | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
n_epochs=10, | n_epochs=10, | ||||
callbacks=callbacks, | callbacks=callbacks, | ||||
output_from_new_proc="all" | output_from_new_proc="all" | ||||
@@ -134,31 +140,31 @@ def test_model_checkpoint_callback_1( | |||||
if version == 0: | if version == 0: | ||||
if driver == "torch": | if driver == "torch": | ||||
assert "epoch_10-global_batch_250-acc" in all_saved_model_paths | |||||
assert "epoch_4-global_batch_123-acc" in all_saved_model_paths | |||||
assert "model-epoch_10" in all_saved_model_paths | |||||
assert "model-epoch_4-batch_123" in all_saved_model_paths | |||||
epoch_save_path = all_saved_model_paths["epoch_10-global_batch_250-acc"] | |||||
step_save_path = all_saved_model_paths["epoch_4-global_batch_123-acc"] | |||||
epoch_save_path = all_saved_model_paths["model-epoch_10"] | |||||
step_save_path = all_saved_model_paths["model-epoch_4-batch_123"] | |||||
assert len(all_saved_model_paths) == 12 | assert len(all_saved_model_paths) == 12 | ||||
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | ||||
else: | else: | ||||
assert "epoch_6-global_batch_78-acc" in all_saved_model_paths | |||||
assert "epoch_9-global_batch_123-acc" in all_saved_model_paths | |||||
assert "model-epoch_6" in all_saved_model_paths | |||||
assert "model-epoch_9-batch_123" in all_saved_model_paths | |||||
epoch_save_path = all_saved_model_paths["epoch_6-global_batch_78-acc"] | |||||
step_save_path = all_saved_model_paths["epoch_9-global_batch_123-acc"] | |||||
epoch_save_path = all_saved_model_paths["model-epoch_6"] | |||||
step_save_path = all_saved_model_paths["model-epoch_9-batch_123"] | |||||
assert len(all_saved_model_paths) == 11 | assert len(all_saved_model_paths) == 11 | ||||
all_state_dicts = [epoch_save_path, step_save_path] | all_state_dicts = [epoch_save_path, step_save_path] | ||||
elif version == 1: | elif version == 1: | ||||
pattern = re.compile("epoch_[0-9]+-global_batch_[0-9]+-[a-z|A-Z]+_[0-9]*.?[0-9]*") | |||||
pattern = re.compile("model-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") | |||||
if driver == "torch": | if driver == "torch": | ||||
assert "epoch_9-global_batch_225-acc" in all_saved_model_paths | |||||
assert "last" in all_saved_model_paths | |||||
assert "model-epoch_9" in all_saved_model_paths | |||||
assert "model-last" in all_saved_model_paths | |||||
aLL_topk_folders = [] | aLL_topk_folders = [] | ||||
for each_folder_name in all_saved_model_paths: | for each_folder_name in all_saved_model_paths: | ||||
each_folder_name = pattern.findall(each_folder_name) | each_folder_name = pattern.findall(each_folder_name) | ||||
@@ -166,15 +172,15 @@ def test_model_checkpoint_callback_1( | |||||
aLL_topk_folders.append(each_folder_name[0]) | aLL_topk_folders.append(each_folder_name[0]) | ||||
assert len(aLL_topk_folders) == 2 | assert len(aLL_topk_folders) == 2 | ||||
epoch_save_path = all_saved_model_paths["epoch_9-global_batch_225-acc"] | |||||
last_save_path = all_saved_model_paths["last"] | |||||
epoch_save_path = all_saved_model_paths["model-epoch_9"] | |||||
last_save_path = all_saved_model_paths["model-last"] | |||||
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] | topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] | ||||
assert len(all_saved_model_paths) == 6 | assert len(all_saved_model_paths) == 6 | ||||
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | ||||
else: | else: | ||||
assert "epoch_9-global_batch_117-acc" in all_saved_model_paths | |||||
assert "last" in all_saved_model_paths | |||||
assert "model-epoch_9" in all_saved_model_paths | |||||
assert "model-last" in all_saved_model_paths | |||||
aLL_topk_folders = [] | aLL_topk_folders = [] | ||||
for each_folder_name in all_saved_model_paths: | for each_folder_name in all_saved_model_paths: | ||||
@@ -183,8 +189,8 @@ def test_model_checkpoint_callback_1( | |||||
aLL_topk_folders.append(each_folder_name[0]) | aLL_topk_folders.append(each_folder_name[0]) | ||||
assert len(aLL_topk_folders) == 2 | assert len(aLL_topk_folders) == 2 | ||||
epoch_save_path = all_saved_model_paths["epoch_9-global_batch_117-acc"] | |||||
last_save_path = all_saved_model_paths["last"] | |||||
epoch_save_path = all_saved_model_paths["model-epoch_9"] | |||||
last_save_path = all_saved_model_paths["model-last"] | |||||
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] | topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] | ||||
assert len(all_saved_model_paths) == 6 | assert len(all_saved_model_paths) == 6 | ||||
@@ -212,7 +218,7 @@ def test_model_checkpoint_callback_1( | |||||
finally: | finally: | ||||
synchronize_safe_rm(path) | synchronize_safe_rm(path) | ||||
# pass | |||||
pass | |||||
if dist.is_initialized(): | if dist.is_initialized(): | ||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@@ -238,11 +244,11 @@ def test_model_checkpoint_callback_2( | |||||
raise NotImplementedError | raise NotImplementedError | ||||
callbacks = [ | callbacks = [ | ||||
CheckpointCallback( | |||||
ModelCheckpointCallback( | |||||
monitor="acc1", | monitor="acc1", | ||||
save_folder=path, | save_folder=path, | ||||
save_every_n_epochs=None, | save_every_n_epochs=None, | ||||
save_every_n_global_batches=None, | |||||
save_every_n_batches=None, | |||||
save_topk=None, | save_topk=None, | ||||
save_last=False, | save_last=False, | ||||
save_on_exception=NotImplementedError, | save_on_exception=NotImplementedError, | ||||
@@ -279,12 +285,12 @@ def test_model_checkpoint_callback_2( | |||||
all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | ||||
if driver == "torch": | if driver == "torch": | ||||
assert "epoch_4-global_batch_100-acc_NotImplementedError" in all_saved_model_paths | |||||
exception_model_path = all_saved_model_paths["epoch_4-global_batch_100-acc_NotImplementedError"] | |||||
assert "model-epoch_4-batch_100-exception_NotImplementedError" in all_saved_model_paths | |||||
exception_model_path = all_saved_model_paths["model-epoch_4-batch_100-exception_NotImplementedError"] | |||||
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | ||||
else: | else: | ||||
assert "epoch_4-global_batch_52-acc_NotImplementedError" in all_saved_model_paths | |||||
exception_model_path = all_saved_model_paths["epoch_4-global_batch_52-acc_NotImplementedError"] | |||||
assert "model-epoch_4-batch_52-exception_NotImplementedError" in all_saved_model_paths | |||||
exception_model_path = all_saved_model_paths["model-epoch_4-batch_52-exception_NotImplementedError"] | |||||
assert len(all_saved_model_paths) == 1 | assert len(all_saved_model_paths) == 1 | ||||
all_state_dicts = [exception_model_path] | all_state_dicts = [exception_model_path] | ||||
@@ -332,12 +338,11 @@ def test_trainer_checkpoint_callback_1( | |||||
if version == 0: | if version == 0: | ||||
callbacks = [ | callbacks = [ | ||||
CheckpointCallback( | |||||
TrainerCheckpointCallback( | |||||
monitor="acc", | monitor="acc", | ||||
is_trainer_checkpoint=True, | |||||
save_folder=path, | save_folder=path, | ||||
save_every_n_epochs=7, | save_every_n_epochs=7, | ||||
save_every_n_global_batches=123, # 避免和 epoch 的保存重复; | |||||
save_every_n_batches=123, # 避免和 epoch 的保存重复; | |||||
save_topk=None, | save_topk=None, | ||||
save_last=False, | save_last=False, | ||||
save_on_exception=None, | save_on_exception=None, | ||||
@@ -346,12 +351,11 @@ def test_trainer_checkpoint_callback_1( | |||||
] | ] | ||||
elif version == 1: | elif version == 1: | ||||
callbacks = [ | callbacks = [ | ||||
CheckpointCallback( | |||||
TrainerCheckpointCallback( | |||||
monitor="acc", | monitor="acc", | ||||
is_trainer_checkpoint=True, | |||||
save_folder=path, | save_folder=path, | ||||
save_every_n_epochs=None, | save_every_n_epochs=None, | ||||
save_every_n_global_batches=None, | |||||
save_every_n_batches=None, | |||||
save_topk=2, | save_topk=2, | ||||
save_last=True, | save_last=True, | ||||
save_on_exception=None, | save_on_exception=None, | ||||
@@ -383,31 +387,31 @@ def test_trainer_checkpoint_callback_1( | |||||
if version == 0: | if version == 0: | ||||
if driver == "torch": | if driver == "torch": | ||||
assert "epoch_7-global_batch_175-acc" in all_saved_model_paths | |||||
assert "epoch_4-global_batch_123-acc" in all_saved_model_paths | |||||
assert "trainer-epoch_7" in all_saved_model_paths | |||||
assert "trainer-epoch_4-batch_123" in all_saved_model_paths | |||||
epoch_save_path = all_saved_model_paths["epoch_7-global_batch_175-acc"] | |||||
step_save_path = all_saved_model_paths["epoch_4-global_batch_123-acc"] | |||||
epoch_save_path = all_saved_model_paths["trainer-epoch_7"] | |||||
step_save_path = all_saved_model_paths["trainer-epoch_4-batch_123"] | |||||
assert len(all_saved_model_paths) == 3 | assert len(all_saved_model_paths) == 3 | ||||
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | ||||
else: | else: | ||||
assert "epoch_7-global_batch_91-acc" in all_saved_model_paths | |||||
assert "epoch_9-global_batch_123-acc" in all_saved_model_paths | |||||
assert "trainer-epoch_7" in all_saved_model_paths | |||||
assert "trainer-epoch_9-batch_123" in all_saved_model_paths | |||||
epoch_save_path = all_saved_model_paths["epoch_7-global_batch_91-acc"] | |||||
step_save_path = all_saved_model_paths["epoch_9-global_batch_123-acc"] | |||||
epoch_save_path = all_saved_model_paths["trainer-epoch_7"] | |||||
step_save_path = all_saved_model_paths["trainer-epoch_9-batch_123"] | |||||
assert len(all_saved_model_paths) == 2 | assert len(all_saved_model_paths) == 2 | ||||
all_state_dicts = [epoch_save_path, step_save_path] | all_state_dicts = [epoch_save_path, step_save_path] | ||||
elif version == 1: | elif version == 1: | ||||
pattern = re.compile("epoch_[0-9]+-global_batch_[0-9]+-[a-z|A-Z]+_[0-9]*.?[0-9]*") | |||||
pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") | |||||
# all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | # all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | ||||
if driver == "torch": | if driver == "torch": | ||||
assert "last" in all_saved_model_paths | |||||
assert "trainer-last" in all_saved_model_paths | |||||
aLL_topk_folders = [] | aLL_topk_folders = [] | ||||
for each_folder_name in all_saved_model_paths: | for each_folder_name in all_saved_model_paths: | ||||
each_folder_name = pattern.findall(each_folder_name) | each_folder_name = pattern.findall(each_folder_name) | ||||
@@ -415,13 +419,13 @@ def test_trainer_checkpoint_callback_1( | |||||
aLL_topk_folders.append(each_folder_name[0]) | aLL_topk_folders.append(each_folder_name[0]) | ||||
assert len(aLL_topk_folders) == 2 | assert len(aLL_topk_folders) == 2 | ||||
last_save_path = all_saved_model_paths["last"] | |||||
last_save_path = all_saved_model_paths["trainer-last"] | |||||
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] | topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] | ||||
assert len(all_saved_model_paths) == 3 | assert len(all_saved_model_paths) == 3 | ||||
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | ||||
else: | else: | ||||
assert "last" in all_saved_model_paths | |||||
assert "trainer-last" in all_saved_model_paths | |||||
aLL_topk_folders = [] | aLL_topk_folders = [] | ||||
for each_folder_name in all_saved_model_paths: | for each_folder_name in all_saved_model_paths: | ||||
@@ -430,7 +434,7 @@ def test_trainer_checkpoint_callback_1( | |||||
aLL_topk_folders.append(each_folder_name[0]) | aLL_topk_folders.append(each_folder_name[0]) | ||||
assert len(aLL_topk_folders) == 2 | assert len(aLL_topk_folders) == 2 | ||||
last_save_path = all_saved_model_paths["last"] | |||||
last_save_path = all_saved_model_paths["trainer-last"] | |||||
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] | topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] | ||||
assert len(all_saved_model_paths) == 3 | assert len(all_saved_model_paths) == 3 | ||||
@@ -474,10 +478,11 @@ def test_trainer_checkpoint_callback_2( | |||||
device, | device, | ||||
version | version | ||||
): | ): | ||||
pytest.skip("Skip transformers test for now.") | |||||
path = Path.cwd().joinpath(f"test_model_checkpoint") | path = Path.cwd().joinpath(f"test_model_checkpoint") | ||||
path.mkdir(exist_ok=True, parents=True) | path.mkdir(exist_ok=True, parents=True) | ||||
import transformers | |||||
import transformers # 版本4.16.2 | |||||
import torch | import torch | ||||
from torchmetrics import Accuracy | from torchmetrics import Accuracy | ||||
from transformers import AutoModelForSequenceClassification | from transformers import AutoModelForSequenceClassification | ||||
@@ -587,12 +592,11 @@ def test_trainer_checkpoint_callback_2( | |||||
if version == 0: | if version == 0: | ||||
callbacks = [ | callbacks = [ | ||||
CheckpointCallback( | |||||
TrainerCheckpointCallback( | |||||
monitor="acc", | monitor="acc", | ||||
is_trainer_checkpoint=True, | |||||
save_folder=path, | save_folder=path, | ||||
save_every_n_epochs=None, | save_every_n_epochs=None, | ||||
save_every_n_global_batches=50, | |||||
save_every_n_batches=50, | |||||
save_topk=None, | save_topk=None, | ||||
save_last=False, | save_last=False, | ||||
save_on_exception=None, | save_on_exception=None, | ||||
@@ -601,12 +605,11 @@ def test_trainer_checkpoint_callback_2( | |||||
] | ] | ||||
elif version == 1: | elif version == 1: | ||||
callbacks = [ | callbacks = [ | ||||
CheckpointCallback( | |||||
TrainerCheckpointCallback( | |||||
monitor="acc", | monitor="acc", | ||||
is_trainer_checkpoint=True, | |||||
save_folder=path, | save_folder=path, | ||||
save_every_n_epochs=None, | save_every_n_epochs=None, | ||||
save_every_n_global_batches=None, | |||||
save_every_n_batches=None, | |||||
save_topk=1, | save_topk=1, | ||||
save_last=True, | save_last=True, | ||||
save_on_exception=None, | save_on_exception=None, | ||||
@@ -638,27 +641,27 @@ def test_trainer_checkpoint_callback_2( | |||||
if version == 0: | if version == 0: | ||||
if driver == "torch": | if driver == "torch": | ||||
assert "epoch_1-global_batch_200-acc" in all_saved_model_paths | |||||
assert "trainer-epoch_1-batch_200" in all_saved_model_paths | |||||
epoch_save_path = all_saved_model_paths["epoch_1-global_batch_200-acc"] | |||||
epoch_save_path = all_saved_model_paths["trainer-epoch_1-batch_200"] | |||||
assert len(all_saved_model_paths) == 4 | assert len(all_saved_model_paths) == 4 | ||||
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | ||||
else: | else: | ||||
assert "epoch_1-global_batch_100-acc" in all_saved_model_paths | |||||
assert "trainer-epoch_1-batch_100" in all_saved_model_paths | |||||
epoch_save_path = all_saved_model_paths["epoch_1-global_batch_100-acc"] | |||||
epoch_save_path = all_saved_model_paths["trainer-epoch_1-batch_100"] | |||||
assert len(all_saved_model_paths) == 2 | assert len(all_saved_model_paths) == 2 | ||||
all_state_dicts = [epoch_save_path] | all_state_dicts = [epoch_save_path] | ||||
elif version == 1: | elif version == 1: | ||||
pattern = re.compile("epoch_[0-9]+-global_batch_[0-9]+-[a-z|A-Z]+_[0-9]*.?[0-9]*") | |||||
pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") | |||||
# all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | # all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | ||||
if driver == "torch": | if driver == "torch": | ||||
assert "last" in all_saved_model_paths | |||||
assert "trainer-last" in all_saved_model_paths | |||||
aLL_topk_folders = [] | aLL_topk_folders = [] | ||||
for each_folder_name in all_saved_model_paths: | for each_folder_name in all_saved_model_paths: | ||||
each_folder_name = pattern.findall(each_folder_name) | each_folder_name = pattern.findall(each_folder_name) | ||||
@@ -666,13 +669,13 @@ def test_trainer_checkpoint_callback_2( | |||||
aLL_topk_folders.append(each_folder_name[0]) | aLL_topk_folders.append(each_folder_name[0]) | ||||
assert len(aLL_topk_folders) == 1 | assert len(aLL_topk_folders) == 1 | ||||
last_save_path = all_saved_model_paths["last"] | |||||
last_save_path = all_saved_model_paths["trainer-last"] | |||||
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] | topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] | ||||
assert len(all_saved_model_paths) == 2 | assert len(all_saved_model_paths) == 2 | ||||
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | ||||
else: | else: | ||||
assert "last" in all_saved_model_paths | |||||
assert "trainer-last" in all_saved_model_paths | |||||
aLL_topk_folders = [] | aLL_topk_folders = [] | ||||
for each_folder_name in all_saved_model_paths: | for each_folder_name in all_saved_model_paths: | ||||
@@ -681,7 +684,7 @@ def test_trainer_checkpoint_callback_2( | |||||
aLL_topk_folders.append(each_folder_name[0]) | aLL_topk_folders.append(each_folder_name[0]) | ||||
assert len(aLL_topk_folders) == 1 | assert len(aLL_topk_folders) == 1 | ||||
last_save_path = all_saved_model_paths["last"] | |||||
last_save_path = all_saved_model_paths["trainer-last"] | |||||
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] | topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] | ||||
assert len(all_saved_model_paths) == 2 | assert len(all_saved_model_paths) == 2 | ||||
@@ -105,6 +105,20 @@ class TestDataSetMethods(unittest.TestCase): | |||||
self.assertTrue(isinstance(field_array, FieldArray)) | self.assertTrue(isinstance(field_array, FieldArray)) | ||||
self.assertEqual(len(field_array), 40) | self.assertEqual(len(field_array), 40) | ||||
def test_setitem(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||||
ds.add_field('i', list(range(len(ds)))) | |||||
assert ds.get_field('i').content == list(range(len(ds))) | |||||
import random | |||||
random.shuffle(ds) | |||||
import numpy as np | |||||
np.random.shuffle(ds) | |||||
assert ds.get_field('i').content != list(range(len(ds))) | |||||
ins1 = ds[1] | |||||
ds[2] = ds[1] | |||||
assert ds[2]['x'] == ins1['x'] and ds[2]['y'] == ins1['y'] | |||||
def test_get_item_error(self): | def test_get_item_error(self): | ||||
with self.assertRaises(RuntimeError): | with self.assertRaises(RuntimeError): | ||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | ||||