diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index d07382e4..02b56cd7 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -14,7 +14,7 @@ __all__ = [ 'MoreEvaluateCallback', "TorchWarmupCallback", "TorchGradClipCallback", - "MonitorUtility", + "ResultsMonitor", 'HasMonitorCallback', # collators diff --git a/fastNLP/core/callbacks/__init__.py b/fastNLP/core/callbacks/__init__.py index 6f859183..9ba0d227 100644 --- a/fastNLP/core/callbacks/__init__.py +++ b/fastNLP/core/callbacks/__init__.py @@ -16,7 +16,7 @@ __all__ = [ "TorchWarmupCallback", "TorchGradClipCallback", - "MonitorUtility", + "ResultsMonitor", 'HasMonitorCallback' ] @@ -31,5 +31,5 @@ from .load_best_model_callback import LoadBestModelCallback from .early_stop_callback import EarlyStopCallback from .torch_callbacks import * from .more_evaluate_callback import MoreEvaluateCallback -from .has_monitor_callback import MonitorUtility, HasMonitorCallback +from .has_monitor_callback import ResultsMonitor, HasMonitorCallback diff --git a/fastNLP/core/callbacks/callback_manager.py b/fastNLP/core/callbacks/callback_manager.py index 8e1b64de..eabc489b 100644 --- a/fastNLP/core/callbacks/callback_manager.py +++ b/fastNLP/core/callbacks/callback_manager.py @@ -57,7 +57,7 @@ def prepare_callbacks(callbacks, progress_bar): if has_no_progress and progress_bar is not None: callback = choose_progress_callback(progress_bar) if callback is not None: - _callbacks.append(callback) + _callbacks = [callback] + _callbacks # 放在最前面,方便分割不同 epoch has_no_progress = False elif has_no_progress is False and progress_bar not in ('auto', None): logger.rank_zero_warning(f"Since you have passed in ProgressCallback, progress_bar={progress_bar} will be ignored.") diff --git a/fastNLP/core/callbacks/has_monitor_callback.py b/fastNLP/core/callbacks/has_monitor_callback.py index 8e5eb0aa..2d1affd2 100644 --- a/fastNLP/core/callbacks/has_monitor_callback.py +++ b/fastNLP/core/callbacks/has_monitor_callback.py @@ -1,7 +1,7 @@ __all__ = [ 'HasMonitorCallback', 'ExecuteOnceBetterMonitor', - 'MonitorUtility' + 'ResultsMonitor' ] from typing import Dict, Union, Any @@ -29,12 +29,16 @@ class CanItemDataType(ABC): return NotImplemented -class MonitorUtility: - """ - 计算 monitor 的相关函数 +class ResultsMonitor: + def __init__(self, monitor:Union[Callback, str], larger_better:bool=True): + """ + 可用于监控某个数值,并通过 is_better_results() 等接口实现检测结果是否变得更好了。 - """ - def __init__(self, monitor, larger_better): + :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 + 果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 + :param larger_better: monitor 是否时越大越好 + """ self.set_monitor(monitor, larger_better) def set_monitor(self, monitor, larger_better): @@ -53,7 +57,7 @@ class MonitorUtility: def itemize_results(self, results): """ - 将结果中有 .item() 方法的都调用一下,使得可以结果可以保存 + 将结果中有 .item() 方法的都调用一下,使得 tensor 类型的数据转为 python 内置类型。 :param results: :return: @@ -161,7 +165,7 @@ class MonitorUtility: return monitor_name -class HasMonitorCallback(MonitorUtility, Callback): +class HasMonitorCallback(ResultsMonitor, Callback): def __init__(self, monitor, larger_better, must_have_monitor=False): """ 该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了 diff --git a/fastNLP/core/callbacks/topk_saver.py b/fastNLP/core/callbacks/topk_saver.py index cf6881d7..09843511 100644 --- a/fastNLP/core/callbacks/topk_saver.py +++ b/fastNLP/core/callbacks/topk_saver.py @@ -12,7 +12,7 @@ from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_LAUNCH_TIME from fastNLP.envs import rank_zero_call from fastNLP.envs.env import FASTNLP_EVALUATE_RESULT_FILENAME -from .has_monitor_callback import MonitorUtility +from .has_monitor_callback import ResultsMonitor class Saver: @@ -170,7 +170,7 @@ class TopkQueue: return self.topk != 0 -class TopkSaver(MonitorUtility, Saver): +class TopkSaver(ResultsMonitor, Saver): def __init__(self, topk:int=0, monitor:str=None, larger_better:bool=True, folder:str=None, save_object:str='model', only_state_dict:bool=True, model_save_fn:Callable=None, save_evaluate_results:bool=True, **kwargs): @@ -196,7 +196,7 @@ class TopkSaver(MonitorUtility, Saver): fastnlp_evaluate_results.json 文件,记录当前的 results。仅在设置了 topk 的场景下有用,默认为 True 。 :param kwargs: 更多需要传递给 Trainer.save() 或者 Trainer.save_model() 接口的参数。 """ - MonitorUtility.__init__(self, monitor, larger_better) + ResultsMonitor.__init__(self, monitor, larger_better) Saver.__init__(self, folder, save_object, only_state_dict, model_save_fn, **kwargs) if monitor is not None and topk == 0: diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index 4583bae2..48aee094 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -8,10 +8,10 @@ __all__ = [ ] from fastNLP.core.drivers import Driver -from fastNLP.core.drivers.utils import choose_driver +from ..drivers.choose_driver import choose_driver from .loops import Loop, EvaluateBatchLoop from fastNLP.core.utils import auto_param_call, dataclass_to_dict, \ - match_and_substitute_params, f_rich_progress + match_and_substitute_params, f_rich_progress, flat_nest_dict from fastNLP.core.metrics import Metric from fastNLP.core.metrics.utils import _is_torchmetrics_metric, _is_paddle_metric, _is_allennlp_metric from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader @@ -162,13 +162,15 @@ class Evaluator: self.cur_dataloader_name = dataloader_name results = self.evaluate_batch_loop.run(self, dataloader) self.remove_progress_bar(dataloader_name) - metric_results.update(results) + metric_results[dataloader_name] = results self.reset() self.driver.barrier() except BaseException as e: raise e finally: self.finally_progress_bar() + + metric_results = flat_nest_dict(metric_results, separator=self.separator, compress_none_key=True, top_down=False) self.driver.set_model_mode(mode='train') if self.verbose: if self.progress_bar == 'rich': @@ -251,14 +253,13 @@ class Evaluator: """ self.metrics_wrapper.update(batch, outputs) - def get_dataloader_metric(self, dataloader_name: Optional[str] = '') -> Dict: + def get_metric(self) -> Dict: """ - 获取当前dataloader的metric结果 + 调用所有 metric 的 get_metric 方法,并返回结果。其中 key 为 metric 的名称,value 是各个 metric 的结果。 - :param str dataloader_name: 当前dataloader的名字 :return: """ - return self.metrics_wrapper.get_metric(dataloader_name=dataloader_name, separator=self.separator) + return self.metrics_wrapper.get_metric() @property def metrics_wrapper(self): @@ -366,15 +367,12 @@ class _MetricsWrapper: elif _is_torchmetrics_metric(metric) or _is_paddle_metric(metric) or isinstance(metric, Metric): metric.reset() - def get_metric(self, dataloader_name: str, separator: str) -> Dict: + def get_metric(self) -> Dict: """ - 将所有 metric 结果展平到一个一级的字典中,这个字典中 key 的命名规则是 - indicator_name{separator}metric_name{separator}dataloader_name - 例如: f1#F1PreRec#dev + 调用各个 metric 得到 metric 的结果。并使用 {'metric_name1': metric_results, 'metric_name2': metric_results} 的形式 + 返回。 - :param dataloader_name: 当前metric对应的dataloader的名字。若为空,则不显示在最终的key上面。 - :param separator: 用于间隔不同称呼。 - :return: 返回一个一级结构的字典,其中 key 为区别一个 metric 的名字,value 为该 metric 的值; + :return: """ results = {} for metric_name, metric in zip(self._metric_names, self._metrics): @@ -384,37 +382,9 @@ class _MetricsWrapper: _results = metric.get_metric(reset=False) elif _is_torchmetrics_metric(metric): _results = metric.compute() - # 我们规定了 evaluator 中的 metrics 的输入只能是一个 dict,这样如果 metric 是一个 torchmetrics 时,如果 evaluator - # 没有传入 func_post_proc,那么我们就自动使用该 metric 的 metric name 当做其的 indicator name 将其自动转换成一个字典; elif _is_paddle_metric(metric): _results = metric.accumulate() - if not isinstance(_results, Dict): - name = _get_metric_res_name(dataloader_name, metric_name, '', separator) - results[name] = _results else: - for indicator_name, value in _results.items(): - name = _get_metric_res_name(dataloader_name, metric_name, indicator_name, separator) - results[name] = value - + raise RuntimeError(f"Not support `{type(metric)}` for now.") + results[metric_name] = _results return results - - -def _get_metric_res_name(dataloader_name: Optional[str], metric_name: str, indicator_name: str, separator='#') -> str: - """ - - :param dataloader_name: dataloder的名字 - :param metric_name: metric的名字 - :param indicator_name: metric中的各项metric名称,例如f, precision, recall - :param separator: 用以间隔不同对象的间隔符 - :return: - """ - names = [] - if indicator_name: - names.append(indicator_name) - if metric_name: - names.append(metric_name) - if dataloader_name: - names.append(dataloader_name) - if len(names) == 0: - raise RuntimeError("You cannot use empty `dataloader_name`, `metric_name`, and `monitor` simultaneously.") - return separator.join(names) diff --git a/fastNLP/core/controllers/loops/evaluate_batch_loop.py b/fastNLP/core/controllers/loops/evaluate_batch_loop.py index 2d8f07d1..0bf66fda 100644 --- a/fastNLP/core/controllers/loops/evaluate_batch_loop.py +++ b/fastNLP/core/controllers/loops/evaluate_batch_loop.py @@ -40,8 +40,8 @@ class EvaluateBatchLoop(Loop): self.batch_step_fn(evaluator, batch) batch_idx += 1 evaluator.update_progress_bar(batch_idx, evaluator.cur_dataloader_name) - # 获取metric结果。返回的dict内容示例为{'f1#F1Metric#dl1': 0.93, 'pre#F1Metric#dl1': 0.95, ...} - results = evaluator.get_dataloader_metric(dataloader_name=evaluator.cur_dataloader_name) + # 获取metric结果。返回的dict内容示例为{'metric_name1': metric_results, 'metric_name2': metric_results, ...} + results = evaluator.get_metric() return results @staticmethod diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 49a54a07..6c1117db 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -23,7 +23,7 @@ from fastNLP.core.callbacks.callback import _CallbackWrapper from fastNLP.core.callbacks.callback_manager import prepare_callbacks from fastNLP.core.callbacks.callback_event import Event from fastNLP.core.drivers import Driver -from fastNLP.core.drivers.utils import choose_driver +from ..drivers.choose_driver import choose_driver from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext from fastNLP.core.utils.utils import _check_valid_parameters_number from fastNLP.envs import rank_zero_call diff --git a/fastNLP/core/drivers/choose_driver.py b/fastNLP/core/drivers/choose_driver.py new file mode 100644 index 00000000..5696b4c7 --- /dev/null +++ b/fastNLP/core/drivers/choose_driver.py @@ -0,0 +1,31 @@ +from typing import Union, Optional, List + +from .driver import Driver + + +def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, List[int], str]], **kwargs) -> Driver: + r""" + 根据输入的参数 'gpus' 的格式来决定具体的工作模式; + + :param model: 运行过程中使用的具体的最原始的模型; + :param driver: 应当为字符串或者 `Driver` 实例,表示运行中具体使用的训练/评测模式; + :param device: 具体的形式请参见 `fastNLP.core.drivers.torch_driver.utils.initialize_torch_dirver` 的注释; + :param kwargs: 其余的传给 `Driver` 的参数; + """ + + # 如果用户直接传进来一个 driver 实例,我们就直接返回回去,目前用户需要自己保证传进来的 driver 的正确性; + if isinstance(driver, Driver): + return driver + + if driver in {"torch", "torch_ddp", "fairscale"}: + from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver + return initialize_torch_driver(driver, device, model, **kwargs) + elif driver in {"jittor"}: + from fastNLP.core.drivers.jittor_driver.initialize_jittor_driver import initialize_jittor_driver + return initialize_jittor_driver(driver, device, model, **kwargs) + elif driver in {"paddle", "fleet"}: + from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver + return initialize_paddle_driver(driver, device, model, **kwargs) + else: + raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'torch_ddp', 'fairscale', " + "'jittor', 'paddle', 'fleet'].") \ No newline at end of file diff --git a/fastNLP/core/drivers/jittor_driver/jittor_driver.py b/fastNLP/core/drivers/jittor_driver/jittor_driver.py index bcebc6d0..b751354d 100644 --- a/fastNLP/core/drivers/jittor_driver/jittor_driver.py +++ b/fastNLP/core/drivers/jittor_driver/jittor_driver.py @@ -33,11 +33,12 @@ class JittorDriver(Driver): f"`jittor.Module` type.") super(JittorDriver, self).__init__(model) - self.model = model - self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) self.grad_scaler = _grad_scaler() + # 用来设置是否关闭 auto_param_call 中的参数匹配问题; + self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) + @staticmethod def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): # 在fastnlp中实现了JittorDataLoader @@ -152,4 +153,4 @@ class JittorDriver(Driver): # def set_sampler_epoch(self, dataloader: JittorDataLoader, cur_epoch_idx): # # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; # if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): - # dataloader.batch_sampler.set_epoch(cur_epoch_idx) \ No newline at end of file + # dataloader.batch_sampler.set_epoch(cur_epoch_idx) diff --git a/fastNLP/core/drivers/jittor_driver/single_device.py b/fastNLP/core/drivers/jittor_driver/single_device.py index 695e6ec9..ab1e8595 100644 --- a/fastNLP/core/drivers/jittor_driver/single_device.py +++ b/fastNLP/core/drivers/jittor_driver/single_device.py @@ -60,8 +60,8 @@ class JittorSingleDriver(JittorDriver): logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...') return fn, None elif fn in {"train_step", "evaluate_step"}: - logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...') - return self.model, self.model.forward + logger.debug(f'Use {_get_fun_msg(self.model.execute, with_fp=False)}...') + return self.model, self.model.execute else: raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") @@ -98,3 +98,9 @@ class JittorSingleDriver(JittorDriver): return dataloader else: return dataloader + + def setup(self): + """ + 使用单个 GPU 时,jittor 底层自动实现调配,无需额外操作 + """ + pass diff --git a/fastNLP/core/drivers/utils.py b/fastNLP/core/drivers/utils.py index d2a221d4..09cac2b9 100644 --- a/fastNLP/core/drivers/utils.py +++ b/fastNLP/core/drivers/utils.py @@ -1,38 +1,5 @@ -from typing import Optional -from typing import Union, List +from typing import List import subprocess -from pathlib import Path - -from fastNLP.core.drivers.driver import Driver - -__all__ = [] - -def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, List[int], str]], **kwargs) -> Driver: - r""" - 根据输入的参数 'gpus' 的格式来决定具体的工作模式; - - :param model: 运行过程中使用的具体的最原始的模型; - :param driver: 应当为字符串或者 `Driver` 实例,表示运行中具体使用的训练/评测模式; - :param device: 具体的形式请参见 `fastNLP.core.drivers.torch_driver.utils.initialize_torch_dirver` 的注释; - :param kwargs: 其余的传给 `Driver` 的参数; - """ - - # 如果用户直接传进来一个 driver 实例,我们就直接返回回去,目前用户需要自己保证传进来的 driver 的正确性; - if isinstance(driver, Driver): - return driver - - if driver in {"torch", "torch_ddp", "fairscale"}: - from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver - return initialize_torch_driver(driver, device, model, **kwargs) - elif driver in {"jittor"}: - from fastNLP.core.drivers.jittor_driver.initialize_jittor_driver import initialize_jittor_driver - return initialize_jittor_driver(driver, device, model, **kwargs) - elif driver in {"paddle", "fleet"}: - from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver - return initialize_paddle_driver(driver, device, model, **kwargs) - else: - raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'torch_ddp', 'fairscale', " - "'jittor', 'paddle', 'fleet'].") def distributed_open_proc(output_from_new_proc:str, command:List[str], env_copy:dict, rank:int=None): diff --git a/fastNLP/core/utils/__init__.py b/fastNLP/core/utils/__init__.py index ea716fe8..4de52d16 100644 --- a/fastNLP/core/utils/__init__.py +++ b/fastNLP/core/utils/__init__.py @@ -24,6 +24,7 @@ __all__ = [ 'Option', 'deprecated', 'seq_len_to_mask', + "flat_nest_dict" ] from .cache_results import cache_results @@ -33,8 +34,6 @@ from .paddle_utils import get_device_from_visible, paddle_to, paddle_move_data_t from .rich_progress import f_rich_progress from .torch_paddle_utils import torch_paddle_move_data_to_device from .torch_utils import torch_move_data_to_device -from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \ - dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ - deprecated, seq_len_to_mask +from .utils import * diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py index 1c54c03c..edb41032 100644 --- a/fastNLP/core/utils/utils.py +++ b/fastNLP/core/utils/utils.py @@ -35,6 +35,7 @@ __all__ = [ 'Option', 'deprecated', 'seq_len_to_mask', + "flat_nest_dict" ] @@ -645,4 +646,55 @@ def is_notebook(): except: return False else: # pragma: no cover - return True \ No newline at end of file + return True + + +def flat_nest_dict(d:Dict, separator:str='#', compress_none_key:bool=True, top_down:bool=False) -> Dict: + """ + 讲一个 nested 的 dict 转成 flat 的 dict,例如 + ex:: + d = {'test': {'f1': {'f': 0.2, 'rec': 0.1}}} -> {'f#f1#test':0.2, 'rec#f1#test':0.1} + + :param d: 需要展平的 dict 对象。 + :param separator: 不同层级之间的 key 之间的连接符号。 + :param compress_none_key: 如果有 key 为 None ,则忽略这一层连接。 + :param top_down: 新的 key 的是否按照从最底层往最底层的顺序连接。 + :return: + """ + assert isinstance(d, Dict) + assert isinstance(separator, str) + flat_d = {} + for key, value in d.items(): + if key is None: + key = () + else: + key = (key, ) + if isinstance(value, Mapping): + flat_d.update(_flat_nest_dict(value, parent_key=key, compress_none_key=compress_none_key)) + else: + flat_d[key] = value + + str_flat_d = {} + for key, value in flat_d.items(): + if top_down: + key = map(str, key) + else: + key = map(str, key[::-1]) + key = separator.join(key) + str_flat_d[key] = value + return str_flat_d + + +def _flat_nest_dict(d:Mapping, parent_key:Tuple, compress_none_key:bool): + flat_d = {} + for k, v in d.items(): + _key = parent_key + if k is not None: + _key = _key + (k,) + if isinstance(v, Mapping): + _d = _flat_nest_dict(v, parent_key=_key, compress_none_key=compress_none_key) + flat_d.update(_d) + else: + flat_d[_key] = v + + return flat_d diff --git a/tests/core/controllers/test_trainer_jittor.py b/tests/core/controllers/test_trainer_jittor.py new file mode 100644 index 00000000..d0eac8cd --- /dev/null +++ b/tests/core/controllers/test_trainer_jittor.py @@ -0,0 +1,133 @@ +import pytest + +from fastNLP.core.controllers.trainer import Trainer +from fastNLP.core.controllers.trainer import Evaluator +from fastNLP.core.metrics.accuracy import Accuracy +from fastNLP.core.callbacks.progress_callback import RichCallback +from fastNLP.core.dataloaders.jittor_dataloader.fdl import JittorDataLoader +from fastNLP.envs.imports import _NEED_IMPORT_JITTOR + +if _NEED_IMPORT_JITTOR: + import jittor as jt + from jittor import nn, Module + from jittor.dataset import Dataset + + +class JittorNormalModel_Classification(Module): + """ + 基础的 Jittor 分类模型 + """ + + def __init__(self, num_labels, feature_dimension): + super(JittorNormalModel_Classification, self).__init__() + self.num_labels = num_labels + + self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64) + self.ac1 = nn.ReLU() + self.linear2 = nn.Linear(in_features=64, out_features=32) + self.ac2 = nn.ReLU() + self.output = nn.Linear(in_features=32, out_features=num_labels) + self.loss_fn = nn.CrossEntropyLoss() + + def execute(self, x): + # It's similar to forward function in Pytorch + x = self.ac1(self.linear1(x)) + x = self.ac2(self.linear2(x)) + x = self.output(x) + return x + + def train_step(self, x, y): + x = self(x) + return {"loss": self.loss_fn(x, y)} + + def evaluate_step(self, x, y): + x = self(x) + return {"pred": x, "target": y.reshape((-1,))} + + +class JittorRandomMaxDataset(Dataset): + def __init__(self, num_samples, num_features): + super(JittorRandomMaxDataset, self).__init__() + self.x = jt.randn((num_samples, num_features)) + self.y = self.x.argmax(dim=1)[0] + + def __len__(self): + return len(self.y) + + def __getitem__(self, item): + return {"x": self.x[item], "y": self.y[item]} + + +class TrainJittorConfig: + num_labels: int = 5 + feature_dimension: int = 5 + lr = 1e-1 + batch_size: int = 4 + shuffle: bool = True + + +@pytest.mark.parametrize("driver,device", [("jittor", None)]) +@pytest.mark.parametrize("callbacks", [[RichCallback(100)]]) +def test_trainer_jittor( + driver, + device, + callbacks, + n_epochs=3, +): + model = JittorNormalModel_Classification( + num_labels=TrainJittorConfig.num_labels, + feature_dimension=TrainJittorConfig.feature_dimension + ) + optimizer = nn.SGD(model.parameters(), lr=TrainJittorConfig.lr) + train_dataloader = JittorDataLoader( + dataset=JittorRandomMaxDataset(1000, TrainJittorConfig.feature_dimension), + batch_size=TrainJittorConfig.batch_size, + shuffle=True, + # num_workers=4, + ) + val_dataloader = JittorDataLoader( + dataset=JittorRandomMaxDataset(500, TrainJittorConfig.feature_dimension), + batch_size=TrainJittorConfig.batch_size, + shuffle=True, + # num_workers=4, + ) + test_dataloader = JittorDataLoader( + dataset=JittorRandomMaxDataset(1000, TrainJittorConfig.feature_dimension), + batch_size=TrainJittorConfig.batch_size, + shuffle=True, + # num_workers=4, + ) + metrics = {"acc": Accuracy()} + + trainer = Trainer( + model=model, + driver=driver, + device=device, + optimizers=optimizer, + train_dataloader=train_dataloader, + evaluate_dataloaders=val_dataloader, + validate_every=-1, + evaluate_fn="evaluate_step", + input_mapping=None, + output_mapping=None, + metrics=metrics, + n_epochs=n_epochs, + callbacks=callbacks, + # progress_bar="rich" + ) + trainer.run() + + evaluator = Evaluator( + model=model, + driver=driver, + dataloaders=test_dataloader, + evaluate_fn="evaluate_step", + metrics=metrics, + ) + metric_results = evaluator.run() + assert metric_results["acc#acc"] > 0.80 + + +if __name__ == "__main__": + # test_trainer_jittor("jittor", None, [RichCallback(100)]) + pytest.main(['test_trainer_jittor.py']) # 只运行此模块 diff --git a/tests/core/controllers/test_trainer_w_evaluator_torch.py b/tests/core/controllers/test_trainer_w_evaluator_torch.py index 8971b2fe..1eb1ea4d 100644 --- a/tests/core/controllers/test_trainer_w_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_w_evaluator_torch.py @@ -174,7 +174,7 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps( dist.destroy_process_group() @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch", 1)]) # ("torch", [0, 1]),("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch", 'cpu')]) # ("torch", [0, 1]),("torch", 1) @magic_argv_env_context def test_trainer_validate_every( model_and_optimizers: TrainerParameters, @@ -234,7 +234,7 @@ def test_trainer_on( device=device, optimizers=model_and_optimizers.optimizers, train_dataloader=model_and_optimizers.train_dataloader, - evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, + evaluate_dataloaders={"dl":model_and_optimizers.evaluate_dataloaders}, input_mapping=model_and_optimizers.input_mapping, output_mapping=model_and_optimizers.output_mapping, metrics=model_and_optimizers.metrics,