@@ -390,4 +390,23 @@ class HasMonitorCallback(Callback): | |||
if (self.larger_better and monitor_value1 > monitor_value2) or \ | |||
(not self.larger_better and monitor_value1 < monitor_value2): | |||
better = True | |||
return better | |||
return better | |||
@property | |||
def monitor_name(self): | |||
""" | |||
返回 monitor 的名字,如果 monitor 是个 callable 的函数,则返回该函数的名称。 | |||
:return: | |||
""" | |||
if callable(self.monitor): | |||
try: | |||
monitor_name = self.monitor.__qualname__ | |||
except: | |||
monitor_name = self.monitor.__name__ | |||
elif self.monitor is None: | |||
return None | |||
else: | |||
# 这里是能是monitor,而不能是real_monitor,因为用户再次运行的时候real_monitor被初始化为monitor了 | |||
monitor_name = str(self.monitor) | |||
return monitor_name |
@@ -19,11 +19,11 @@ from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir | |||
class CheckpointCallback(HasMonitorCallback): | |||
def __init__( | |||
self, | |||
monitor, | |||
monitor:Optional[Union[str, Callable]]=None, | |||
save_folder: Optional[Union[str, Path]] = None, | |||
save_every_n_epochs: Optional[int] = None, | |||
save_every_n_batches: Optional[int] = None, | |||
save_last: bool = True, | |||
save_last: bool = False, | |||
save_topk: Optional[int] = None, | |||
save_on_exception: Optional[Union[BaseException, Sequence[BaseException]]] = None, | |||
larger_better: bool = True, | |||
@@ -31,12 +31,32 @@ class CheckpointCallback(HasMonitorCallback): | |||
model_save_fn: Optional[Callable] = None, | |||
**kwargs, | |||
): | |||
""" | |||
请使用 ModelCheckpointCallback 与 TrainerCheckpointCallback 。 | |||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | |||
:param save_every_n_batches: 多少个 batch 保存一次。 | |||
:param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 | |||
:param save_topk: 保存 monitor 结果 topK 个。 | |||
:param save_on_exception: 在出异常信息时,是否保存。传入需要捕获的异常的类。 | |||
:param larger_better: monitor 的值是否时越大越好。 | |||
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 | |||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||
:param kwargs: | |||
""" | |||
super().__init__(monitor=monitor, larger_better=larger_better, | |||
must_have_monitor=save_topk is not None) | |||
if save_folder is None: | |||
logger.warning( | |||
"Parameter `path` is None, and we will use the current work directory to find and load your model.") | |||
save_folder = Path.cwd() | |||
save_folder = Path(save_folder) | |||
if not save_folder.exists(): | |||
raise NotADirectoryError(f"Path '{save_folder.absolute()}' is not existed!") | |||
elif save_folder.is_file(): | |||
@@ -71,7 +91,7 @@ class CheckpointCallback(HasMonitorCallback): | |||
else: | |||
save_on_exception = [] | |||
self.save_folder = Path(save_folder) | |||
self.save_folder = save_folder | |||
self.save_every_n_epochs = save_every_n_epochs | |||
self.save_every_n_batches = save_every_n_batches | |||
self.save_last = save_last | |||
@@ -88,8 +108,7 @@ class CheckpointCallback(HasMonitorCallback): | |||
# 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候, | |||
# FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的; | |||
self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) | |||
# 我们只需要保证这个创建文件夹的操作只在进程 0 上进行即可;因为后续的实际的保存操作,其它进程实际并不会去执行; | |||
synchronize_mkdir(self.timestamp_path) | |||
# 该 folder 只在保存真的要发生的时候再创建。 | |||
def on_after_trainer_initialized(self, trainer, driver): | |||
if self.save_topk is not None: | |||
@@ -98,8 +117,6 @@ class CheckpointCallback(HasMonitorCallback): | |||
logger.warning("You set `save_topk`, but `evaluate_dataloaders` is not set in Trainer.") | |||
def on_validate_end(self, trainer, results): | |||
if len(results) == 0: | |||
return | |||
self._save_topk(trainer, results) | |||
def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): | |||
@@ -136,16 +153,17 @@ class CheckpointCallback(HasMonitorCallback): | |||
states['timestamp_path'] = str(self.timestamp_path.absolute()) | |||
states['_topk_model'] = deepcopy(self._topk_model) | |||
states['save_topk'] = 0 if self.save_topk is None else self.save_topk | |||
states['_real_monitor'] = self._real_monitor | |||
if isinstance(self._real_monitor, str): | |||
states['_real_monitor'] = self._real_monitor | |||
return states | |||
def on_load_checkpoint(self, trainer, states: Optional[Dict]): | |||
timestamp_path = states['timestamp_path'] | |||
if not os.path.exists(timestamp_path): | |||
logger.info(f"The resuming save folder {timestamp_path} is not exists, will checkpoint save to " | |||
logger.info(f"The resuming checkpoint folder {timestamp_path} is not exists, will checkpoint save to " | |||
f" {self.timestamp_path.absolute()}.") | |||
else: | |||
logger.info(f"Resume to save in path: {timestamp_path}.") | |||
logger.info(f"Resume to checkpoint in path: {timestamp_path}.") | |||
self.timestamp_path = Path(timestamp_path) | |||
_topk_model = states['_topk_model'] | |||
save_topk = None if int(states['save_topk']) == 0 else int(states['save_topk']) | |||
@@ -153,7 +171,8 @@ class CheckpointCallback(HasMonitorCallback): | |||
assert self.save_topk == save_topk, f"The checkpoint set save_topk={save_topk}, while this callback set it " \ | |||
f"as {save_topk}." | |||
self._topk_model.update(self._topk_model) | |||
self._real_monitor = states["real_monitor"] | |||
self._real_monitor = states["_real_monitor"] | |||
def _save_topk(self, trainer: "fastNLP.Trainer", results: Dict): | |||
""" | |||
@@ -231,9 +250,9 @@ class ModelCheckpointCallback(CheckpointCallback): | |||
model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。 | |||
若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 | |||
:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型), | |||
返回一个 float 值作为 monitor 的结果。 | |||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | |||
@@ -249,6 +268,11 @@ class ModelCheckpointCallback(CheckpointCallback): | |||
""" | |||
@property | |||
def save_fn_name(self): | |||
""" | |||
调用 Trainer 中的哪个函数。 | |||
:return: | |||
""" | |||
return 'save_model' | |||
@property | |||
@@ -257,7 +281,7 @@ class ModelCheckpointCallback(CheckpointCallback): | |||
通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; | |||
:return: | |||
""" | |||
return f"model_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" | |||
return f"model_checkpoint#monitor-{self.monitor_name}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" | |||
@property | |||
def folder_prefix(self): | |||
@@ -279,9 +303,9 @@ class TrainerCheckpointCallback(CheckpointCallback): | |||
model_save_fn 为 None ,则以上每个 folder 中,将生成两个文件:fastnlp_trainer.pkl.tar 以及 fastnlp_model.pkl.tar 。 | |||
若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 | |||
:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型), | |||
返回一个 float 值作为 monitor 的结果。 | |||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | |||
@@ -297,6 +321,11 @@ class TrainerCheckpointCallback(CheckpointCallback): | |||
""" | |||
@property | |||
def save_fn_name(self): | |||
""" | |||
调用 Trainer 中的哪个函数。 | |||
:return: | |||
""" | |||
return 'save' | |||
@property | |||
@@ -305,7 +334,8 @@ class TrainerCheckpointCallback(CheckpointCallback): | |||
通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; | |||
:return: | |||
""" | |||
return f"trainer_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" | |||
return f"trainer_checkpoint#monitor-{self.monitor_name}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" | |||
@property | |||
def folder_prefix(self): | |||
@@ -12,8 +12,9 @@ class EarlyStopCallback(HasMonitorCallback): | |||
def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10): | |||
""" | |||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 | |||
evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
:param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
:param larger_better: monitor 的值是否是越大越好。 | |||
:param patience: 多少次 validate 不没有提升就停止。 | |||
""" | |||
@@ -46,17 +47,20 @@ class EarlyStopCallback(HasMonitorCallback): | |||
states = { | |||
'patience': self.patience, | |||
'wait': self.wait, | |||
'monitor': self.monitor, | |||
'monitor_value': self.monitor_value | |||
} | |||
if not callable(self._real_monitor): | |||
states['_real_monitor'] = self._real_monitor | |||
return states | |||
def on_load_checkpoint(self, trainer, states): | |||
self.patience = states['patience'] | |||
self.wait = states['wait'] | |||
self.monitor = states['monitor'] | |||
self.monitor_value = float(states['monitor_value']) | |||
if '_real_monitor' in states: | |||
self._real_monitor = states['_real_monitor'] | |||
@property | |||
def callback_name(self): | |||
return f'EarlyStopCallback#monitor-{self.monitor}#patience-{self.patience}' | |||
return f'EarlyStopCallback#monitor-{self.monitor_name}#patience-{self.patience}' | |||
@@ -21,8 +21,9 @@ class LoadBestModelCallback(HasMonitorCallback): | |||
""" | |||
保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 | |||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 | |||
evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
:param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
:param larger_better: 该 metric 值是否是越大越好。 | |||
:param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | |||
不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | |||
@@ -44,10 +44,11 @@ class RichCallback(ProgressCallback): | |||
:param print_every: 多少个 batch 更新一次显示。 | |||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | |||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。如果为 None ,会尝试使用 trainer 中设置的 monitor 。 | |||
也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
:param larger_better: 是否是monitor的结果越大越好。 | |||
:param format_json: 是否format json再打印 | |||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 | |||
完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor | |||
。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
:param larger_better: 是否是 monitor 的结果越大越好。 | |||
:param format_json: 是否格式化 json 再打印 | |||
""" | |||
super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False) | |||
self.print_every = print_every | |||
@@ -136,8 +137,9 @@ class RawTextCallback(ProgressCallback): | |||
:param print_every: 多少个 batch 更新一次显示。 | |||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | |||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。也可以传入一个函数,接受参数为 evaluation 的结果( | |||
字典类型),返回一个 float 值作为 monitor 的结果。 | |||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 | |||
完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor | |||
。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
:param larger_better: 是否是monitor的结果越大越好。 | |||
:param format_json: 是否format json再打印 | |||
""" | |||
@@ -36,7 +36,7 @@ class Evaluator: | |||
model, | |||
dataloaders, | |||
metrics: Optional[Union[Dict, Metric]] = None, | |||
driver: Union[str, Driver] = 'single', | |||
driver: Union[str, Driver] = 'torch', | |||
device: Optional[Union[int, List[int], str]] = None, | |||
batch_step_fn: Optional[callable] = None, | |||
evaluate_fn: Optional[str] = None, # 首先尝试找 evaluate_step, 找不到 forward, callable | |||
@@ -49,8 +49,8 @@ class Evaluator: | |||
): | |||
""" | |||
:param dataloaders: | |||
:param model: | |||
:param dataloaders: | |||
:param metrics: 使用的 metric 。必须为 dict 类型,其中 key 为 metric 的名称,value 为一个 Metric 对象。支持 fastNLP 的 | |||
metric ,torchmetrics,allennlpmetrics等。 | |||
:param driver: 使用 driver 。 | |||
@@ -120,7 +120,8 @@ class Evaluator: | |||
if evaluate_fn is not None and not isinstance(evaluate_fn, str): | |||
raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") | |||
self._evaluate_step, self._evaluate_step_signature_fn = self.driver.get_model_call_fn("evaluate_step" if evaluate_fn is None else evaluate_fn) | |||
self._evaluate_step, self._evaluate_step_signature_fn = \ | |||
self.driver.get_model_call_fn("evaluate_step" if evaluate_fn is None else evaluate_fn) | |||
self.evaluate_fn = evaluate_fn | |||
self.dataloaders = {} | |||
@@ -134,8 +135,6 @@ class Evaluator: | |||
self.driver.barrier() | |||
self.driver.check_dataloader_legality(self.dataloaders, "dataloaders", is_train=False) | |||
def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: | |||
""" | |||
返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。 | |||
@@ -20,7 +20,7 @@ class TrainBatchLoop(Loop): | |||
else lambda *args, **kwargs: None | |||
dataloader = iter(dataloader) | |||
indices = None | |||
while True: | |||
while trainer.batch_idx_in_epoch<=trainer.num_batches_per_epoch: | |||
try: | |||
trainer.on_fetch_data_begin() | |||
batch = next(dataloader) | |||
@@ -30,10 +30,8 @@ class TrainBatchLoop(Loop): | |||
batch = trainer.move_data_to_device(batch) | |||
except StopIteration: | |||
break | |||
except EarlyStopException: # 在 Trainer 处理 earlystop 的 exception | |||
break | |||
except BaseException as e: | |||
if indices: | |||
if indices and not isinstance(e, EarlyStopException): | |||
logger.debug(f"The following exception happens when running on samples: {indices}") | |||
raise e | |||
@@ -174,9 +174,8 @@ class Trainer(TrainerEventTrigger): | |||
optimizers=optimizers, | |||
device=device, | |||
n_epochs=n_epochs, | |||
validate_dataloaders=evaluate_dataloaders, | |||
batch_step_fn=batch_step_fn, | |||
validate_batch_step_fn=evaluate_batch_step_fn, | |||
z=evaluate_batch_step_fn, | |||
evaluate_fn=evaluate_fn, | |||
callbacks=callbacks, | |||
metrics=metrics, | |||
@@ -264,7 +263,6 @@ class Trainer(TrainerEventTrigger): | |||
self.on_after_trainer_initialized(self.driver) | |||
self.driver.barrier() | |||
self.driver.check_dataloader_legality(self.train_dataloader, "train_dataloader", is_train=True) | |||
def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, | |||
num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, | |||
@@ -310,7 +308,7 @@ class Trainer(TrainerEventTrigger): | |||
self.num_batches_per_epoch = len(self.dataloader) | |||
self.total_batches = self.num_batches_per_epoch * self.n_epochs | |||
self.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + self.batch_idx_in_epoch | |||
self.on_train_begin() | |||
self.driver.barrier() | |||
self.driver.zero_grad(self.set_grad_to_none) | |||
@@ -637,6 +635,8 @@ class Trainer(TrainerEventTrigger): | |||
:param folder: 保存断点重训 states 的文件地址; | |||
:param resume_training: 是否从上次的 batch 开始训练,或者只从最近的 epoch 开始训练;注意如果 resume_training=True,那么我们 | |||
只会加载 model 和 optimizers 的状态;而其余的对象的值则根据用户的 Trainer 的初始化直接重置; | |||
:param only_state_dict: 保存的 model 是否只包含了权重。 | |||
:param model_load_fn: 使用的模型加载函数,参数应为一个 文件夹,不返回任何内容。 | |||
""" | |||
self.driver.barrier() | |||
if isinstance(folder, str): | |||
@@ -675,8 +675,6 @@ class Trainer(TrainerEventTrigger): | |||
# 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 | |||
# '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 | |||
self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') | |||
self.trainer_state.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + \ | |||
self.batch_idx_in_epoch | |||
# 这个是防止用户在 Trainer.load 之后还没结束当前 epoch 又继续 save | |||
self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch | |||
@@ -65,10 +65,10 @@ class TrainerState: | |||
""" | |||
n_epochs: Optional[int] = None # 无论如何重新算 | |||
cur_epoch_idx: Optional[int] = None # 断点重训; 仅当 resume=False 时为0; | |||
global_forward_batches: Optional[int] = None # 断点重训 | |||
cur_epoch_idx: Optional[int] = 0 # 断点重训; 仅当 resume=False 时为0; | |||
global_forward_batches: Optional[int] = 0 # 断点重训 | |||
batch_idx_in_epoch: Optional[int] = None # 断点重训 | |||
batch_idx_in_epoch: Optional[int] = 0 # 断点重训 | |||
num_batches_per_epoch: Optional[int] = None # 无论如何重新算 | |||
@@ -86,7 +86,7 @@ class Driver(ABC): | |||
函数; | |||
:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; | |||
:param fn: 由 Trainer 传入的用于网络前向传播一次的函数; | |||
:param fn: 调用该函数进行一次计算。 | |||
:param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call | |||
函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward; | |||
:return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); | |||
@@ -126,17 +126,6 @@ class Driver(ABC): | |||
def model(self, model): | |||
self._model = model | |||
@staticmethod | |||
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||
r""" | |||
该函数会在 trainer 或者 evaluator 设置 dataloader 后检测 dataloader 的合法性,因为不同的深度学习的框架需要的 dataloader 的 | |||
行为是不相同的; | |||
:param dataloader: 需要检测的输入的 `dataloader`; | |||
:param dataloader_name: | |||
""" | |||
raise NotImplementedError("Each specific driver should implemented its own `check_dataloader_legality` function.") | |||
@property | |||
def optimizers(self) -> List: | |||
r""" | |||
@@ -406,7 +406,7 @@ class TorchDDPDriver(TorchDriver): | |||
if hasattr(model, fn): | |||
fn = getattr(model, fn) | |||
if not callable(fn): | |||
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") | |||
raise RuntimeError(f"The `{fn}` attribute of model is not `Callable`.") | |||
return fn, None | |||
elif fn in {"train_step", "evaluate_step"}: | |||
return model, model.forward | |||
@@ -199,6 +199,7 @@ class TorchDriver(Driver): | |||
num_consumed_batches = sampler_states['num_consumed_samples'] | |||
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||
states['sampler_states'] = sampler_states | |||
else: | |||
raise RuntimeError( | |||
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | |||
@@ -181,6 +181,7 @@ class TestCheckNumberOfParameters: | |||
def test_get_fun_msg(): | |||
# 测试运行 | |||
def demo(x): | |||
pass | |||
@@ -1,3 +1,6 @@ | |||
import numpy as np | |||
class NormalIterator: | |||
def __init__(self, num_of_data=1000): | |||
self._num_of_data = num_of_data | |||
@@ -15,4 +18,15 @@ class NormalIterator: | |||
return self._data | |||
def __len__(self): | |||
return self._num_of_data | |||
return self._num_of_data | |||
class RandomDataset: | |||
def __init__(self, num_data=10): | |||
self.data = np.random.rand(num_data) | |||
def __len__(self): | |||
return len(self.data) | |||
def __getitem__(self, item): | |||
return self.data[item] |