Browse Source

修复若干bug

tags/v1.0.0alpha
yh_cc 2 years ago
parent
commit
a4b2e0fac5
14 changed files with 121 additions and 65 deletions
  1. +20
    -1
      fastNLP/core/callbacks/callback.py
  2. +49
    -19
      fastNLP/core/callbacks/checkpoint_callback.py
  3. +9
    -5
      fastNLP/core/callbacks/early_stop_callback.py
  4. +3
    -2
      fastNLP/core/callbacks/load_best_model_callback.py
  5. +8
    -6
      fastNLP/core/callbacks/progress_callback.py
  6. +4
    -5
      fastNLP/core/controllers/evaluator.py
  7. +2
    -4
      fastNLP/core/controllers/loops/train_batch_loop.py
  8. +4
    -6
      fastNLP/core/controllers/trainer.py
  9. +3
    -3
      fastNLP/core/controllers/utils/state.py
  10. +1
    -12
      fastNLP/core/drivers/driver.py
  11. +1
    -1
      fastNLP/core/drivers/torch_driver/ddp.py
  12. +1
    -0
      fastNLP/core/drivers/torch_driver/torch_driver.py
  13. +1
    -0
      tests/core/utils/test_utils.py
  14. +15
    -1
      tests/helpers/datasets/normal_data.py

+ 20
- 1
fastNLP/core/callbacks/callback.py View File

@@ -390,4 +390,23 @@ class HasMonitorCallback(Callback):
if (self.larger_better and monitor_value1 > monitor_value2) or \
(not self.larger_better and monitor_value1 < monitor_value2):
better = True
return better
return better

@property
def monitor_name(self):
"""
返回 monitor 的名字,如果 monitor 是个 callable 的函数,则返回该函数的名称。

:return:
"""
if callable(self.monitor):
try:
monitor_name = self.monitor.__qualname__
except:
monitor_name = self.monitor.__name__
elif self.monitor is None:
return None
else:
# 这里是能是monitor,而不能是real_monitor,因为用户再次运行的时候real_monitor被初始化为monitor了
monitor_name = str(self.monitor)
return monitor_name

+ 49
- 19
fastNLP/core/callbacks/checkpoint_callback.py View File

@@ -19,11 +19,11 @@ from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir
class CheckpointCallback(HasMonitorCallback):
def __init__(
self,
monitor,
monitor:Optional[Union[str, Callable]]=None,
save_folder: Optional[Union[str, Path]] = None,
save_every_n_epochs: Optional[int] = None,
save_every_n_batches: Optional[int] = None,
save_last: bool = True,
save_last: bool = False,
save_topk: Optional[int] = None,
save_on_exception: Optional[Union[BaseException, Sequence[BaseException]]] = None,
larger_better: bool = True,
@@ -31,12 +31,32 @@ class CheckpointCallback(HasMonitorCallback):
model_save_fn: Optional[Callable] = None,
**kwargs,
):
"""
请使用 ModelCheckpointCallback 与 TrainerCheckpointCallback 。

:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
果(字典类型),返回一个 float 值作为 monitor 的结果。
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的
时间戳文件夹中。如果为 None ,默认使用当前文件夹。
:param save_every_n_epochs: 多少个 epoch 保存一次。
:param save_every_n_batches: 多少个 batch 保存一次。
:param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。
:param save_topk: 保存 monitor 结果 topK 个。
:param save_on_exception: 在出异常信息时,是否保存。传入需要捕获的异常的类。
:param larger_better: monitor 的值是否时越大越好。
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。
:param kwargs:
"""
super().__init__(monitor=monitor, larger_better=larger_better,
must_have_monitor=save_topk is not None)
if save_folder is None:
logger.warning(
"Parameter `path` is None, and we will use the current work directory to find and load your model.")
save_folder = Path.cwd()
save_folder = Path(save_folder)
if not save_folder.exists():
raise NotADirectoryError(f"Path '{save_folder.absolute()}' is not existed!")
elif save_folder.is_file():
@@ -71,7 +91,7 @@ class CheckpointCallback(HasMonitorCallback):
else:
save_on_exception = []

self.save_folder = Path(save_folder)
self.save_folder = save_folder
self.save_every_n_epochs = save_every_n_epochs
self.save_every_n_batches = save_every_n_batches
self.save_last = save_last
@@ -88,8 +108,7 @@ class CheckpointCallback(HasMonitorCallback):
# 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候,
# FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的;
self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME])
# 我们只需要保证这个创建文件夹的操作只在进程 0 上进行即可;因为后续的实际的保存操作,其它进程实际并不会去执行;
synchronize_mkdir(self.timestamp_path)
# 该 folder 只在保存真的要发生的时候再创建。

