diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index d07382e4..02b56cd7 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -14,7 +14,7 @@ __all__ = [ 'MoreEvaluateCallback', "TorchWarmupCallback", "TorchGradClipCallback", - "MonitorUtility", + "ResultsMonitor", 'HasMonitorCallback', # collators diff --git a/fastNLP/core/callbacks/__init__.py b/fastNLP/core/callbacks/__init__.py index 6f859183..9ba0d227 100644 --- a/fastNLP/core/callbacks/__init__.py +++ b/fastNLP/core/callbacks/__init__.py @@ -16,7 +16,7 @@ __all__ = [ "TorchWarmupCallback", "TorchGradClipCallback", - "MonitorUtility", + "ResultsMonitor", 'HasMonitorCallback' ] @@ -31,5 +31,5 @@ from .load_best_model_callback import LoadBestModelCallback from .early_stop_callback import EarlyStopCallback from .torch_callbacks import * from .more_evaluate_callback import MoreEvaluateCallback -from .has_monitor_callback import MonitorUtility, HasMonitorCallback +from .has_monitor_callback import ResultsMonitor, HasMonitorCallback diff --git a/fastNLP/core/callbacks/callback_manager.py b/fastNLP/core/callbacks/callback_manager.py index 35ca3f53..eabc489b 100644 --- a/fastNLP/core/callbacks/callback_manager.py +++ b/fastNLP/core/callbacks/callback_manager.py @@ -57,7 +57,7 @@ def prepare_callbacks(callbacks, progress_bar): if has_no_progress and progress_bar is not None: callback = choose_progress_callback(progress_bar) if callback is not None: - _callbacks.append(callback) + _callbacks = [callback] + _callbacks # 放在最前面,方便分割不同 epoch has_no_progress = False elif has_no_progress is False and progress_bar not in ('auto', None): logger.rank_zero_warning(f"Since you have passed in ProgressCallback, progress_bar={progress_bar} will be ignored.") @@ -146,11 +146,13 @@ class CallbackManager: r""" 用于断点重训的 callback 的保存函数; 该函数主要涉及两个方面: - 1. callback 的状态的保存;我们会调用每一个 callback 的 `on_save_checkpoint` 方法,该方法应当返回一个字典,其中包含着 - 断点重训应当保存的状态; - 2. 每一个具体的 callback 函数的 filter 的状态; - :return: 一个包含上述内容的字典:: + 1. callback 的状态的保存;我们会调用每一个 callback 的 `on_save_checkpoint` 方法,该方法应当返回一个字典,其中包含着 + 断点重训应当保存的状态; + 2. 每一个具体的 callback 函数的 filter 的状态; + + :return: 一个包含上述内容的字典: + .. code-block:: { "callback_name_1": { @@ -158,6 +160,7 @@ class CallbackManager: "filter_states": {"on_train_begin": filter1.state_dict(), ...} } } + """ states = {} diff --git a/fastNLP/core/callbacks/has_monitor_callback.py b/fastNLP/core/callbacks/has_monitor_callback.py index 8e5eb0aa..2d1affd2 100644 --- a/fastNLP/core/callbacks/has_monitor_callback.py +++ b/fastNLP/core/callbacks/has_monitor_callback.py @@ -1,7 +1,7 @@ __all__ = [ 'HasMonitorCallback', 'ExecuteOnceBetterMonitor', - 'MonitorUtility' + 'ResultsMonitor' ] from typing import Dict, Union, Any @@ -29,12 +29,16 @@ class CanItemDataType(ABC): return NotImplemented -class MonitorUtility: - """ - 计算 monitor 的相关函数 +class ResultsMonitor: + def __init__(self, monitor:Union[Callback, str], larger_better:bool=True): + """ + 可用于监控某个数值,并通过 is_better_results() 等接口实现检测结果是否变得更好了。 - """ - def __init__(self, monitor, larger_better): + :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 + 果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 + :param larger_better: monitor 是否时越大越好 + """ self.set_monitor(monitor, larger_better) def set_monitor(self, monitor, larger_better): @@ -53,7 +57,7 @@ class MonitorUtility: def itemize_results(self, results): """ - 将结果中有 .item() 方法的都调用一下,使得可以结果可以保存 + 将结果中有 .item() 方法的都调用一下,使得 tensor 类型的数据转为 python 内置类型。 :param results: :return: @@ -161,7 +165,7 @@ class MonitorUtility: return monitor_name -class HasMonitorCallback(MonitorUtility, Callback): +class HasMonitorCallback(ResultsMonitor, Callback): def __init__(self, monitor, larger_better, must_have_monitor=False): """ 该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了 diff --git a/fastNLP/core/callbacks/more_evaluate_callback.py b/fastNLP/core/callbacks/more_evaluate_callback.py index 1f34881c..33415b7a 100644 --- a/fastNLP/core/callbacks/more_evaluate_callback.py +++ b/fastNLP/core/callbacks/more_evaluate_callback.py @@ -39,7 +39,7 @@ class MoreEvaluateCallback(HasMonitorCallback): 意义是,当检测到 Trainer 中 evaluate results 的 {watch_monitor} 的结果更好时,则进行一次 evaluate 。该参数有两种 取值: (1) str 类型,监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最 匹配的那个作为 monitor ; (2) 也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor - 的结果,如果当前结果中没有相关的monitor 值请返回 None 。 + 的结果,如果当前结果中没有相关的monitor 值请返回 None 。 :param watch_monitor_larger_better: watch_monitor 是否越大越好。 :param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 `model.forward`;(1) 如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有 diff --git a/fastNLP/core/callbacks/topk_saver.py b/fastNLP/core/callbacks/topk_saver.py index cf6881d7..09843511 100644 --- a/fastNLP/core/callbacks/topk_saver.py +++ b/fastNLP/core/callbacks/topk_saver.py @@ -12,7 +12,7 @@ from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_LAUNCH_TIME from fastNLP.envs import rank_zero_call from fastNLP.envs.env import FASTNLP_EVALUATE_RESULT_FILENAME -from .has_monitor_callback import MonitorUtility +from .has_monitor_callback import ResultsMonitor class Saver: @@ -170,7 +170,7 @@ class TopkQueue: return self.topk != 0 -class TopkSaver(MonitorUtility, Saver): +class TopkSaver(ResultsMonitor, Saver): def __init__(self, topk:int=0, monitor:str=None, larger_better:bool=True, folder:str=None, save_object:str='model', only_state_dict:bool=True, model_save_fn:Callable=None, save_evaluate_results:bool=True, **kwargs): @@ -196,7 +196,7 @@ class TopkSaver(MonitorUtility, Saver): fastnlp_evaluate_results.json 文件,记录当前的 results。仅在设置了 topk 的场景下有用,默认为 True 。 :param kwargs: 更多需要传递给 Trainer.save() 或者 Trainer.save_model() 接口的参数。 """ - MonitorUtility.__init__(self, monitor, larger_better) + ResultsMonitor.__init__(self, monitor, larger_better) Saver.__init__(self, folder, save_object, only_state_dict, model_save_fn, **kwargs) if monitor is not None and topk == 0: diff --git a/fastNLP/core/callbacks/torch_callbacks/torch_grad_clip_callback.py b/fastNLP/core/callbacks/torch_callbacks/torch_grad_clip_callback.py index 81b4ce6e..cc0e1e98 100644 --- a/fastNLP/core/callbacks/torch_callbacks/torch_grad_clip_callback.py +++ b/fastNLP/core/callbacks/torch_callbacks/torch_grad_clip_callback.py @@ -10,13 +10,13 @@ class TorchGradClipCallback(Callback): 在每次 optimizer update 之前将 parameter 进行 clip :param float clip_value: 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数 - :param str clip_type: 支持'norm', 'value'两种:: + :param str clip_type: 支持'norm', 'value'两种: - 1 'norm', 将gradient的norm rescale到[-clip_value, clip_value] + 1. 'norm', 将gradient的norm rescale到[-clip_value, clip_value] + 2. 'value', 将gradient限制在[-clip_value, clip_value], + 小于-clip_value的gradient被赋值为-clip_value; + 大于clip_value的gradient被赋值为clip_value. - 2 'value', 将gradient限制在[-clip_value, clip_value], - 小于-clip_value的gradient被赋值为-clip_value; - 大于clip_value的gradient被赋值为clip_value. :param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。 如果为None则默认对 Trainer 的 optimizers 中所有参数进行梯度裁剪。 """ diff --git a/fastNLP/core/collators/padders/get_padder.py b/fastNLP/core/collators/padders/get_padder.py index 5c7be44b..db48011b 100644 --- a/fastNLP/core/collators/padders/get_padder.py +++ b/fastNLP/core/collators/padders/get_padder.py @@ -118,6 +118,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> elif backend == 'numpy': return NumpyTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) elif backend == 'torch': + # 这里 ele_dtype 传入为 None 的原因是防止出现 paddle tensor 转换为 torch tensor return TorchTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) elif backend == 'paddle': return PaddleTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) diff --git a/fastNLP/core/collators/padders/paddle_padder.py b/fastNLP/core/collators/padders/paddle_padder.py index 5432b17a..7e91ec42 100644 --- a/fastNLP/core/collators/padders/paddle_padder.py +++ b/fastNLP/core/collators/padders/paddle_padder.py @@ -132,6 +132,9 @@ class PaddleTensorPadder(Padder): try: if not isinstance(batch_field[0], paddle.Tensor): batch_field = [paddle.to_tensor(field.tolist(), dtype=dtype) for field in batch_field] + else: + if dtype is None: + dtype = batch_field[0].dtype except AttributeError: raise RuntimeError(f"If the field is not a paddle.Tensor (it is {type(batch_field[0])}), " f"it must have tolist() method.") diff --git a/fastNLP/core/collators/padders/torch_padder.py b/fastNLP/core/collators/padders/torch_padder.py index d6d07dcd..b67aeff8 100644 --- a/fastNLP/core/collators/padders/torch_padder.py +++ b/fastNLP/core/collators/padders/torch_padder.py @@ -118,6 +118,8 @@ class TorchTensorPadder(Padder): batch_field = [torch.tensor(field.tolist(), dtype=dtype) for field in batch_field] else: device = batch_field[0].device + if dtype is None: + dtype = batch_field[0].dtype except AttributeError: raise RuntimeError(f"If the field is not a torch.Tensor (it is {type(batch_field[0])}), " f"it must have tolist() method.") diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index 70c7fbd0..47301955 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -8,10 +8,10 @@ __all__ = [ ] from fastNLP.core.drivers import Driver -from fastNLP.core.drivers.utils import choose_driver +from ..drivers.choose_driver import choose_driver from .loops import Loop, EvaluateBatchLoop from fastNLP.core.utils import auto_param_call, dataclass_to_dict, \ - match_and_substitute_params, f_rich_progress + match_and_substitute_params, f_rich_progress, flat_nest_dict from fastNLP.core.metrics import Metric from fastNLP.core.metrics.utils import _is_torchmetrics_metric, _is_paddle_metric, _is_allennlp_metric from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader @@ -51,23 +51,30 @@ class Evaluator: 为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `evaluate_step` 和 `test_step`; :param fp16: 是否使用 fp16 。 :param verbose: 是否打印 evaluate 的结果。 - :param kwargs: - bool model_use_eval_mode: 是否在 evaluate 的时候将 model 的状态设置成 eval 状态。在 eval 状态下,model 的dropout - 与 batch normalization 将会关闭。默认为True。如果为 False,fastNLP 不会对 model 的 evaluate 状态做任何设置。无论 - 该值是什么,fastNLP 都会在 evaluate 接受后将 model 的状态设置为 train 。 - TODO 还没完成。 - Union[bool] auto_tensor_conversion_for_metric: 是否自动将输出中的 - tensor 适配到 metrics 支持的。例如 model 输出是 paddlepaddle 的 tensor ,但是想利用 torchmetrics 的metric对象, - 当 auto_tensor_conversion_for_metric 为True时,fastNLP 将自动将输出中 paddle 的 tensor (其它非 tensor 的参数 - 不做任何处理)转换为 pytorch 的 tensor 再输入到 metrics 中进行评测。 model 的输出 tensor 类型通过 driver 来决定, - metrics 支持的输入类型由 metrics 决定。如果需要更复杂的转换,请使用 input_mapping、output_mapping 参数进行。 - use_dist_sampler: 是否使用分布式evaluate的方式。仅当 driver 为分布式类型时,该参数才有效。默认为根据 driver 是否支持 - 分布式进行设置。如果为True,将使得每个进程上的 dataloader 自动使用不同数据,所有进程的数据并集是整个数据集。 - output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: - ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 - log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; - progress_bar: evaluate 的时候显示的 progress bar 。目前支持三种 [None, 'raw', 'rich', 'auto'], auto 表示如果检测 - 到当前terminal为交互型则使用 rich,否则使用 raw。 + :param \**kwargs: + See below + :kwargs: + * *model_use_eval_mode* (``bool``) -- + 是否在 evaluate 的时候将 model 的状态设置成 eval 状态。在 eval 状态下,model 的 + dropout 与 batch normalization 将会关闭。默认为True。如果为 False,fastNLP 不会对 model 的 evaluate 状态做任何设置。无论 + 该值是什么,fastNLP 都会在 evaluate 接受后将 model 的状态设置为 train 。 + TODO 还没完成。 + * *auto_tensor_conversion_for_metric* (``Union[bool]``) -- + 是否自动将输出中的 tensor 适配到 metrics 支持的。例如 model 输出是 + paddlepaddle 的 tensor ,但是想利用 torchmetrics 的metric对象,当 auto_tensor_conversion_for_metric 为True时,fastNLP 将 + 自动将输出中 paddle 的 tensor (其它非 tensor 的参数不做任何处理)转换为 pytorch 的 tensor 再输入到 metrics 中进行评测。 model 的 + 输出 tensor 类型通过 driver 来决定,metrics 支持的输入类型由 metrics 决定。如果需要更复杂的转换, + 请使用 input_mapping、output_mapping 参数进行。 + * *use_dist_sampler* -- + 是否使用分布式evaluate的方式。仅当 driver 为分布式类型时,该参数才有效。默认为根据 driver 是否支持 + 分布式进行设置。如果为True,将使得每个进程上的 dataloader 自动使用不同数据,所有进程的数据并集是整个数据集。 + * *output_from_new_proc* -- + 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: + ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 + log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; + * *progress_bar* -- + evaluate 的时候显示的 progress bar 。目前支持三种 [None, 'raw', 'rich', 'auto'], auto 表示如果检测 + 到当前terminal为交互型则使用 rich,否则使用 raw。 """ self.model = model @@ -155,19 +162,21 @@ class Evaluator: self.cur_dataloader_name = dataloader_name results = self.evaluate_batch_loop.run(self, dataloader) self.remove_progress_bar(dataloader_name) - metric_results.update(results) + metric_results[dataloader_name] = results self.reset() self.driver.barrier() except BaseException as e: raise e finally: self.finally_progress_bar() + if len(metric_results) > 0: # 如果 metric 不为 None 需要 print 。 + metric_results = flat_nest_dict(metric_results, separator=self.separator, compress_none_key=True, top_down=False) + if self.verbose: + if self.progress_bar == 'rich': + f_rich_progress.print(metric_results) + else: + logger.info(metric_results) self.driver.set_model_mode(mode='train') - if self.verbose: - if self.progress_bar == 'rich': - f_rich_progress.print(metric_results) - else: - logger.info(metric_results) return metric_results @@ -244,14 +253,13 @@ class Evaluator: """ self.metrics_wrapper.update(batch, outputs) - def get_dataloader_metric(self, dataloader_name: Optional[str] = '') -> Dict: + def get_metric(self) -> Dict: """ - 获取当前dataloader的metric结果 + 调用所有 metric 的 get_metric 方法,并返回结果。其中 key 为 metric 的名称,value 是各个 metric 的结果。 - :param str dataloader_name: 当前dataloader的名字 :return: """ - return self.metrics_wrapper.get_metric(dataloader_name=dataloader_name, separator=self.separator) + return self.metrics_wrapper.get_metric() @property def metrics_wrapper(self): @@ -359,15 +367,12 @@ class _MetricsWrapper: elif _is_torchmetrics_metric(metric) or _is_paddle_metric(metric) or isinstance(metric, Metric): metric.reset() - def get_metric(self, dataloader_name: str, separator: str) -> Dict: + def get_metric(self) -> Dict: """ - 将所有 metric 结果展平到一个一级的字典中,这个字典中 key 的命名规则是 - indicator_name{separator}metric_name{separator}dataloader_name - 例如: f1#F1PreRec#dev + 调用各个 metric 得到 metric 的结果。并使用 {'metric_name1': metric_results, 'metric_name2': metric_results} 的形式 + 返回。 - :param dataloader_name: 当前metric对应的dataloader的名字。若为空,则不显示在最终的key上面。 - :param separator: 用于间隔不同称呼。 - :return: 返回一个一级结构的字典,其中 key 为区别一个 metric 的名字,value 为该 metric 的值; + :return: """ results = {} for metric_name, metric in zip(self._metric_names, self._metrics): @@ -377,37 +382,9 @@ class _MetricsWrapper: _results = metric.get_metric(reset=False) elif _is_torchmetrics_metric(metric): _results = metric.compute() - # 我们规定了 evaluator 中的 metrics 的输入只能是一个 dict,这样如果 metric 是一个 torchmetrics 时,如果 evaluator - # 没有传入 func_post_proc,那么我们就自动使用该 metric 的 metric name 当做其的 indicator name 将其自动转换成一个字典; elif _is_paddle_metric(metric): _results = metric.accumulate() - if not isinstance(_results, Dict): - name = _get_metric_res_name(dataloader_name, metric_name, '', separator) - results[name] = _results else: - for indicator_name, value in _results.items(): - name = _get_metric_res_name(dataloader_name, metric_name, indicator_name, separator) - results[name] = value - + raise RuntimeError(f"Not support `{type(metric)}` for now.") + results[metric_name] = _results return results - - -def _get_metric_res_name(dataloader_name: Optional[str], metric_name: str, indicator_name: str, separator='#') -> str: - """ - - :param dataloader_name: dataloder的名字 - :param metric_name: metric的名字 - :param indicator_name: metric中的各项metric名称,例如f, precision, recall - :param separator: 用以间隔不同对象的间隔符 - :return: - """ - names = [] - if indicator_name: - names.append(indicator_name) - if metric_name: - names.append(metric_name) - if dataloader_name: - names.append(dataloader_name) - if len(names) == 0: - raise RuntimeError("You cannot use empty `dataloader_name`, `metric_name`, and `monitor` simultaneously.") - return separator.join(names) diff --git a/fastNLP/core/controllers/loops/evaluate_batch_loop.py b/fastNLP/core/controllers/loops/evaluate_batch_loop.py index 2d8f07d1..0bf66fda 100644 --- a/fastNLP/core/controllers/loops/evaluate_batch_loop.py +++ b/fastNLP/core/controllers/loops/evaluate_batch_loop.py @@ -40,8 +40,8 @@ class EvaluateBatchLoop(Loop): self.batch_step_fn(evaluator, batch) batch_idx += 1 evaluator.update_progress_bar(batch_idx, evaluator.cur_dataloader_name) - # 获取metric结果。返回的dict内容示例为{'f1#F1Metric#dl1': 0.93, 'pre#F1Metric#dl1': 0.95, ...} - results = evaluator.get_dataloader_metric(dataloader_name=evaluator.cur_dataloader_name) + # 获取metric结果。返回的dict内容示例为{'metric_name1': metric_results, 'metric_name2': metric_results, ...} + results = evaluator.get_metric() return results @staticmethod diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 54ce5f28..9c8cd874 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -23,7 +23,7 @@ from fastNLP.core.callbacks.callback import _CallbackWrapper from fastNLP.core.callbacks.callback_manager import prepare_callbacks from fastNLP.core.callbacks.callback_event import Event from fastNLP.core.drivers import Driver -from fastNLP.core.drivers.utils import choose_driver +from ..drivers.choose_driver import choose_driver from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext from fastNLP.core.utils.utils import _check_valid_parameters_number from fastNLP.envs import rank_zero_call @@ -67,20 +67,21 @@ class Trainer(TrainerEventTrigger): 要自己实现模型部分,而将训练层面的逻辑完全地交给 fastNLP; :param model: 训练所需要的模型,目前支持 pytorch; - :param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["torch", "torch_ddp", ],之后我们会加入 jittor、paddle - 等国产框架的训练模式;其中 "torch" 表示使用 cpu 或者单张 gpu 进行训练 + :param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["torch", "torch_ddp", ],之后我们会加入 jittor、paddle 等 + 国产框架的训练模式;其中 "torch" 表示使用 cpu 或者单张 gpu 进行训练 :param train_dataloader: 训练数据集,注意其必须是单独的一个数据集,不能是 List 或者 Dict; :param optimizers: 训练所需要的优化器;可以是单独的一个优化器实例,也可以是多个优化器组成的 List; :param device: 该参数用来指定具体训练时使用的机器;注意当该参数为 None 时,fastNLP 不会将模型和数据进行设备之间的移动处理,但是你 - 可以通过参数 `input_mapping` 和 `output_mapping` 来实现设备之间数据迁移的工作(通过这两个参数传入两个处理数据的函数);同时你也 - 可以通过在 kwargs 添加参数 "data_device" 来让我们帮助您将数据迁移到指定的机器上(注意这种情况理应只出现在用户在 Trainer 实例化前 - 自己构造 DDP 的多进程场景); + 可以通过参数 `input_mapping` 和 `output_mapping` 来实现设备之间数据迁移的工作(通过这两个参数传入两个处理数据的函数);同时你也 + 可以通过在 kwargs 添加参数 "data_device" 来让我们帮助您将数据迁移到指定的机器上(注意这种情况理应只出现在用户在 Trainer 实例化前 + 自己构造 DDP 的多进程场景); device 的可选输入如下所示: 1. 可选输入:str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中, 可见的第二个GPU中; 2. torch.device:将模型装载到torch.device上; 3. int: 将使用device_id为该值的gpu进行训练;如果值为 -1,那么默认使用全部的显卡,此时是 `TorchDDPDriver`; 4. list(int):如果多于1个device,应当通过该种方式进行设定;当 `device` 为一个 list 时,我们默认使用 `TorchDDPDriver`; 5. None: 为None则不对模型进行任何处理; + :param n_epochs: 训练总共的 epoch 的数量,默认为 20; :param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 为 None; @@ -121,26 +122,27 @@ class Trainer(TrainerEventTrigger): 如果 evaluate_dataloaders 与 metrics 没有提供,该参数无意义。 :param larger_better: monitor 的值是否是越大越好。 :param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; - :param kwargs: 一些其它的可能需要的参数; - torch_non_blocking: 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; - data_device: 表示如果用户的模型 device (在 Driver 中对应为参数 model_device)为 None 时,我们会将数据迁移到 data_device 上; - 注意如果 model_device 为 None,那么 data_device 不会起作用; - torch_ddp_kwargs: 用于配置 pytorch 的 DistributedDataParallel 初始化时的参数;仅用于 pytorch ddp 训练。例如传入 - {'find_unused_parameters': True} 来解决有有参数不参与前向运算导致的报错等。 - set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; - use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch + :param kwargs: 一些其它的可能需要的参数,见下方的说明 + :kwargs: + * *torch_non_blocking* -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; + * *data_device* -- 表示如果用户的模型 device (在 Driver 中对应为参数 model_device)为 None 时,我们会将数据迁移到 data_device 上; + 注意如果 model_device 为 None,那么 data_device 不会起作用; + * *torch_ddp_kwargs* -- 用于配置 pytorch 的 DistributedDataParallel 初始化时的参数;仅用于 pytorch ddp 训练。例如传入 + {'find_unused_parameters': True} 来解决有有参数不参与前向运算导致的报错等。 + * *set_grad_to_none* -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; + * *use_dist_sampler* -- 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch 内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 - evaluate_use_dist_sampler: 表示在 Evaluator 中在使用 分布式 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; - output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: + * *evaluate_use_dist_sampler* -- 表示在 Evaluator 中在使用 分布式 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; + * *output_from_new_proc* -- 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; - progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象, + * *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象, 默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 RichCallback,否则使用 RawTextCallback对象。如果 需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。 - train_input_mapping: 与 input_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。 - train_output_mapping: 与 output_mapping 一致,但是只用于 train 中。与 output_mapping 互斥。 - evaluate_input_mapping: 与 input_mapping 一致,但是只用于 evaluate 中。与 input_mapping 互斥。 - evaluate_output_mapping: 与 output_mapping 一致,但是只用于 evaluate 中。与 output_mapping 互斥。 + * *train_input_mapping* -- 与 input_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。 + * *train_output_mapping* -- 与 output_mapping 一致,但是只用于 train 中。与 output_mapping 互斥。 + * *evaluate_input_mapping* -- 与 input_mapping 一致,但是只用于 evaluate 中。与 input_mapping 互斥。 + * *evaluate_output_mapping* -- 与 output_mapping 一致,但是只用于 evaluate 中。与 output_mapping 互斥。 """ self.model = model self.marker = marker @@ -290,14 +292,14 @@ class Trainer(TrainerEventTrigger): catch_KeyboardInterrupt=None): """ 注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None,并且使用 ModelCheckpoint - 去保存断点重训的文件; + 去保存断点重训的文件; :param num_train_batch_per_epoch: 每个 epoch 运行多少个 batch 即停止,-1 为根据 dataloader 有多少个 batch 决定。 :param num_eval_batch_per_dl: 每个 evaluate dataloader 运行多少个 batch 停止,-1 为根据 dataloader 有多少个 batch 决定。 :param num_eval_sanity_batch: 在训练之前运行多少个 evaluation batch 来检测一下 evaluation 是否有错误。为 0 表示不检测。 :param resume_from: 从哪个路径下恢复 trainer 的状态 :param resume_training: 是否按照 checkpoint 中训练状态恢复。如果为 False,则只恢复 model 和 optimizers 的状态。 :param catch_KeyboardInterrupt: 是否捕获KeyboardInterrupt, 如果捕获的话,不会抛出一场,trainer.run()之后的代码会继续运 - 行。默认如果非 distributed 的 driver 会 catch ,distributed 不会 catch (无法 catch ) + 行。默认如果非 distributed 的 driver 会 catch ,distributed 不会 catch (无法 catch ) :return: """ @@ -417,39 +419,42 @@ class Trainer(TrainerEventTrigger): def on(cls, event: Event, marker: Optional[str] = None): r""" 函数修饰器,用户可以使用该函数来方便地将一个函数转变为 callback 函数,从而进行训练流程中的控制; - 支持的 event 时机有以下这些,其执行的时机顺序也如下所示。每个时机装饰的函数应该接受的参数列表也如下所示,例如 - Trainer.__init__(): - on_after_trainer_initialized(trainer, driver) - Trainer.run(): - if num_eval_sanity_batch>0: - on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch - on_sanity_check_end(trainer, sanity_check_res) - try: - on_train_begin(trainer) - while cur_epoch_idx < n_epochs: - on_train_epoch_begin(trainer) - while batch_idx_in_epoch<=num_batches_per_epoch: - on_fetch_data_begin(trainer) - batch = next(dataloader) - on_fetch_data_end(trainer) - on_train_batch_begin(trainer, batch, indices) - on_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping(如果设置了) 后的,否则即为 model 的输出。 - on_after_backward(trainer) - on_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 - on_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 - on_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 - on_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 - on_train_batch_end(trainer) - on_train_epoch_end(trainer) - except BaseException: - self.on_exception(trainer, exception) - finally: - on_train_end(trainer) + 支持的 event 时机有以下这些,其执行的时机顺序也如下所示。每个时机装饰的函数应该接受的参数列表也如下所示,例如:: + + Trainer.__init__(): + on_after_trainer_initialized(trainer, driver) + Trainer.run(): + if num_eval_sanity_batch>0: + on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch + on_sanity_check_end(trainer, sanity_check_res) + try: + on_train_begin(trainer) + while cur_epoch_idx < n_epochs: + on_train_epoch_begin(trainer) + while batch_idx_in_epoch<=num_batches_per_epoch: + on_fetch_data_begin(trainer) + batch = next(dataloader) + on_fetch_data_end(trainer) + on_train_batch_begin(trainer, batch, indices) + on_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping(如果设置了) 后的,否则即为 model 的输出。 + on_after_backward(trainer) + on_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 + on_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 + on_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 + on_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 + on_train_batch_end(trainer) + on_train_epoch_end(trainer) + except BaseException: + self.on_exception(trainer, exception) + finally: + on_train_end(trainer) + 其它 callback 例如 on_evaluate_begin(trainer)/on_evaluate_end(trainer, results)/on_save_model(trainer)/ - on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()中 - 特定的时间调用。 + on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()中 + 特定的时间调用。 Example:: + from fastNLP import Event @Trainer.on(Event.on_save_model()) def do_something_1(trainer): @@ -696,7 +701,7 @@ class Trainer(TrainerEventTrigger): r""" 用于断点重训的加载函数; 注意在 fastNLP 中断点重训的保存和加载逻辑是分开的,因此可能存在一种情况:用户只希望加载一个断点重训的状态,而在之后不再进行断点重训的 - 保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleSampler; + 保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleSampler; 注意我们目前不支持单卡到多卡的断点重训; diff --git a/fastNLP/core/controllers/utils/state.py b/fastNLP/core/controllers/utils/state.py index 6cccde1e..528ab529 100644 --- a/fastNLP/core/controllers/utils/state.py +++ b/fastNLP/core/controllers/utils/state.py @@ -26,7 +26,8 @@ class State(dict): 为了实现断点重训,用户应当保证其保存的信息都是可序列化的; - 推荐的使用方式: + 推荐的使用方式:: + >>> state = State() >>> state["best_accuracy"] = 0.9 >>> print(state["best_accuracy"]) diff --git a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py index 48feea0b..f23e80e9 100644 --- a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py @@ -142,6 +142,7 @@ class JittorDataLoader: """ 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 Example:: + collator.set_ignore('field1', 'field2') :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 diff --git a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py index 140c03bc..977197f6 100644 --- a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py @@ -144,6 +144,7 @@ class PaddleDataLoader(DataLoader): """ 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 Example:: + collator.set_ignore('field1', 'field2') :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index 643b7ad3..48fee045 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -153,6 +153,7 @@ class TorchDataLoader(DataLoader): """ 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 Example:: + collator.set_ignore('field1', 'field2') :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index 44c1d444..83e83ac9 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -706,8 +706,8 @@ class DataSet: def concat(self, dataset: 'DataSet', inplace:bool=True, field_mapping:Dict=None) -> 'DataSet': """ 将当前dataset与输入的dataset结合成一个更大的dataset,需要保证两个dataset都包含了相同的field。结合后的dataset的input,target - 以及collate_fn以当前dataset为准。当dataset中包含的field多于当前的dataset,则多余的field会被忽略;若dataset中未包含所有 - 当前dataset含有field,则会报错。 + 以及collate_fn以当前dataset为准。当dataset中包含的field多于当前的dataset,则多余的field会被忽略;若dataset中未包含所有 + 当前dataset含有field,则会报错。 :param DataSet, dataset: 需要和当前dataset concat的dataset :param bool, inplace: 是否直接将dataset组合到当前dataset中 diff --git a/fastNLP/core/drivers/choose_driver.py b/fastNLP/core/drivers/choose_driver.py new file mode 100644 index 00000000..5696b4c7 --- /dev/null +++ b/fastNLP/core/drivers/choose_driver.py @@ -0,0 +1,31 @@ +from typing import Union, Optional, List + +from .driver import Driver + + +def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, List[int], str]], **kwargs) -> Driver: + r""" + 根据输入的参数 'gpus' 的格式来决定具体的工作模式; + + :param model: 运行过程中使用的具体的最原始的模型; + :param driver: 应当为字符串或者 `Driver` 实例,表示运行中具体使用的训练/评测模式; + :param device: 具体的形式请参见 `fastNLP.core.drivers.torch_driver.utils.initialize_torch_dirver` 的注释; + :param kwargs: 其余的传给 `Driver` 的参数; + """ + + # 如果用户直接传进来一个 driver 实例,我们就直接返回回去,目前用户需要自己保证传进来的 driver 的正确性; + if isinstance(driver, Driver): + return driver + + if driver in {"torch", "torch_ddp", "fairscale"}: + from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver + return initialize_torch_driver(driver, device, model, **kwargs) + elif driver in {"jittor"}: + from fastNLP.core.drivers.jittor_driver.initialize_jittor_driver import initialize_jittor_driver + return initialize_jittor_driver(driver, device, model, **kwargs) + elif driver in {"paddle", "fleet"}: + from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver + return initialize_paddle_driver(driver, device, model, **kwargs) + else: + raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'torch_ddp', 'fairscale', " + "'jittor', 'paddle', 'fleet'].") \ No newline at end of file diff --git a/fastNLP/core/drivers/driver.py b/fastNLP/core/drivers/driver.py index 01bae72a..6ce168cb 100644 --- a/fastNLP/core/drivers/driver.py +++ b/fastNLP/core/drivers/driver.py @@ -87,8 +87,8 @@ class Driver(ABC): :param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; :param fn: 调用该函数进行一次计算。 - :param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call - 函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward; + :param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call 函 + 数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward; :return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); """ raise NotImplementedError("Each specific driver should implemented its own `model_call` function.") @@ -106,9 +106,10 @@ class Driver(ABC): `evaluate step fn` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中; 这一函数应当通过参数 `fn` 来判断应当返回的实际的调用的函数,具体逻辑如下所示: - 1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 `fn`,则默认调用模型的 `forward` - 函数,然后给出 warning; - 2. 如果 fn 是其他字符串,那么如果模型没有定义方法 `fn` 则直接报错; + 1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 `fn`,则默认调用模型的 `forward` + 函数,然后给出 warning; + 2. 如果 fn 是其他字符串,那么如果模型没有定义方法 `fn` 则直接报错; + 注意不同的 driver 需要做额外的检测处理,例如在 DDPDriver 中,当传入的模型本身就是 DistributedDataParallel 中,我们只能调用模型的 forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此 可能需要额外标记最初传入 driver 的模型是哪种形式的; @@ -376,7 +377,7 @@ class Driver(ABC): 的 pid 记录下来,然后在出现错误后,由出现错误的进程手动地将其它进程 kill 掉; 因此,每一个多进程 driver 如果想要该函数能够正确地执行,其需要在自己的 open_subprocess(开启多进程的函数)中正确地记录每一个进程的 - pid 的信息; + pid 的信息; """ # 单卡 driver 不需要这个函数; if self._pids is not None: diff --git a/fastNLP/core/drivers/jittor_driver/jittor_driver.py b/fastNLP/core/drivers/jittor_driver/jittor_driver.py index bcebc6d0..b751354d 100644 --- a/fastNLP/core/drivers/jittor_driver/jittor_driver.py +++ b/fastNLP/core/drivers/jittor_driver/jittor_driver.py @@ -33,11 +33,12 @@ class JittorDriver(Driver): f"`jittor.Module` type.") super(JittorDriver, self).__init__(model) - self.model = model - self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) self.grad_scaler = _grad_scaler() + # 用来设置是否关闭 auto_param_call 中的参数匹配问题; + self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) + @staticmethod def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): # 在fastnlp中实现了JittorDataLoader @@ -152,4 +153,4 @@ class JittorDriver(Driver): # def set_sampler_epoch(self, dataloader: JittorDataLoader, cur_epoch_idx): # # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; # if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): - # dataloader.batch_sampler.set_epoch(cur_epoch_idx) \ No newline at end of file + # dataloader.batch_sampler.set_epoch(cur_epoch_idx) diff --git a/fastNLP/core/drivers/jittor_driver/single_device.py b/fastNLP/core/drivers/jittor_driver/single_device.py index 695e6ec9..ab1e8595 100644 --- a/fastNLP/core/drivers/jittor_driver/single_device.py +++ b/fastNLP/core/drivers/jittor_driver/single_device.py @@ -60,8 +60,8 @@ class JittorSingleDriver(JittorDriver): logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...') return fn, None elif fn in {"train_step", "evaluate_step"}: - logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...') - return self.model, self.model.forward + logger.debug(f'Use {_get_fun_msg(self.model.execute, with_fp=False)}...') + return self.model, self.model.execute else: raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") @@ -98,3 +98,9 @@ class JittorSingleDriver(JittorDriver): return dataloader else: return dataloader + + def setup(self): + """ + 使用单个 GPU 时,jittor 底层自动实现调配,无需额外操作 + """ + pass diff --git a/fastNLP/core/drivers/paddle_driver/dist_utils.py b/fastNLP/core/drivers/paddle_driver/dist_utils.py index 82cabb02..ffa142d3 100644 --- a/fastNLP/core/drivers/paddle_driver/dist_utils.py +++ b/fastNLP/core/drivers/paddle_driver/dist_utils.py @@ -172,6 +172,7 @@ def fastnlp_paddle_all_gather(obj: Any, device=None, group=None) ->List: 实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 example:: + obj = { 'a': [1, 1], 'b': [[1, 2], [1, 2]], diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index 0fca3856..022c397f 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -534,7 +534,7 @@ class TorchDDPDriver(TorchDriver): def broadcast_object(self, obj, src:int=0, group=None, **kwargs): """ 从 src 端将 obj 对象(可能是 tensor ,可能是 object )发送到 dst 处。如果是非 tensor 的对象会尝试使用 pickle 进行打包进行 - 传输,然后再 dst 处再加载回来。仅在分布式的 driver 中有实际意义。 + 传输,然后再 dst 处再加载回来。仅在分布式的 driver 中有实际意义。 :param obj: obj,可能是 Tensor 或 嵌套类型的数据 :param int src: source 的 global rank 。 @@ -551,9 +551,10 @@ class TorchDDPDriver(TorchDriver): def all_gather(self, obj, group) -> List: """ 将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过 - pickle 进行序列化,接收到之后再反序列化。 + pickle 进行序列化,接收到之后再反序列化。 + + example:: - example: obj = { 'a': [1, 1], 'b': [[1, 2], [1, 2]], diff --git a/fastNLP/core/drivers/torch_driver/dist_utils.py b/fastNLP/core/drivers/torch_driver/dist_utils.py index c77b8416..c5b90655 100644 --- a/fastNLP/core/drivers/torch_driver/dist_utils.py +++ b/fastNLP/core/drivers/torch_driver/dist_utils.py @@ -175,7 +175,8 @@ def fastnlp_torch_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) - """ 实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 - example: + example:: + obj = { 'a': [1, 1], 'b': [[1, 2], [1, 2]], diff --git a/fastNLP/core/drivers/torch_driver/utils.py b/fastNLP/core/drivers/torch_driver/utils.py index d756cf77..e4c84bf8 100644 --- a/fastNLP/core/drivers/torch_driver/utils.py +++ b/fastNLP/core/drivers/torch_driver/utils.py @@ -175,16 +175,18 @@ def _build_fp16_env(dummy=False): def replace_sampler(dataloader: "DataLoader", sampler): """ - 替换 sampler (初始化一个新的 dataloader 的逻辑在于): + 替换 sampler (初始化一个新的 dataloader 的逻辑在于): - 用户可能继承了 dataloader,定制了自己的 dataloader 类,这也是我们为什么先 `inspect.signature(dataloader)` 而不是直接 - `inspect.signature(DataLoader)` 的原因,因此同时注意到我们在外层重新初始化一个 dataloader 时也是使用的用户传进来的 dataloader - 的类,而不是直接的 DataLoader; + 用户可能继承了 dataloader,定制了自己的 dataloader 类,这也是我们为什么先 `inspect.signature(dataloader)` 而不是直接 + `inspect.signature(DataLoader)` 的原因,因此同时注意到我们在外层重新初始化一个 dataloader 时也是使用的用户传进来的 dataloader + 的类,而不是直接的 DataLoader; + + 如果需要定制自己的 dataloader,保证以下两点: + + 1. 在 __init__ 方法中加入 **kwargs,这是为了方便我们将 sampler 插入到具体的 DataLoader 的构造中; + 2. 在 __init__ 方法中出现的参数,请务必挂为同样名字的实例属性,例如 self.one_arg_name = one_arg_name,这是因为我们只能通过属性 + 来获取实际的参数的值; - 如果需要定制自己的 dataloader,保证以下两点: - 1. 在 __init__ 方法中加入 **kwargs,这是为了方便我们将 sampler 插入到具体的 DataLoader 的构造中; - 2. 在 __init__ 方法中出现的参数,请务必挂为同样名字的实例属性,例如 self.one_arg_name = one_arg_name,这是因为我们只能通过属性 - 来获取实际的参数的值; """ # 拿到实例属性; diff --git a/fastNLP/core/drivers/utils.py b/fastNLP/core/drivers/utils.py index 040747f0..09cac2b9 100644 --- a/fastNLP/core/drivers/utils.py +++ b/fastNLP/core/drivers/utils.py @@ -1,38 +1,5 @@ -from typing import Optional -from typing import Union, List +from typing import List import subprocess -from pathlib import Path - -from fastNLP.core.drivers.driver import Driver - - - -def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, List[int], str]], **kwargs) -> Driver: - r""" - 根据输入的参数 'gpus' 的格式来决定具体的工作模式; - - :param model: 运行过程中使用的具体的最原始的模型; - :param driver: 应当为字符串或者 `Driver` 实例,表示运行中具体使用的训练/评测模式; - :param device: 具体的形式请参见 `fastNLP.core.drivers.torch_driver.utils.initialize_torch_dirver` 的注释; - :param kwargs: 其余的传给 `Driver` 的参数; - """ - - # 如果用户直接传进来一个 driver 实例,我们就直接返回回去,目前用户需要自己保证传进来的 driver 的正确性; - if isinstance(driver, Driver): - return driver - - if driver in {"torch", "torch_ddp", "fairscale"}: - from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver - return initialize_torch_driver(driver, device, model, **kwargs) - elif driver in {"jittor"}: - from fastNLP.core.drivers.jittor_driver.initialize_jittor_driver import initialize_jittor_driver - return initialize_jittor_driver(driver, device, model, **kwargs) - elif driver in {"paddle", "fleet"}: - from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver - return initialize_paddle_driver(driver, device, model, **kwargs) - else: - raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'torch_ddp', 'fairscale', " - "'jittor', 'paddle', 'fleet'].") def distributed_open_proc(output_from_new_proc:str, command:List[str], env_copy:dict, rank:int=None): diff --git a/fastNLP/core/log/logger.py b/fastNLP/core/log/logger.py index 809e9c5c..86d52041 100644 --- a/fastNLP/core/log/logger.py +++ b/fastNLP/core/log/logger.py @@ -1,18 +1,20 @@ r""" Logger 是fastNLP中记录日志的模块,logger封装了logging模块的Logger, 具体使用方式与直接使用logging.Logger相同,同时也新增一些简单好用的API -使用方式: -from fastNLP import _logger -# -# _logger 可以和 logging.Logger 一样使用 -_logger.info('your msg') -_logger.error('your msg') - -# _logger 新增的API -# 将日志输出到文件,以及输出的日志等级 -_logger.add_file('/path/to/log', level='INFO') -# 定义在命令行中的显示格式和日志等级 -_logger.set_stdout('tqdm', level='WARN') + +使用方式:: + + from fastNLP import _logger + # + # _logger 可以和 logging.Logger 一样使用 + _logger.info('your msg') + _logger.error('your msg') + + # _logger 新增的API + # 将日志输出到文件,以及输出的日志等级 + _logger.add_file('/path/to/log', level='INFO') + # 定义在命令行中的显示格式和日志等级 + _logger.set_stdout('tqdm', level='WARN') """ diff --git a/fastNLP/core/log/print.py b/fastNLP/core/log/print.py index 610ae0bd..b3d328ed 100644 --- a/fastNLP/core/log/print.py +++ b/fastNLP/core/log/print.py @@ -10,12 +10,13 @@ def print(*args, sep=' ', end='\n', file=None, flush=False): 用来重定向 print 函数至 logger.info 的函数。 Example:: + from fastNLP import print print("This is a test") # 等价于调用了 logger.info("This is a test") :param args: 需要打印的内容 :param sep: 存在多个输入时,使用的间隔。 - :param end: 该参数在当前设置无意义,因为结尾一定会被加入 \n 。 + :param end: 该参数在当前设置无意义,因为结尾一定会被加入 '\\\\n' 。 :param file: 该参数无意义。 :param flush: 该参数无意义。 :return: diff --git a/fastNLP/core/metrics/metric.py b/fastNLP/core/metrics/metric.py index b5fc44dd..6a32ef60 100644 --- a/fastNLP/core/metrics/metric.py +++ b/fastNLP/core/metrics/metric.py @@ -38,7 +38,7 @@ class Metric: def register_element(self, name, value: float = 0, aggregate_method=None, backend='auto') -> Element: """ 注册一个 element 对象,注册之后便可以通过在 Metric 中直接通过 self.{name} 进行调用,可以认为该对象即为对应 backend 的 - tensor 直接进行加减乘除计算即可。 + tensor 直接进行加减乘除计算即可。 注意:如果想使得该 metric 可自动扩展到多卡的情况,请一定申明 aggregate_method 。 :param name: 当前 element 的名字,注册后,在 Metric 中可以通过 self.{name} 访问该变量。 @@ -48,7 +48,7 @@ class Metric: Torch.tensor ; 如果backend 为 paddle 则该对象为 paddle.tensor ;如果 backend 为 jittor , 则该对象为 jittor.Var 。 一般情况下直接默认为 auto 就行了,fastNLP 会根据实际调用 Metric.update() 函数时传入的参数进行合理的初始化,例如当传入 的参数中只包含 torch.Tensor 这一种 tensor 时(可以有其它非 tensor 类型的输入)则认为 backend 为 torch ;只包含 - jittor.Var 则认为 backend 这一种 tensor 时(可以有其它非 tensor 类型的输入)则认为 backend 为 jittor 。如果没有检测 + jittor.Var 则认为 backend 这一种 tensor 时(可以有其它非 tensor 类型的输入)则认为 backend 为 jittor 。如果没有检测 到任何一种 tensor ,就默认使用 float 类型作为 element 。 :return: 注册的 Element 对象 """ diff --git a/fastNLP/core/samplers/mix_sampler.py b/fastNLP/core/samplers/mix_sampler.py index f53c06a5..0aa543be 100644 --- a/fastNLP/core/samplers/mix_sampler.py +++ b/fastNLP/core/samplers/mix_sampler.py @@ -496,7 +496,7 @@ class PollingSampler(MixSampler): :param sampler: 实例化好的sampler,每个dataset对应一个sampler对象 :param drop_last: 是否去掉最后一个batch的数据,其长度小于batch_size :param ds_ratio: 当ds_ratio=None时候, 轮流采样dataset列表直至所有的数据集采样完;当ds_ratio='truncate_to_least'时, - 以dataset列表最短的ds为基准,长的数据集会被截断;当ds_ratio='pad_to_most'时,以dataset列表最长ds为基准,短的数据集会被重采样 + 以dataset列表最短的ds为基准,长的数据集会被截断;当ds_ratio='pad_to_most'时,以dataset列表最长ds为基准,短的数据集会被重采样 """ super(PollingSampler, self).__init__(dataset=dataset, batch_size=batch_size, sampler=sampler, ds_ratio=ds_ratio, diff --git a/fastNLP/core/samplers/utils.py b/fastNLP/core/samplers/utils.py index 2badc0dd..f8535ed2 100644 --- a/fastNLP/core/samplers/utils.py +++ b/fastNLP/core/samplers/utils.py @@ -35,7 +35,9 @@ class NumConsumedSamplesArray: def __init__(self, buffer_size=2000, num_consumed_samples=0): """ 保留 buffer_size 个 num_consumed_samples 数据,可以索引得到某个 index 下的 num_consumed_samples 多少 + Example:: + array = NumConsumedSamplesArray(buffer_size=3) for i in range(10): array.push(i) diff --git a/fastNLP/core/utils/__init__.py b/fastNLP/core/utils/__init__.py index ea716fe8..4de52d16 100644 --- a/fastNLP/core/utils/__init__.py +++ b/fastNLP/core/utils/__init__.py @@ -24,6 +24,7 @@ __all__ = [ 'Option', 'deprecated', 'seq_len_to_mask', + "flat_nest_dict" ] from .cache_results import cache_results @@ -33,8 +34,6 @@ from .paddle_utils import get_device_from_visible, paddle_to, paddle_move_data_t from .rich_progress import f_rich_progress from .torch_paddle_utils import torch_paddle_move_data_to_device from .torch_utils import torch_move_data_to_device -from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \ - dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ - deprecated, seq_len_to_mask +from .utils import * diff --git a/fastNLP/core/utils/cache_results.py b/fastNLP/core/utils/cache_results.py index cde4a51e..3313b9a1 100644 --- a/fastNLP/core/utils/cache_results.py +++ b/fastNLP/core/utils/cache_results.py @@ -222,7 +222,7 @@ def cache_results(_cache_fp, _hash_param=True, _refresh=False, _verbose=1, _chec 可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理。 如果在函数加上了装饰器@cache_results(),则函数会增加五个参数[_cache_fp, _hash_param, _refresh, _verbose, - _check_hash]。上面的例子即为使用_cache_fp的情况,这五个参数不会传入到被装饰函数中,当然被装饰函数参数名也不能包含这五个名称:: + _check_hash]。上面的例子即为使用_cache_fp的情况,这五个参数不会传入到被装饰函数中,当然被装饰函数参数名也不能包含这五个名称。 :param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在 函数调用的时候传入 _cache_fp 这个参数。保存文件的名称会受到 diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py index a96b5bd1..edb41032 100644 --- a/fastNLP/core/utils/utils.py +++ b/fastNLP/core/utils/utils.py @@ -35,6 +35,7 @@ __all__ = [ 'Option', 'deprecated', 'seq_len_to_mask', + "flat_nest_dict" ] @@ -256,12 +257,13 @@ def match_and_substitute_params(mapping: Optional[Union[Callable, Dict]] = None, 对于 `output_mapping`,该函数会在 `Trainer.train_step` 以及 `Evaluator.train_step` 中得到结果后立刻被调用; 转换的逻辑按优先级依次为: - 1. 如果 `mapping` 是一个函数,那么会直接返回 `mapping(data)`; - 2. 如果 `mapping` 是一个 `Dict`,那么 `data` 的类型只能为以下三种: [`Dict`, `dataclass`, `Sequence`]; - 如果 `data` 是 `Dict`,那么该函数会将 `data` 的 key 替换为 mapping[key]; - 如果 `data` 是 `dataclass`,那么该函数会先使用 `dataclasses.asdict` 函数将其转换为 `Dict`,然后进行转换; - 如果 `data` 是 `Sequence`,那么该函数会先将其转换成一个对应的 `Dict`:{"_0": list[0], "_1": list[1], ...},然后使用 - mapping对这个 `Dict` 进行转换,如果没有匹配上mapping中的key则保持"_number"这个形式。 + + 1. 如果 `mapping` 是一个函数,那么会直接返回 `mapping(data)`; + 2. 如果 `mapping` 是一个 `Dict`,那么 `data` 的类型只能为以下三种: [`Dict`, `dataclass`, `Sequence`]; + 如果 `data` 是 `Dict`,那么该函数会将 `data` 的 key 替换为 mapping[key]; + 如果 `data` 是 `dataclass`,那么该函数会先使用 `dataclasses.asdict` 函数将其转换为 `Dict`,然后进行转换; + 如果 `data` 是 `Sequence`,那么该函数会先将其转换成一个对应的 `Dict`:{"_0": list[0], "_1": list[1], ...},然后使用 + mapping对这个 `Dict` 进行转换,如果没有匹配上mapping中的key则保持"_number"这个形式。 :param mapping: 用于转换的字典或者函数;mapping是函数时,返回值必须为字典类型。 :param data: 需要被转换的对象; @@ -439,12 +441,16 @@ def _is_iterable(value): def pretty_table_printer(dataset_or_ins) -> PrettyTable: r""" :param dataset_or_ins: 传入一个dataSet或者instance - ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"]) - +-----------+-----------+-----------------+ - | field_1 | field_2 | field_3 | - +-----------+-----------+-----------------+ - | [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] | - +-----------+-----------+-----------------+ + + .. code-block:: + + ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"]) + +-----------+-----------+-----------------+ + | field_1 | field_2 | field_3 | + +-----------+-----------+-----------------+ + | [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] | + +-----------+-----------+-----------------+ + :return: 以 pretty table的形式返回根据terminal大小进行自动截断 """ x = PrettyTable() @@ -640,4 +646,55 @@ def is_notebook(): except: return False else: # pragma: no cover - return True \ No newline at end of file + return True + + +def flat_nest_dict(d:Dict, separator:str='#', compress_none_key:bool=True, top_down:bool=False) -> Dict: + """ + 讲一个 nested 的 dict 转成 flat 的 dict,例如 + ex:: + d = {'test': {'f1': {'f': 0.2, 'rec': 0.1}}} -> {'f#f1#test':0.2, 'rec#f1#test':0.1} + + :param d: 需要展平的 dict 对象。 + :param separator: 不同层级之间的 key 之间的连接符号。 + :param compress_none_key: 如果有 key 为 None ,则忽略这一层连接。 + :param top_down: 新的 key 的是否按照从最底层往最底层的顺序连接。 + :return: + """ + assert isinstance(d, Dict) + assert isinstance(separator, str) + flat_d = {} + for key, value in d.items(): + if key is None: + key = () + else: + key = (key, ) + if isinstance(value, Mapping): + flat_d.update(_flat_nest_dict(value, parent_key=key, compress_none_key=compress_none_key)) + else: + flat_d[key] = value + + str_flat_d = {} + for key, value in flat_d.items(): + if top_down: + key = map(str, key) + else: + key = map(str, key[::-1]) + key = separator.join(key) + str_flat_d[key] = value + return str_flat_d + + +def _flat_nest_dict(d:Mapping, parent_key:Tuple, compress_none_key:bool): + flat_d = {} + for k, v in d.items(): + _key = parent_key + if k is not None: + _key = _key + (k,) + if isinstance(v, Mapping): + _d = _flat_nest_dict(v, parent_key=_key, compress_none_key=compress_none_key) + flat_d.update(_d) + else: + flat_d[_key] = v + + return flat_d diff --git a/fastNLP/envs/distributed.py b/fastNLP/envs/distributed.py index 0b8f4f74..adcfb085 100644 --- a/fastNLP/envs/distributed.py +++ b/fastNLP/envs/distributed.py @@ -47,7 +47,7 @@ def rank_zero_call(fn: Callable): rank_zero_call(add)(1, 2) 同时,该函数还会设置 FASTNLP_NO_SYNC 为 2,在这个环境下,所有的 fastNLP 内置的 barrier 接口,gather/broadcast 操作都没有任何 - 意义。 + 意义。 :param fn: 需要包裹的可执行的函数。 :return: @@ -65,7 +65,7 @@ def rank_zero_call(fn: Callable): def fastnlp_no_sync_context(level=2): """ 用于让 fastNLP 的 barrier 以及 gather/broadcast等操作等同于只有1卡的多卡程序。如果为 1 表示 fastNLP 里的barrier 操作失效; - 如果为 2 表示 barrier 与 gather/broadcast 都失效。 + 如果为 2 表示 barrier 与 gather/broadcast 都失效。 :param int level: 可选 [0, 1, 2] :return: @@ -84,9 +84,10 @@ def all_rank_call_context(): """ 在多卡模式下,该环境内,会暂时地将 FASTNLP_GLOBAL_RANK 设置为 "0",使得 rank_zero_call 函数失效,使得每个进程都会运行该函数。 - # 使用方式 - with all_rank_call_context(): - do_something # all rank will do + 使用方式:: + + with all_rank_call_context(): + do_something # all rank will do :param fn: :return: diff --git a/fastNLP/io/data_bundle.py b/fastNLP/io/data_bundle.py index 4a23990d..a3c15a28 100644 --- a/fastNLP/io/data_bundle.py +++ b/fastNLP/io/data_bundle.py @@ -233,8 +233,8 @@ class DataBundle: 如果为False,则报错 :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 :param ignore_miss_dataset: 如果 dataset 没有 {field_name} ,就直接跳过这个 dataset 。 - :param progress_desc 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称 - :param show_progress_bar 是否显示tqdm进度条 + :param progress_desc: 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称 + :param show_progress_bar: 是否显示tqdm进度条 """ _progress_desc = progress_desc diff --git a/tests/core/controllers/test_trainer_jittor.py b/tests/core/controllers/test_trainer_jittor.py new file mode 100644 index 00000000..d0eac8cd --- /dev/null +++ b/tests/core/controllers/test_trainer_jittor.py @@ -0,0 +1,133 @@ +import pytest + +from fastNLP.core.controllers.trainer import Trainer +from fastNLP.core.controllers.trainer import Evaluator +from fastNLP.core.metrics.accuracy import Accuracy +from fastNLP.core.callbacks.progress_callback import RichCallback +from fastNLP.core.dataloaders.jittor_dataloader.fdl import JittorDataLoader +from fastNLP.envs.imports import _NEED_IMPORT_JITTOR + +if _NEED_IMPORT_JITTOR: + import jittor as jt + from jittor import nn, Module + from jittor.dataset import Dataset + + +class JittorNormalModel_Classification(Module): + """ + 基础的 Jittor 分类模型 + """ + + def __init__(self, num_labels, feature_dimension): + super(JittorNormalModel_Classification, self).__init__() + self.num_labels = num_labels + + self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64) + self.ac1 = nn.ReLU() + self.linear2 = nn.Linear(in_features=64, out_features=32) + self.ac2 = nn.ReLU() + self.output = nn.Linear(in_features=32, out_features=num_labels) + self.loss_fn = nn.CrossEntropyLoss() + + def execute(self, x): + # It's similar to forward function in Pytorch + x = self.ac1(self.linear1(x)) + x = self.ac2(self.linear2(x)) + x = self.output(x) + return x + + def train_step(self, x, y): + x = self(x) + return {"loss": self.loss_fn(x, y)} + + def evaluate_step(self, x, y): + x = self(x) + return {"pred": x, "target": y.reshape((-1,))} + + +class JittorRandomMaxDataset(Dataset): + def __init__(self, num_samples, num_features): + super(JittorRandomMaxDataset, self).__init__() + self.x = jt.randn((num_samples, num_features)) + self.y = self.x.argmax(dim=1)[0] + + def __len__(self): + return len(self.y) + + def __getitem__(self, item): + return {"x": self.x[item], "y": self.y[item]} + + +class TrainJittorConfig: + num_labels: int = 5 + feature_dimension: int = 5 + lr = 1e-1 + batch_size: int = 4 + shuffle: bool = True + + +@pytest.mark.parametrize("driver,device", [("jittor", None)]) +@pytest.mark.parametrize("callbacks", [[RichCallback(100)]]) +def test_trainer_jittor( + driver, + device, + callbacks, + n_epochs=3, +): + model = JittorNormalModel_Classification( + num_labels=TrainJittorConfig.num_labels, + feature_dimension=TrainJittorConfig.feature_dimension + ) + optimizer = nn.SGD(model.parameters(), lr=TrainJittorConfig.lr) + train_dataloader = JittorDataLoader( + dataset=JittorRandomMaxDataset(1000, TrainJittorConfig.feature_dimension), + batch_size=TrainJittorConfig.batch_size, + shuffle=True, + # num_workers=4, + ) + val_dataloader = JittorDataLoader( + dataset=JittorRandomMaxDataset(500, TrainJittorConfig.feature_dimension), + batch_size=TrainJittorConfig.batch_size, + shuffle=True, + # num_workers=4, + ) + test_dataloader = JittorDataLoader( + dataset=JittorRandomMaxDataset(1000, TrainJittorConfig.feature_dimension), + batch_size=TrainJittorConfig.batch_size, + shuffle=True, + # num_workers=4, + ) + metrics = {"acc": Accuracy()} + + trainer = Trainer( + model=model, + driver=driver, + device=device, + optimizers=optimizer, + train_dataloader=train_dataloader, + evaluate_dataloaders=val_dataloader, + validate_every=-1, + evaluate_fn="evaluate_step", + input_mapping=None, + output_mapping=None, + metrics=metrics, + n_epochs=n_epochs, + callbacks=callbacks, + # progress_bar="rich" + ) + trainer.run() + + evaluator = Evaluator( + model=model, + driver=driver, + dataloaders=test_dataloader, + evaluate_fn="evaluate_step", + metrics=metrics, + ) + metric_results = evaluator.run() + assert metric_results["acc#acc"] > 0.80 + + +if __name__ == "__main__": + # test_trainer_jittor("jittor", None, [RichCallback(100)]) + pytest.main(['test_trainer_jittor.py']) # 只运行此模块 diff --git a/tests/core/controllers/test_trainer_w_evaluator_torch.py b/tests/core/controllers/test_trainer_w_evaluator_torch.py index 8971b2fe..1eb1ea4d 100644 --- a/tests/core/controllers/test_trainer_w_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_w_evaluator_torch.py @@ -174,7 +174,7 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps( dist.destroy_process_group() @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch", 1)]) # ("torch", [0, 1]),("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch", 'cpu')]) # ("torch", [0, 1]),("torch", 1) @magic_argv_env_context def test_trainer_validate_every( model_and_optimizers: TrainerParameters, @@ -234,7 +234,7 @@ def test_trainer_on( device=device, optimizers=model_and_optimizers.optimizers, train_dataloader=model_and_optimizers.train_dataloader, - evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, + evaluate_dataloaders={"dl":model_and_optimizers.evaluate_dataloaders}, input_mapping=model_and_optimizers.input_mapping, output_mapping=model_and_optimizers.output_mapping, metrics=model_and_optimizers.metrics,