diff --git a/fastNLP/core/callbacks/callback.py b/fastNLP/core/callbacks/callback.py index 902421c8..b37eda63 100644 --- a/fastNLP/core/callbacks/callback.py +++ b/fastNLP/core/callbacks/callback.py @@ -390,4 +390,23 @@ class HasMonitorCallback(Callback): if (self.larger_better and monitor_value1 > monitor_value2) or \ (not self.larger_better and monitor_value1 < monitor_value2): better = True - return better \ No newline at end of file + return better + + @property + def monitor_name(self): + """ + 返回 monitor 的名字,如果 monitor 是个 callable 的函数,则返回该函数的名称。 + + :return: + """ + if callable(self.monitor): + try: + monitor_name = self.monitor.__qualname__ + except: + monitor_name = self.monitor.__name__ + elif self.monitor is None: + return None + else: + # 这里是能是monitor,而不能是real_monitor,因为用户再次运行的时候real_monitor被初始化为monitor了 + monitor_name = str(self.monitor) + return monitor_name diff --git a/fastNLP/core/callbacks/checkpoint_callback.py b/fastNLP/core/callbacks/checkpoint_callback.py index 7bbdb2fe..d2d97294 100644 --- a/fastNLP/core/callbacks/checkpoint_callback.py +++ b/fastNLP/core/callbacks/checkpoint_callback.py @@ -19,11 +19,11 @@ from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir class CheckpointCallback(HasMonitorCallback): def __init__( self, - monitor, + monitor:Optional[Union[str, Callable]]=None, save_folder: Optional[Union[str, Path]] = None, save_every_n_epochs: Optional[int] = None, save_every_n_batches: Optional[int] = None, - save_last: bool = True, + save_last: bool = False, save_topk: Optional[int] = None, save_on_exception: Optional[Union[BaseException, Sequence[BaseException]]] = None, larger_better: bool = True, @@ -31,12 +31,32 @@ class CheckpointCallback(HasMonitorCallback): model_save_fn: Optional[Callable] = None, **kwargs, ): + """ + 请使用 ModelCheckpointCallback 与 TrainerCheckpointCallback 。 + + :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 + 果(字典类型),返回一个 float 值作为 monitor 的结果。 + :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 + 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 + :param save_every_n_epochs: 多少个 epoch 保存一次。 + :param save_every_n_batches: 多少个 batch 保存一次。 + :param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 + :param save_topk: 保存 monitor 结果 topK 个。 + :param save_on_exception: 在出异常信息时,是否保存。传入需要捕获的异常的类。 + :param larger_better: monitor 的值是否时越大越好。 + :param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 + :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 + 如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 + :param kwargs: + """ super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=save_topk is not None) if save_folder is None: logger.warning( "Parameter `path` is None, and we will use the current work directory to find and load your model.") save_folder = Path.cwd() + save_folder = Path(save_folder) if not save_folder.exists(): raise NotADirectoryError(f"Path '{save_folder.absolute()}' is not existed!") elif save_folder.is_file(): @@ -71,7 +91,7 @@ class CheckpointCallback(HasMonitorCallback): else: save_on_exception = [] - self.save_folder = Path(save_folder) + self.save_folder = save_folder self.save_every_n_epochs = save_every_n_epochs self.save_every_n_batches = save_every_n_batches self.save_last = save_last @@ -88,8 +108,7 @@ class CheckpointCallback(HasMonitorCallback): # 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候, # FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的; self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) - # 我们只需要保证这个创建文件夹的操作只在进程 0 上进行即可;因为后续的实际的保存操作,其它进程实际并不会去执行; - synchronize_mkdir(self.timestamp_path) + # 该 folder 只在保存真的要发生的时候再创建。 def on_after_trainer_initialized(self, trainer, driver): if self.save_topk is not None: @@ -98,8 +117,6 @@ class CheckpointCallback(HasMonitorCallback): logger.warning("You set `save_topk`, but `evaluate_dataloaders` is not set in Trainer.") def on_validate_end(self, trainer, results): - if len(results) == 0: - return self._save_topk(trainer, results) def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): @@ -136,16 +153,17 @@ class CheckpointCallback(HasMonitorCallback): states['timestamp_path'] = str(self.timestamp_path.absolute()) states['_topk_model'] = deepcopy(self._topk_model) states['save_topk'] = 0 if self.save_topk is None else self.save_topk - states['_real_monitor'] = self._real_monitor + if isinstance(self._real_monitor, str): + states['_real_monitor'] = self._real_monitor return states def on_load_checkpoint(self, trainer, states: Optional[Dict]): timestamp_path = states['timestamp_path'] if not os.path.exists(timestamp_path): - logger.info(f"The resuming save folder {timestamp_path} is not exists, will checkpoint save to " + logger.info(f"The resuming checkpoint folder {timestamp_path} is not exists, will checkpoint save to " f" {self.timestamp_path.absolute()}.") else: - logger.info(f"Resume to save in path: {timestamp_path}.") + logger.info(f"Resume to checkpoint in path: {timestamp_path}.") self.timestamp_path = Path(timestamp_path) _topk_model = states['_topk_model'] save_topk = None if int(states['save_topk']) == 0 else int(states['save_topk']) @@ -153,7 +171,8 @@ class CheckpointCallback(HasMonitorCallback): assert self.save_topk == save_topk, f"The checkpoint set save_topk={save_topk}, while this callback set it " \ f"as {save_topk}." self._topk_model.update(self._topk_model) - self._real_monitor = states["real_monitor"] + + self._real_monitor = states["_real_monitor"] def _save_topk(self, trainer: "fastNLP.Trainer", results: Dict): """ @@ -231,9 +250,9 @@ class ModelCheckpointCallback(CheckpointCallback): model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。 若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 - :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 - 的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型), - 返回一个 float 值作为 monitor 的结果。 + :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 + 果(字典类型),返回一个 float 值作为 monitor 的结果。 :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 :param save_every_n_epochs: 多少个 epoch 保存一次。 @@ -249,6 +268,11 @@ class ModelCheckpointCallback(CheckpointCallback): """ @property def save_fn_name(self): + """ + 调用 Trainer 中的哪个函数。 + + :return: + """ return 'save_model' @property @@ -257,7 +281,7 @@ class ModelCheckpointCallback(CheckpointCallback): 通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; :return: """ - return f"model_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" + return f"model_checkpoint#monitor-{self.monitor_name}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" @property def folder_prefix(self): @@ -279,9 +303,9 @@ class TrainerCheckpointCallback(CheckpointCallback): model_save_fn 为 None ,则以上每个 folder 中,将生成两个文件:fastnlp_trainer.pkl.tar 以及 fastnlp_model.pkl.tar 。 若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 - :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 - 的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型), - 返回一个 float 值作为 monitor 的结果。 + :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 + 果(字典类型),返回一个 float 值作为 monitor 的结果。 :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 :param save_every_n_epochs: 多少个 epoch 保存一次。 @@ -297,6 +321,11 @@ class TrainerCheckpointCallback(CheckpointCallback): """ @property def save_fn_name(self): + """ + 调用 Trainer 中的哪个函数。 + + :return: + """ return 'save' @property @@ -305,7 +334,8 @@ class TrainerCheckpointCallback(CheckpointCallback): 通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; :return: """ - return f"trainer_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" + + return f"trainer_checkpoint#monitor-{self.monitor_name}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" @property def folder_prefix(self): diff --git a/fastNLP/core/callbacks/early_stop_callback.py b/fastNLP/core/callbacks/early_stop_callback.py index b1842d43..c679ad7e 100644 --- a/fastNLP/core/callbacks/early_stop_callback.py +++ b/fastNLP/core/callbacks/early_stop_callback.py @@ -12,8 +12,9 @@ class EarlyStopCallback(HasMonitorCallback): def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10): """ - :param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 - evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 + :param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 + 果(字典类型),返回一个 float 值作为 monitor 的结果。 :param larger_better: monitor 的值是否是越大越好。 :param patience: 多少次 validate 不没有提升就停止。 """ @@ -46,17 +47,20 @@ class EarlyStopCallback(HasMonitorCallback): states = { 'patience': self.patience, 'wait': self.wait, - 'monitor': self.monitor, 'monitor_value': self.monitor_value } + if not callable(self._real_monitor): + states['_real_monitor'] = self._real_monitor return states def on_load_checkpoint(self, trainer, states): self.patience = states['patience'] self.wait = states['wait'] - self.monitor = states['monitor'] self.monitor_value = float(states['monitor_value']) + if '_real_monitor' in states: + self._real_monitor = states['_real_monitor'] + @property def callback_name(self): - return f'EarlyStopCallback#monitor-{self.monitor}#patience-{self.patience}' + return f'EarlyStopCallback#monitor-{self.monitor_name}#patience-{self.patience}' diff --git a/fastNLP/core/callbacks/load_best_model_callback.py b/fastNLP/core/callbacks/load_best_model_callback.py index e068326b..09f85d01 100644 --- a/fastNLP/core/callbacks/load_best_model_callback.py +++ b/fastNLP/core/callbacks/load_best_model_callback.py @@ -21,8 +21,9 @@ class LoadBestModelCallback(HasMonitorCallback): """ 保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 - :param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 - evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 + :param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 + 果(字典类型),返回一个 float 值作为 monitor 的结果。 :param larger_better: 该 metric 值是否是越大越好。 :param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 diff --git a/fastNLP/core/callbacks/progress_callback.py b/fastNLP/core/callbacks/progress_callback.py index 67176387..f351f204 100644 --- a/fastNLP/core/callbacks/progress_callback.py +++ b/fastNLP/core/callbacks/progress_callback.py @@ -44,10 +44,11 @@ class RichCallback(ProgressCallback): :param print_every: 多少个 batch 更新一次显示。 :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 - :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。如果为 None ,会尝试使用 trainer 中设置的 monitor 。 - 也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 - :param larger_better: 是否是monitor的结果越大越好。 - :param format_json: 是否format json再打印 + :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 + 完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor + 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 + :param larger_better: 是否是 monitor 的结果越大越好。 + :param format_json: 是否格式化 json 再打印 """ super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False) self.print_every = print_every @@ -136,8 +137,9 @@ class RawTextCallback(ProgressCallback): :param print_every: 多少个 batch 更新一次显示。 :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 - :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。也可以传入一个函数,接受参数为 evaluation 的结果( - 字典类型),返回一个 float 值作为 monitor 的结果。 + :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 + 完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor + 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 :param larger_better: 是否是monitor的结果越大越好。 :param format_json: 是否format json再打印 """ diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index 3013c316..d447a0f2 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -36,7 +36,7 @@ class Evaluator: model, dataloaders, metrics: Optional[Union[Dict, Metric]] = None, - driver: Union[str, Driver] = 'single', + driver: Union[str, Driver] = 'torch', device: Optional[Union[int, List[int], str]] = None, batch_step_fn: Optional[callable] = None, evaluate_fn: Optional[str] = None, # 首先尝试找 evaluate_step, 找不到 forward, callable @@ -49,8 +49,8 @@ class Evaluator: ): """ - :param dataloaders: :param model: + :param dataloaders: :param metrics: 使用的 metric 。必须为 dict 类型,其中 key 为 metric 的名称,value 为一个 Metric 对象。支持 fastNLP 的 metric ,torchmetrics,allennlpmetrics等。 :param driver: 使用 driver 。 @@ -120,7 +120,8 @@ class Evaluator: if evaluate_fn is not None and not isinstance(evaluate_fn, str): raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") - self._evaluate_step, self._evaluate_step_signature_fn = self.driver.get_model_call_fn("evaluate_step" if evaluate_fn is None else evaluate_fn) + self._evaluate_step, self._evaluate_step_signature_fn = \ + self.driver.get_model_call_fn("evaluate_step" if evaluate_fn is None else evaluate_fn) self.evaluate_fn = evaluate_fn self.dataloaders = {} @@ -134,8 +135,6 @@ class Evaluator: self.driver.barrier() - self.driver.check_dataloader_legality(self.dataloaders, "dataloaders", is_train=False) - def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: """ 返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。 diff --git a/fastNLP/core/controllers/loops/train_batch_loop.py b/fastNLP/core/controllers/loops/train_batch_loop.py index a3219e6d..7dbe9775 100644 --- a/fastNLP/core/controllers/loops/train_batch_loop.py +++ b/fastNLP/core/controllers/loops/train_batch_loop.py @@ -20,7 +20,7 @@ class TrainBatchLoop(Loop): else lambda *args, **kwargs: None dataloader = iter(dataloader) indices = None - while True: + while trainer.batch_idx_in_epoch<=trainer.num_batches_per_epoch: try: trainer.on_fetch_data_begin() batch = next(dataloader) @@ -30,10 +30,8 @@ class TrainBatchLoop(Loop): batch = trainer.move_data_to_device(batch) except StopIteration: break - except EarlyStopException: # 在 Trainer 处理 earlystop 的 exception - break except BaseException as e: - if indices: + if indices and not isinstance(e, EarlyStopException): logger.debug(f"The following exception happens when running on samples: {indices}") raise e diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 3859fca0..2d5fcfd4 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -264,7 +264,6 @@ class Trainer(TrainerEventTrigger): self.on_after_trainer_initialized(self.driver) self.driver.barrier() - self.driver.check_dataloader_legality(self.train_dataloader, "train_dataloader", is_train=True) def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, @@ -310,7 +309,7 @@ class Trainer(TrainerEventTrigger): self.num_batches_per_epoch = len(self.dataloader) self.total_batches = self.num_batches_per_epoch * self.n_epochs - + self.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + self.batch_idx_in_epoch self.on_train_begin() self.driver.barrier() self.driver.zero_grad(self.set_grad_to_none) @@ -637,6 +636,8 @@ class Trainer(TrainerEventTrigger): :param folder: 保存断点重训 states 的文件地址; :param resume_training: 是否从上次的 batch 开始训练,或者只从最近的 epoch 开始训练;注意如果 resume_training=True,那么我们 只会加载 model 和 optimizers 的状态;而其余的对象的值则根据用户的 Trainer 的初始化直接重置; + :param only_state_dict: 保存的 model 是否只包含了权重。 + :param model_load_fn: 使用的模型加载函数,参数应为一个 文件夹,不返回任何内容。 """ self.driver.barrier() if isinstance(folder, str): @@ -675,8 +676,6 @@ class Trainer(TrainerEventTrigger): # 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 # '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') - self.trainer_state.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + \ - self.batch_idx_in_epoch # 这个是防止用户在 Trainer.load 之后还没结束当前 epoch 又继续 save self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch diff --git a/fastNLP/core/controllers/utils/state.py b/fastNLP/core/controllers/utils/state.py index 2327c1e5..496533d2 100644 --- a/fastNLP/core/controllers/utils/state.py +++ b/fastNLP/core/controllers/utils/state.py @@ -65,10 +65,10 @@ class TrainerState: """ n_epochs: Optional[int] = None # 无论如何重新算 - cur_epoch_idx: Optional[int] = None # 断点重训; 仅当 resume=False 时为0; - global_forward_batches: Optional[int] = None # 断点重训 + cur_epoch_idx: Optional[int] = 0 # 断点重训; 仅当 resume=False 时为0; + global_forward_batches: Optional[int] = 0 # 断点重训 - batch_idx_in_epoch: Optional[int] = None # 断点重训 + batch_idx_in_epoch: Optional[int] = 0 # 断点重训 num_batches_per_epoch: Optional[int] = None # 无论如何重新算 diff --git a/fastNLP/core/drivers/driver.py b/fastNLP/core/drivers/driver.py index b1015b47..0ef7f053 100644 --- a/fastNLP/core/drivers/driver.py +++ b/fastNLP/core/drivers/driver.py @@ -86,7 +86,7 @@ class Driver(ABC): 函数; :param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; - :param fn: 由 Trainer 传入的用于网络前向传播一次的函数; + :param fn: 调用该函数进行一次计算。 :param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call 函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward; :return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); @@ -126,17 +126,6 @@ class Driver(ABC): def model(self, model): self._model = model - @staticmethod - def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): - r""" - 该函数会在 trainer 或者 evaluator 设置 dataloader 后检测 dataloader 的合法性,因为不同的深度学习的框架需要的 dataloader 的 - 行为是不相同的; - - :param dataloader: 需要检测的输入的 `dataloader`; - :param dataloader_name: - """ - raise NotImplementedError("Each specific driver should implemented its own `check_dataloader_legality` function.") - @property def optimizers(self) -> List: r""" diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index de0af7f2..977eaf2c 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -34,10 +34,10 @@ if _NEED_IMPORT_PADDLE: from paddle.optimizer import Optimizer _reduces = { - 'max': paddle.max, - 'min': paddle.min, - 'mean': paddle.mean, - 'sum': paddle.sum + "max": paddle.max, + "min": paddle.min, + "mean": paddle.mean, + "sum": paddle.sum } class PaddleDriver(Driver): @@ -254,24 +254,24 @@ class PaddleDriver(Driver): else: raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") - num_consumed_batches = states.pop('num_consumed_batches') - if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): + num_consumed_batches = states.pop("num_consumed_batches") + if hasattr(sampler, "state_dict") and callable(sampler.state_dict): sampler_states = sampler.state_dict() # 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples - # 会造成多余实际消耗的问题。 - num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) + # 会造成多余实际消耗的问题。 + num_consumed_samples_array = sampler_states.pop("num_consumed_samples_array", None) if num_consumed_samples_array is not None: - if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 - try: - num_consumed_batches = num_consumed_batches * dataloader_args.batch_size - except: # 有可能 batch_size 为 None,就只有损失精度了 - num_consumed_batches = sampler_states['num_consumed_samples'] - sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] - assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." - + sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] + else: + try: + sampler_states["num_consumed_samples"] = num_consumed_batches * dataloader_args.batch_size + except: # 有可能 batch_size 为 None,就只有损失精度了 + pass + assert sampler_states["num_consumed_samples"] != -1, "This is a bug, please report." else: raise RuntimeError( - 'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') + "The sampler has no `state_dict()` method, it will fail to recover to the specific batch.") + states["sampler_states"] = sampler_states # 2. 保存模型的状态; if should_save_model: @@ -326,7 +326,7 @@ class PaddleDriver(Driver): batch_size=dataloader_args.batch_size, drop_last=dataloader_args.drop_last ) - sampler.load_state_dict(states['sampler_states']) + sampler.load_state_dict(states["sampler_states"]) states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) # 4. 修改 trainer_state.batch_idx_in_epoch @@ -355,7 +355,7 @@ class PaddleDriver(Driver): return paddle.no_grad @staticmethod - def move_model_to_device(model: 'paddle.nn.Layer', device: Union[str, int, 'paddle.CUDAPlace', 'paddle.CPUPlace']): + def move_model_to_device(model: "paddle.nn.Layer", device: Union[str, int, "paddle.CUDAPlace", "paddle.CPUPlace"]): r""" 用来将模型转移到指定的 device 上; 在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 @@ -363,7 +363,7 @@ class PaddleDriver(Driver): if device is not None: model.to(device) - def move_data_to_device(self, batch: 'paddle.Tensor'): + def move_data_to_device(self, batch: "paddle.Tensor"): r""" 将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。 在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 @@ -404,7 +404,7 @@ class PaddleDriver(Driver): if int(os.environ.get(FASTNLP_SEED_WORKERS, 0)) and dataloader.worker_init_fn is None: dataloader.worker_init_fn = partial(self.worker_init_function, rank=self.global_rank) - def set_sampler_epoch(self, dataloader: 'DataLoader', cur_epoch_idx): + def set_sampler_epoch(self, dataloader: "DataLoader", cur_epoch_idx): r""" 对于分布式的 sampler,dataloader 需要在每一个 epoch 前设置随机数种子,来保证每一个进程上的 shuffle 是一样的; diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index 55af3367..a37525f4 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -406,7 +406,7 @@ class TorchDDPDriver(TorchDriver): if hasattr(model, fn): fn = getattr(model, fn) if not callable(fn): - raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") + raise RuntimeError(f"The `{fn}` attribute of model is not `Callable`.") return fn, None elif fn in {"train_step", "evaluate_step"}: return model, model.forward diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index b7aebec8..233d7040 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -199,6 +199,7 @@ class TorchDriver(Driver): num_consumed_batches = sampler_states['num_consumed_samples'] sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." + states['sampler_states'] = sampler_states else: raise RuntimeError( 'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') diff --git a/tests/core/controllers/test_trainer_paddle.py b/tests/core/controllers/test_trainer_paddle.py index 0f8657b2..69b16427 100644 --- a/tests/core/controllers/test_trainer_paddle.py +++ b/tests/core/controllers/test_trainer_paddle.py @@ -1,19 +1,20 @@ import pytest import os +os.environ["FASTNLP_BACKEND"] = "paddle" from typing import Any from dataclasses import dataclass -from paddle.optimizer import Adam -from paddle.io import DataLoader - from fastNLP.core.controllers.trainer import Trainer from fastNLP.core.metrics.accuracy import Accuracy from fastNLP.core.callbacks.progress_callback import RichCallback from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK +from paddle.optimizer import Adam +from paddle.io import DataLoader + -from tests.helpers.models.paddle_model import PaddleNormalModel_Classification -from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST +from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 +from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback from tests.helpers.utils import magic_argv_env_context @@ -48,64 +49,31 @@ class TrainerParameters: output_mapping: Any = None metrics: Any = None -# @pytest.fixture(params=[0], autouse=True) -# def model_and_optimizers(request): -# """ -# 初始化单卡模式的模型和优化器 -# """ -# trainer_params = TrainerParameters() -# print(paddle.device.get_device()) - -# if request.param == 0: -# trainer_params.model = PaddleNormalModel_Classification( -# num_labels=MNISTTrainPaddleConfig.num_labels, -# feature_dimension=MNISTTrainPaddleConfig.feature_dimension -# ) -# trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001) -# train_dataloader = DataLoader( -# dataset=PaddleDataset_MNIST("train"), -# batch_size=MNISTTrainPaddleConfig.batch_size, -# shuffle=True -# ) -# val_dataloader = DataLoader( -# dataset=PaddleDataset_MNIST(evaluate_fn="test"), -# batch_size=MNISTTrainPaddleConfig.batch_size, -# shuffle=True -# ) -# trainer_params.train_dataloader = train_dataloader -# trainer_params.evaluate_dataloaders = val_dataloader -# trainer_params.evaluate_every = MNISTTrainPaddleConfig.evaluate_every -# trainer_params.metrics = {"acc": Accuracy()} - -# return trainer_params - - -@pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1)]) +@pytest.mark.parametrize("driver,device", [("paddle", "cpu")("paddle", 1)]) # @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])]) @pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.7, larger_better=True), RichCallback(5), RecordLossCallback(loss_threshold=0.3)]]) @magic_argv_env_context def test_trainer_paddle( - # model_and_optimizers: TrainerParameters, driver, device, callbacks, - n_epochs=15, + n_epochs=2, ): trainer_params = TrainerParameters() - trainer_params.model = PaddleNormalModel_Classification( + trainer_params.model = PaddleNormalModel_Classification_1( num_labels=MNISTTrainPaddleConfig.num_labels, feature_dimension=MNISTTrainPaddleConfig.feature_dimension ) trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001) train_dataloader = DataLoader( - dataset=PaddleDataset_MNIST("train"), + dataset=PaddleRandomMaxDataset(6400, 10), batch_size=MNISTTrainPaddleConfig.batch_size, shuffle=True ) val_dataloader = DataLoader( - dataset=PaddleDataset_MNIST(mode="test"), + dataset=PaddleRandomMaxDataset(1000, 10), batch_size=MNISTTrainPaddleConfig.batch_size, shuffle=True ) @@ -113,39 +81,19 @@ def test_trainer_paddle( trainer_params.validate_dataloaders = val_dataloader trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every trainer_params.metrics = {"acc": Accuracy(backend="paddle")} - if not isinstance(device, (int, str)) and len(device) > 1 and FASTNLP_DISTRIBUTED_CHECK not in os.environ: - with pytest.raises(SystemExit) as exc: - trainer = Trainer( - model=trainer_params.model, - driver=driver, - device=device, - optimizers=trainer_params.optimizers, - train_dataloader=trainer_params.train_dataloader, - evaluate_dataloaders=trainer_params.validate_dataloaders, - evaluate_every=trainer_params.validate_every, - input_mapping=trainer_params.input_mapping, - output_mapping=trainer_params.output_mapping, - metrics=trainer_params.metrics, - - n_epochs=n_epochs, - callbacks=callbacks, - ) - assert exc.value.code == 0 - return - else: - trainer = Trainer( - model=trainer_params.model, - driver=driver, - device=device, - optimizers=trainer_params.optimizers, - train_dataloader=trainer_params.train_dataloader, - evaluate_dataloaders=trainer_params.validate_dataloaders, - evaluate_every=trainer_params.validate_every, - input_mapping=trainer_params.input_mapping, - output_mapping=trainer_params.output_mapping, - metrics=trainer_params.metrics, - - n_epochs=n_epochs, - callbacks=callbacks, - ) - trainer.run() \ No newline at end of file + trainer = Trainer( + model=trainer_params.model, + driver=driver, + device=device, + optimizers=trainer_params.optimizers, + train_dataloader=trainer_params.train_dataloader, + validate_dataloaders=trainer_params.validate_dataloaders, + validate_every=trainer_params.validate_every, + input_mapping=trainer_params.input_mapping, + output_mapping=trainer_params.output_mapping, + metrics=trainer_params.metrics, + + n_epochs=n_epochs, + callbacks=callbacks, + ) + trainer.run() diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index ec5bb846..9661c015 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -224,7 +224,6 @@ class TestSetDistReproDataloder: """ def setup_method(self): self.dataset = PaddleNormalDataset(20) - self.dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) model = PaddleNormalModel_Classification_1(10, 32) self.driver = PaddleSingleDriver(model, device="cpu") @@ -233,55 +232,59 @@ class TestSetDistReproDataloder: 测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现 当dist为字符串时,此时应该返回原来的 dataloader """ - replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=False) + dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) - assert replaced_loader is self.dataloader + assert replaced_loader is dataloader def test_set_dist_repro_dataloader_with_reproducible_true(self): """ 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 当dist为字符串时,此时应该返回新的 dataloader,且 batch_sampler 为 RandomBatchSampler """ - replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=True) + dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) - assert not (replaced_loader is self.dataloader) + assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) - assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size - assert replaced_loader.drop_last == self.dataloader.drop_last + assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size + assert replaced_loader.drop_last == dataloader.drop_last - # self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) + # self.check_set_dist_repro_dataloader(dataloader, replaced_loader) def test_set_dist_repro_dataloader_with_dist_batch_sampler(self): """ 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler 应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler """ + dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False) - replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) - assert not (replaced_loader is self.dataloader) + assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) assert replaced_loader.batch_sampler is dist - # self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader) def test_set_dist_repro_dataloader_with_dist_sampler(self): """ 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler """ + dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) dist = RandomSampler(self.dataset, shuffle=True) - replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) - assert not (replaced_loader is self.dataloader) + assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BatchSampler) assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) - assert not (replaced_loader.batch_sampler is self.dataloader.batch_sampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert replaced_loader.batch_sampler.sampler is dist - assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size + assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size - # self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader) def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self): """ @@ -295,11 +298,12 @@ class TestSetDistReproDataloder: replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.drop_last == dataloader.drop_last - # self.check_set_dist_repro_dataloader(dataloader, replaced_loader) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader) def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self): """ @@ -316,34 +320,52 @@ class TestSetDistReproDataloder: assert not (replaced_loader is dataloader) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) assert replaced_loader.batch_sampler.batch_size == 2 + assert replaced_loader.batch_sampler.sampler.shuffle == True - # self.check_set_dist_repro_dataloader(dataloader, replaced_loader) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader) def check_set_dist_repro_dataloader(self, dataloader, replaced_loader): """ 测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确 """ # 迭代两个 batch - # 这里会发生 BatchSampler 里 yield 了多次但 dataloader 只取出一次的情况。 + num_consumed_batches = 2 already_seen_idx = set() - for idx, batch in replaced_loader: - already_seen_idx.update(batch) - if idx >= 1: + for idx, batch in enumerate(replaced_loader): + if idx >= num_consumed_batches: break + already_seen_idx.update(batch) if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): sampler_states = replaced_loader.batch_sampler.state_dict() else: sampler_states = replaced_loader.batch_sampler.sampler.state_dict() - print(sampler_states["data_idx"]) + + # 加载 num_consumed_samples_array,设置正确取出的 batch 数目 + num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) + + import time + time.sleep(5) # 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range left_idxes = set() if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): + batch_size = replaced_loader.batch_sampler.batch_size + if num_consumed_samples_array is not None: + sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] + else: + sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size replaced_loader.batch_sampler.load_state_dict(sampler_states) else: + batch_size = replaced_loader.batch_sampler.batch_size + if num_consumed_samples_array is not None: + sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] + else: + sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size replaced_loader.batch_sampler.sampler.load_state_dict(sampler_states) + replaced_loader.batch_sampler.sampler.set_epoch(0) for idx, batch in enumerate(replaced_loader): left_idxes.update(batch) diff --git a/tests/core/utils/test_utils.py b/tests/core/utils/test_utils.py index a7aeffb1..556f85ff 100644 --- a/tests/core/utils/test_utils.py +++ b/tests/core/utils/test_utils.py @@ -181,6 +181,7 @@ class TestCheckNumberOfParameters: def test_get_fun_msg(): + # 测试运行 def demo(x): pass diff --git a/tests/helpers/datasets/normal_data.py b/tests/helpers/datasets/normal_data.py index ba1af370..714ec676 100644 --- a/tests/helpers/datasets/normal_data.py +++ b/tests/helpers/datasets/normal_data.py @@ -1,3 +1,6 @@ +import numpy as np + + class NormalIterator: def __init__(self, num_of_data=1000): self._num_of_data = num_of_data @@ -15,4 +18,15 @@ class NormalIterator: return self._data def __len__(self): - return self._num_of_data \ No newline at end of file + return self._num_of_data + + +class RandomDataset: + def __init__(self, num_data=10): + self.data = np.random.rand(num_data) + + def __len__(self): + return len(self.data) + + def __getitem__(self, item): + return self.data[item] \ No newline at end of file