def on_after_trainer_initialized(self, trainer, driver):
if self.save_topk is not None:
@@ -98,8 +117,6 @@ class CheckpointCallback(HasMonitorCallback):
logger.warning("You set `save_topk`, but `evaluate_dataloaders` is not set in Trainer.")

def on_validate_end(self, trainer, results):
if len(results) == 0:
return
self._save_topk(trainer, results)

def on_train_epoch_end(self, trainer: "fastNLP.Trainer"):
@@ -136,16 +153,17 @@ class CheckpointCallback(HasMonitorCallback):
states['timestamp_path'] = str(self.timestamp_path.absolute())
states['_topk_model'] = deepcopy(self._topk_model)
states['save_topk'] = 0 if self.save_topk is None else self.save_topk
states['_real_monitor'] = self._real_monitor
if isinstance(self._real_monitor, str):
states['_real_monitor'] = self._real_monitor
return states

def on_load_checkpoint(self, trainer, states: Optional[Dict]):
timestamp_path = states['timestamp_path']
if not os.path.exists(timestamp_path):
logger.info(f"The resuming save folder {timestamp_path} is not exists, will checkpoint save to "
logger.info(f"The resuming checkpoint folder {timestamp_path} is not exists, will checkpoint save to "
f" {self.timestamp_path.absolute()}.")
else:
logger.info(f"Resume to save in path: {timestamp_path}.")
logger.info(f"Resume to checkpoint in path: {timestamp_path}.")
self.timestamp_path = Path(timestamp_path)
_topk_model = states['_topk_model']
save_topk = None if int(states['save_topk']) == 0 else int(states['save_topk'])
@@ -153,7 +171,8 @@ class CheckpointCallback(HasMonitorCallback):
assert self.save_topk == save_topk, f"The checkpoint set save_topk={save_topk}, while this callback set it " \
f"as {save_topk}."
self._topk_model.update(self._topk_model)
self._real_monitor = states["real_monitor"]

self._real_monitor = states["_real_monitor"]

def _save_topk(self, trainer: "fastNLP.Trainer", results: Dict):
"""
@@ -231,9 +250,9 @@ class ModelCheckpointCallback(CheckpointCallback):
model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。
若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。

:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),
返回一个 float 值作为 monitor 的结果。
:param monitor: 监控的 metric 。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
果(字典类型),返回一个 float 值作为 monitor 的结果。
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的
时间戳文件夹中。如果为 None ,默认使用当前文件夹。
:param save_every_n_epochs: 多少个 epoch 保存一次。
@@ -249,6 +268,11 @@ class ModelCheckpointCallback(CheckpointCallback):
"""
@property
def save_fn_name(self):
"""
调用 Trainer 中的哪个函数。

:return:
"""
return 'save_model'

@property
@@ -257,7 +281,7 @@ class ModelCheckpointCallback(CheckpointCallback):
通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态;
:return:
"""
return f"model_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}"
return f"model_checkpoint#monitor-{self.monitor_name}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}"

@property
def folder_prefix(self):
@@ -279,9 +303,9 @@ class TrainerCheckpointCallback(CheckpointCallback):
model_save_fn 为 None ,则以上每个 folder 中,将生成两个文件:fastnlp_trainer.pkl.tar 以及 fastnlp_model.pkl.tar 。
若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。

:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),
返回一个 float 值作为 monitor 的结果。
:param monitor: 监控的 metric 。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
果(字典类型),返回一个 float 值作为 monitor 的结果。
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的
时间戳文件夹中。如果为 None ,默认使用当前文件夹。
:param save_every_n_epochs: 多少个 epoch 保存一次。
@@ -297,6 +321,11 @@ class TrainerCheckpointCallback(CheckpointCallback):
"""
@property
def save_fn_name(self):
"""
调用 Trainer 中的哪个函数。

:return:
"""
return 'save'

@property
@@ -305,7 +334,8 @@ class TrainerCheckpointCallback(CheckpointCallback):
通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态;
:return:
"""
return f"trainer_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}"

return f"trainer_checkpoint#monitor-{self.monitor_name}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}"

@property
def folder_prefix(self):


+ 9
- 5
fastNLP/core/callbacks/early_stop_callback.py View File

@@ -12,8 +12,9 @@ class EarlyStopCallback(HasMonitorCallback):
def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10):
"""

:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为
evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
:param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: monitor 的值是否是越大越好。
:param patience: 多少次 validate 不没有提升就停止。
"""
@@ -46,17 +47,20 @@ class EarlyStopCallback(HasMonitorCallback):
states = {
'patience': self.patience,
'wait': self.wait,
'monitor': self.monitor,
'monitor_value': self.monitor_value
}
if not callable(self._real_monitor):
states['_real_monitor'] = self._real_monitor
return states

def on_load_checkpoint(self, trainer, states):
self.patience = states['patience']
self.wait = states['wait']
self.monitor = states['monitor']
self.monitor_value = float(states['monitor_value'])
if '_real_monitor' in states:
self._real_monitor = states['_real_monitor']

@property
def callback_name(self):
return f'EarlyStopCallback#monitor-{self.monitor}#patience-{self.patience}'
return f'EarlyStopCallback#monitor-{self.monitor_name}#patience-{self.patience}'


+ 3
- 2
fastNLP/core/callbacks/load_best_model_callback.py View File

@@ -21,8 +21,9 @@ class LoadBestModelCallback(HasMonitorCallback):
"""
保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。

:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为
evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
:param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: 该 metric 值是否是越大越好。
:param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保
不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。


+ 8
- 6
fastNLP/core/callbacks/progress_callback.py View File

@@ -44,10 +44,11 @@ class RichCallback(ProgressCallback):

:param print_every: 多少个 batch 更新一次显示。
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。如果为 None ,会尝试使用 trainer 中设置的 monitor 。
也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: 是否是monitor的结果越大越好。
:param format_json: 是否format json再打印
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到
完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor
。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: 是否是 monitor 的结果越大越好。
:param format_json: 是否格式化 json 再打印
"""
super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False)
self.print_every = print_every
@@ -136,8 +137,9 @@ class RawTextCallback(ProgressCallback):

:param print_every: 多少个 batch 更新一次显示。
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。也可以传入一个函数,接受参数为 evaluation 的结果(
字典类型),返回一个 float 值作为 monitor 的结果。
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到
完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor
。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: 是否是monitor的结果越大越好。
:param format_json: 是否format json再打印
"""


+ 4
- 5
fastNLP/core/controllers/evaluator.py View File

@@ -36,7 +36,7 @@ class Evaluator:
model,
dataloaders,
metrics: Optional[Union[Dict, Metric]] = None,
driver: Union[str, Driver] = 'single',
driver: Union[str, Driver] = 'torch',
device: Optional[Union[int, List[int], str]] = None,
batch_step_fn: Optional[callable] = None,
evaluate_fn: Optional[str] = None, # 首先尝试找 evaluate_step, 找不到 forward, callable
@@ -49,8 +49,8 @@ class Evaluator:
):
"""

:param dataloaders:
:param model:
:param dataloaders:
:param metrics: 使用的 metric 。必须为 dict 类型,其中 key 为 metric 的名称,value 为一个 Metric 对象。支持 fastNLP 的
metric ,torchmetrics,allennlpmetrics等。
:param driver: 使用 driver 。
@@ -120,7 +120,8 @@ class Evaluator:

if evaluate_fn is not None and not isinstance(evaluate_fn, str):
raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.")
self._evaluate_step, self._evaluate_step_signature_fn = self.driver.get_model_call_fn("evaluate_step" if evaluate_fn is None else evaluate_fn)
self._evaluate_step, self._evaluate_step_signature_fn = \
self.driver.get_model_call_fn("evaluate_step" if evaluate_fn is None else evaluate_fn)
self.evaluate_fn = evaluate_fn

self.dataloaders = {}
@@ -134,8 +135,6 @@ class Evaluator:

self.driver.barrier()

self.driver.check_dataloader_legality(self.dataloaders, "dataloaders", is_train=False)

def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict:
"""
返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。


+ 2
- 4
fastNLP/core/controllers/loops/train_batch_loop.py View File

@@ -20,7 +20,7 @@ class TrainBatchLoop(Loop):
else lambda *args, **kwargs: None
dataloader = iter(dataloader)
indices = None
while True:
while trainer.batch_idx_in_epoch<=trainer.num_batches_per_epoch:
try:
trainer.on_fetch_data_begin()
batch = next(dataloader)
@@ -30,10 +30,8 @@ class TrainBatchLoop(Loop):
batch = trainer.move_data_to_device(batch)
except StopIteration:
break
except EarlyStopException: # 在 Trainer 处理 earlystop 的 exception
break
except BaseException as e:
if indices:
if indices and not isinstance(e, EarlyStopException):
logger.debug(f"The following exception happens when running on samples: {indices}")
raise e



+ 4
- 6
fastNLP/core/controllers/trainer.py View File

@@ -174,9 +174,8 @@ class Trainer(TrainerEventTrigger):
optimizers=optimizers,
device=device,
n_epochs=n_epochs,
validate_dataloaders=evaluate_dataloaders,
batch_step_fn=batch_step_fn,
validate_batch_step_fn=evaluate_batch_step_fn,
z=evaluate_batch_step_fn,
evaluate_fn=evaluate_fn,
callbacks=callbacks,
metrics=metrics,
@@ -264,7 +263,6 @@ class Trainer(TrainerEventTrigger):
self.on_after_trainer_initialized(self.driver)

self.driver.barrier()
self.driver.check_dataloader_legality(self.train_dataloader, "train_dataloader", is_train=True)

def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1,
num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True,
@@ -310,7 +308,7 @@ class Trainer(TrainerEventTrigger):

self.num_batches_per_epoch = len(self.dataloader)
self.total_batches = self.num_batches_per_epoch * self.n_epochs
self.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + self.batch_idx_in_epoch
self.on_train_begin()
self.driver.barrier()
self.driver.zero_grad(self.set_grad_to_none)
@@ -637,6 +635,8 @@ class Trainer(TrainerEventTrigger):
:param folder: 保存断点重训 states 的文件地址;
:param resume_training: 是否从上次的 batch 开始训练,或者只从最近的 epoch 开始训练;注意如果 resume_training=True,那么我们
只会加载 model 和 optimizers 的状态;而其余的对象的值则根据用户的 Trainer 的初始化直接重置;
:param only_state_dict: 保存的 model 是否只包含了权重。
:param model_load_fn: 使用的模型加载函数,参数应为一个 文件夹,不返回任何内容。
"""
self.driver.barrier()
if isinstance(folder, str):
@@ -675,8 +675,6 @@ class Trainer(TrainerEventTrigger):
# 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于
# '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立
self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch')
self.trainer_state.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + \
self.batch_idx_in_epoch
# 这个是防止用户在 Trainer.load 之后还没结束当前 epoch 又继续 save
self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch



+ 3
- 3
fastNLP/core/controllers/utils/state.py View File

@@ -65,10 +65,10 @@ class TrainerState:
"""
n_epochs: Optional[int] = None # 无论如何重新算

cur_epoch_idx: Optional[int] = None # 断点重训; 仅当 resume=False 时为0;
global_forward_batches: Optional[int] = None # 断点重训
cur_epoch_idx: Optional[int] = 0 # 断点重训; 仅当 resume=False 时为0;
global_forward_batches: Optional[int] = 0 # 断点重训

batch_idx_in_epoch: Optional[int] = None # 断点重训
batch_idx_in_epoch: Optional[int] = 0 # 断点重训

num_batches_per_epoch: Optional[int] = None # 无论如何重新算



+ 1
- 12
fastNLP/core/drivers/driver.py View File

@@ -86,7 +86,7 @@ class Driver(ABC):
函数;

:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型;
:param fn: 由 Trainer 传入的用于网络前向传播一次的函数;
:param fn: 调用该函数进行一次计算。
:param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call
函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward;
:return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查);
@@ -126,17 +126,6 @@ class Driver(ABC):
def model(self, model):
self._model = model

@staticmethod
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
r"""
该函数会在 trainer 或者 evaluator 设置 dataloader 后检测 dataloader 的合法性,因为不同的深度学习的框架需要的 dataloader 的
行为是不相同的;

:param dataloader: 需要检测的输入的 `dataloader`;
:param dataloader_name:
"""
raise NotImplementedError("Each specific driver should implemented its own `check_dataloader_legality` function.")

@property
def optimizers(self) -> List:
r"""


+ 1
- 1
fastNLP/core/drivers/torch_driver/ddp.py View File

@@ -406,7 +406,7 @@ class TorchDDPDriver(TorchDriver):
if hasattr(model, fn):
fn = getattr(model, fn)
if not callable(fn):
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.")
raise RuntimeError(f"The `{fn}` attribute of model is not `Callable`.")
return fn, None
elif fn in {"train_step", "evaluate_step"}:
return model, model.forward


+ 1
- 0
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -199,6 +199,7 @@ class TorchDriver(Driver):
num_consumed_batches = sampler_states['num_consumed_samples']
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches]
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report."
states['sampler_states'] = sampler_states
else:
raise RuntimeError(
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.')


+ 1
- 0
tests/core/utils/test_utils.py View File

@@ -181,6 +181,7 @@ class TestCheckNumberOfParameters:


def test_get_fun_msg():
# 测试运行
def demo(x):
pass


+ 15
- 1
tests/helpers/datasets/normal_data.py View File

@@ -1,3 +1,6 @@
import numpy as np


class NormalIterator:
def __init__(self, num_of_data=1000):
self._num_of_data = num_of_data
@@ -15,4 +18,15 @@ class NormalIterator:
return self._data

def __len__(self):
return self._num_of_data
return self._num_of_data


class RandomDataset:
def __init__(self, num_data=10):
self.data = np.random.rand(num_data)

def __len__(self):
return len(self.data)

def __getitem__(self, item):
return self.data[item]

Loading…
Cancel
Save