Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
f27ff80fce
16 changed files with 270 additions and 107 deletions
  1. +1
    -1
      fastNLP/core/__init__.py
  2. +2
    -2
      fastNLP/core/callbacks/__init__.py
  3. +1
    -1
      fastNLP/core/callbacks/callback_manager.py
  4. +12
    -8
      fastNLP/core/callbacks/has_monitor_callback.py
  5. +3
    -3
      fastNLP/core/callbacks/topk_saver.py
  6. +14
    -44
      fastNLP/core/controllers/evaluator.py
  7. +2
    -2
      fastNLP/core/controllers/loops/evaluate_batch_loop.py
  8. +1
    -1
      fastNLP/core/controllers/trainer.py
  9. +31
    -0
      fastNLP/core/drivers/choose_driver.py
  10. +4
    -3
      fastNLP/core/drivers/jittor_driver/jittor_driver.py
  11. +8
    -2
      fastNLP/core/drivers/jittor_driver/single_device.py
  12. +1
    -34
      fastNLP/core/drivers/utils.py
  13. +2
    -3
      fastNLP/core/utils/__init__.py
  14. +53
    -1
      fastNLP/core/utils/utils.py
  15. +133
    -0
      tests/core/controllers/test_trainer_jittor.py
  16. +2
    -2
      tests/core/controllers/test_trainer_w_evaluator_torch.py

+ 1
- 1
fastNLP/core/__init__.py View File

@@ -14,7 +14,7 @@ __all__ = [
'MoreEvaluateCallback',
"TorchWarmupCallback",
"TorchGradClipCallback",
"MonitorUtility",
"ResultsMonitor",
'HasMonitorCallback',

# collators


+ 2
- 2
fastNLP/core/callbacks/__init__.py View File

@@ -16,7 +16,7 @@ __all__ = [
"TorchWarmupCallback",
"TorchGradClipCallback",

"MonitorUtility",
"ResultsMonitor",
'HasMonitorCallback'
]

@@ -31,5 +31,5 @@ from .load_best_model_callback import LoadBestModelCallback
from .early_stop_callback import EarlyStopCallback
from .torch_callbacks import *
from .more_evaluate_callback import MoreEvaluateCallback
from .has_monitor_callback import MonitorUtility, HasMonitorCallback
from .has_monitor_callback import ResultsMonitor, HasMonitorCallback


+ 1
- 1
fastNLP/core/callbacks/callback_manager.py View File

@@ -57,7 +57,7 @@ def prepare_callbacks(callbacks, progress_bar):
if has_no_progress and progress_bar is not None:
callback = choose_progress_callback(progress_bar)
if callback is not None:
_callbacks.append(callback)
_callbacks = [callback] + _callbacks # 放在最前面,方便分割不同 epoch
has_no_progress = False
elif has_no_progress is False and progress_bar not in ('auto', None):
logger.rank_zero_warning(f"Since you have passed in ProgressCallback, progress_bar={progress_bar} will be ignored.")


+ 12
- 8
fastNLP/core/callbacks/has_monitor_callback.py View File

@@ -1,7 +1,7 @@
__all__ = [
'HasMonitorCallback',
'ExecuteOnceBetterMonitor',
'MonitorUtility'
'ResultsMonitor'
]

from typing import Dict, Union, Any
@@ -29,12 +29,16 @@ class CanItemDataType(ABC):
return NotImplemented


class MonitorUtility:
"""
计算 monitor 的相关函数
class ResultsMonitor:
def __init__(self, monitor:Union[Callback, str], larger_better:bool=True):
"""
可用于监控某个数值,并通过 is_better_results() 等接口实现检测结果是否变得更好了。

"""
def __init__(self, monitor, larger_better):
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。
:param larger_better: monitor 是否时越大越好
"""
self.set_monitor(monitor, larger_better)

def set_monitor(self, monitor, larger_better):
@@ -53,7 +57,7 @@ class MonitorUtility:

def itemize_results(self, results):
"""
将结果中有 .item() 方法的都调用一下,使得可以结果可以保存
将结果中有 .item() 方法的都调用一下,使得 tensor 类型的数据转为 python 内置类型。

:param results:
:return:
@@ -161,7 +165,7 @@ class MonitorUtility:
return monitor_name


class HasMonitorCallback(MonitorUtility, Callback):
class HasMonitorCallback(ResultsMonitor, Callback):
def __init__(self, monitor, larger_better, must_have_monitor=False):
"""
该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了


+ 3
- 3
fastNLP/core/callbacks/topk_saver.py View File

@@ -12,7 +12,7 @@ from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_LAUNCH_TIME
from fastNLP.envs import rank_zero_call
from fastNLP.envs.env import FASTNLP_EVALUATE_RESULT_FILENAME
from .has_monitor_callback import MonitorUtility
from .has_monitor_callback import ResultsMonitor


class Saver:
@@ -170,7 +170,7 @@ class TopkQueue:
return self.topk != 0


class TopkSaver(MonitorUtility, Saver):
class TopkSaver(ResultsMonitor, Saver):
def __init__(self, topk:int=0, monitor:str=None, larger_better:bool=True, folder:str=None, save_object:str='model',
only_state_dict:bool=True, model_save_fn:Callable=None, save_evaluate_results:bool=True,
**kwargs):
@@ -196,7 +196,7 @@ class TopkSaver(MonitorUtility, Saver):
fastnlp_evaluate_results.json 文件,记录当前的 results。仅在设置了 topk 的场景下有用,默认为 True 。
:param kwargs: 更多需要传递给 Trainer.save() 或者 Trainer.save_model() 接口的参数。
"""
MonitorUtility.__init__(self, monitor, larger_better)
ResultsMonitor.__init__(self, monitor, larger_better)
Saver.__init__(self, folder, save_object, only_state_dict, model_save_fn, **kwargs)

if monitor is not None and topk == 0:


+ 14
- 44
fastNLP/core/controllers/evaluator.py View File

@@ -8,10 +8,10 @@ __all__ = [
]

from fastNLP.core.drivers import Driver
from fastNLP.core.drivers.utils import choose_driver
from ..drivers.choose_driver import choose_driver
from .loops import Loop, EvaluateBatchLoop
from fastNLP.core.utils import auto_param_call, dataclass_to_dict, \
match_and_substitute_params, f_rich_progress
match_and_substitute_params, f_rich_progress, flat_nest_dict
from fastNLP.core.metrics import Metric
from fastNLP.core.metrics.utils import _is_torchmetrics_metric, _is_paddle_metric, _is_allennlp_metric
from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader
@@ -162,13 +162,15 @@ class Evaluator:
self.cur_dataloader_name = dataloader_name
results = self.evaluate_batch_loop.run(self, dataloader)
self.remove_progress_bar(dataloader_name)
metric_results.update(results)
metric_results[dataloader_name] = results
self.reset()
self.driver.barrier()
except BaseException as e:
raise e
finally:
self.finally_progress_bar()

metric_results = flat_nest_dict(metric_results, separator=self.separator, compress_none_key=True, top_down=False)
self.driver.set_model_mode(mode='train')
if self.verbose:
if self.progress_bar == 'rich':
@@ -251,14 +253,13 @@ class Evaluator:
"""
self.metrics_wrapper.update(batch, outputs)

def get_dataloader_metric(self, dataloader_name: Optional[str] = '') -> Dict:
def get_metric(self) -> Dict:
"""
获取当前dataloader的metric结果
调用所有 metric 的 get_metric 方法,并返回结果。其中 key 为 metric 的名称,value 是各个 metric 的结果。

:param str dataloader_name: 当前dataloader的名字
:return:
"""
return self.metrics_wrapper.get_metric(dataloader_name=dataloader_name, separator=self.separator)
return self.metrics_wrapper.get_metric()

@property
def metrics_wrapper(self):
@@ -366,15 +367,12 @@ class _MetricsWrapper:
elif _is_torchmetrics_metric(metric) or _is_paddle_metric(metric) or isinstance(metric, Metric):
metric.reset()

def get_metric(self, dataloader_name: str, separator: str) -> Dict:
def get_metric(self) -> Dict:
"""
将所有 metric 结果展平到一个一级的字典中,这个字典中 key 的命名规则是
indicator_name{separator}metric_name{separator}dataloader_name
例如: f1#F1PreRec#dev
调用各个 metric 得到 metric 的结果。并使用 {'metric_name1': metric_results, 'metric_name2': metric_results} 的形式
返回。

:param dataloader_name: 当前metric对应的dataloader的名字。若为空,则不显示在最终的key上面。
:param separator: 用于间隔不同称呼。
:return: 返回一个一级结构的字典,其中 key 为区别一个 metric 的名字,value 为该 metric 的值;
:return:
"""
results = {}
for metric_name, metric in zip(self._metric_names, self._metrics):
@@ -384,37 +382,9 @@ class _MetricsWrapper:
_results = metric.get_metric(reset=False)
elif _is_torchmetrics_metric(metric):
_results = metric.compute()
# 我们规定了 evaluator 中的 metrics 的输入只能是一个 dict,这样如果 metric 是一个 torchmetrics 时,如果 evaluator
# 没有传入 func_post_proc,那么我们就自动使用该 metric 的 metric name 当做其的 indicator name 将其自动转换成一个字典;
elif _is_paddle_metric(metric):
_results = metric.accumulate()
if not isinstance(_results, Dict):
name = _get_metric_res_name(dataloader_name, metric_name, '', separator)
results[name] = _results
else:
for indicator_name, value in _results.items():
name = _get_metric_res_name(dataloader_name, metric_name, indicator_name, separator)
results[name] = value

raise RuntimeError(f"Not support `{type(metric)}` for now.")
results[metric_name] = _results
return results


def _get_metric_res_name(dataloader_name: Optional[str], metric_name: str, indicator_name: str, separator='#') -> str:
"""

:param dataloader_name: dataloder的名字
:param metric_name: metric的名字
:param indicator_name: metric中的各项metric名称,例如f, precision, recall
:param separator: 用以间隔不同对象的间隔符
:return:
"""
names = []
if indicator_name:
names.append(indicator_name)
if metric_name:
names.append(metric_name)
if dataloader_name:
names.append(dataloader_name)
if len(names) == 0:
raise RuntimeError("You cannot use empty `dataloader_name`, `metric_name`, and `monitor` simultaneously.")
return separator.join(names)

+ 2
- 2
fastNLP/core/controllers/loops/evaluate_batch_loop.py View File

@@ -40,8 +40,8 @@ class EvaluateBatchLoop(Loop):
self.batch_step_fn(evaluator, batch)
batch_idx += 1
evaluator.update_progress_bar(batch_idx, evaluator.cur_dataloader_name)
# 获取metric结果。返回的dict内容示例为{'f1#F1Metric#dl1': 0.93, 'pre#F1Metric#dl1': 0.95, ...}
results = evaluator.get_dataloader_metric(dataloader_name=evaluator.cur_dataloader_name)
# 获取metric结果。返回的dict内容示例为{'metric_name1': metric_results, 'metric_name2': metric_results, ...}
results = evaluator.get_metric()
return results

@staticmethod


+ 1
- 1
fastNLP/core/controllers/trainer.py View File

@@ -23,7 +23,7 @@ from fastNLP.core.callbacks.callback import _CallbackWrapper
from fastNLP.core.callbacks.callback_manager import prepare_callbacks
from fastNLP.core.callbacks.callback_event import Event
from fastNLP.core.drivers import Driver
from fastNLP.core.drivers.utils import choose_driver
from ..drivers.choose_driver import choose_driver
from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext
from fastNLP.core.utils.utils import _check_valid_parameters_number
from fastNLP.envs import rank_zero_call


+ 31
- 0
fastNLP/core/drivers/choose_driver.py View File

@@ -0,0 +1,31 @@
from typing import Union, Optional, List

from .driver import Driver


def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, List[int], str]], **kwargs) -> Driver:
r"""
根据输入的参数 'gpus' 的格式来决定具体的工作模式;

:param model: 运行过程中使用的具体的最原始的模型;
:param driver: 应当为字符串或者 `Driver` 实例,表示运行中具体使用的训练/评测模式;
:param device: 具体的形式请参见 `fastNLP.core.drivers.torch_driver.utils.initialize_torch_dirver` 的注释;
:param kwargs: 其余的传给 `Driver` 的参数;
"""

# 如果用户直接传进来一个 driver 实例,我们就直接返回回去,目前用户需要自己保证传进来的 driver 的正确性;
if isinstance(driver, Driver):
return driver

if driver in {"torch", "torch_ddp", "fairscale"}:
from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver
return initialize_torch_driver(driver, device, model, **kwargs)
elif driver in {"jittor"}:
from fastNLP.core.drivers.jittor_driver.initialize_jittor_driver import initialize_jittor_driver
return initialize_jittor_driver(driver, device, model, **kwargs)
elif driver in {"paddle", "fleet"}:
from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver
return initialize_paddle_driver(driver, device, model, **kwargs)
else:
raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'torch_ddp', 'fairscale', "
"'jittor', 'paddle', 'fleet'].")

+ 4
- 3
fastNLP/core/drivers/jittor_driver/jittor_driver.py View File

@@ -33,11 +33,12 @@ class JittorDriver(Driver):
f"`jittor.Module` type.")
super(JittorDriver, self).__init__(model)

self.model = model

self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16)
self.grad_scaler = _grad_scaler()

# 用来设置是否关闭 auto_param_call 中的参数匹配问题;
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False)

@staticmethod
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
# 在fastnlp中实现了JittorDataLoader
@@ -152,4 +153,4 @@ class JittorDriver(Driver):
# def set_sampler_epoch(self, dataloader: JittorDataLoader, cur_epoch_idx):
# # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的;
# if callable(getattr(dataloader.batch_sampler, "set_epoch", None)):
# dataloader.batch_sampler.set_epoch(cur_epoch_idx)
# dataloader.batch_sampler.set_epoch(cur_epoch_idx)

+ 8
- 2
fastNLP/core/drivers/jittor_driver/single_device.py View File

@@ -60,8 +60,8 @@ class JittorSingleDriver(JittorDriver):
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...')
return fn, None
elif fn in {"train_step", "evaluate_step"}:
logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...')
return self.model, self.model.forward
logger.debug(f'Use {_get_fun_msg(self.model.execute, with_fp=False)}...')
return self.model, self.model.execute
else:
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")

@@ -98,3 +98,9 @@ class JittorSingleDriver(JittorDriver):
return dataloader
else:
return dataloader

def setup(self):
"""
使用单个 GPU 时,jittor 底层自动实现调配,无需额外操作
"""
pass

+ 1
- 34
fastNLP/core/drivers/utils.py View File

@@ -1,38 +1,5 @@
from typing import Optional
from typing import Union, List
from typing import List
import subprocess
from pathlib import Path

from fastNLP.core.drivers.driver import Driver

__all__ = []

def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, List[int], str]], **kwargs) -> Driver:
r"""
根据输入的参数 'gpus' 的格式来决定具体的工作模式;

:param model: 运行过程中使用的具体的最原始的模型;
:param driver: 应当为字符串或者 `Driver` 实例,表示运行中具体使用的训练/评测模式;
:param device: 具体的形式请参见 `fastNLP.core.drivers.torch_driver.utils.initialize_torch_dirver` 的注释;
:param kwargs: 其余的传给 `Driver` 的参数;
"""

# 如果用户直接传进来一个 driver 实例,我们就直接返回回去,目前用户需要自己保证传进来的 driver 的正确性;
if isinstance(driver, Driver):
return driver

if driver in {"torch", "torch_ddp", "fairscale"}:
from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver
return initialize_torch_driver(driver, device, model, **kwargs)
elif driver in {"jittor"}:
from fastNLP.core.drivers.jittor_driver.initialize_jittor_driver import initialize_jittor_driver
return initialize_jittor_driver(driver, device, model, **kwargs)
elif driver in {"paddle", "fleet"}:
from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver
return initialize_paddle_driver(driver, device, model, **kwargs)
else:
raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'torch_ddp', 'fairscale', "
"'jittor', 'paddle', 'fleet'].")


def distributed_open_proc(output_from_new_proc:str, command:List[str], env_copy:dict, rank:int=None):


+ 2
- 3
fastNLP/core/utils/__init__.py View File

@@ -24,6 +24,7 @@ __all__ = [
'Option',
'deprecated',
'seq_len_to_mask',
"flat_nest_dict"
]

from .cache_results import cache_results
@@ -33,8 +34,6 @@ from .paddle_utils import get_device_from_visible, paddle_to, paddle_move_data_t
from .rich_progress import f_rich_progress
from .torch_paddle_utils import torch_paddle_move_data_to_device
from .torch_utils import torch_move_data_to_device
from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \
dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \
deprecated, seq_len_to_mask
from .utils import *



+ 53
- 1
fastNLP/core/utils/utils.py View File

@@ -35,6 +35,7 @@ __all__ = [
'Option',
'deprecated',
'seq_len_to_mask',
"flat_nest_dict"
]


@@ -645,4 +646,55 @@ def is_notebook():
except:
return False
else: # pragma: no cover
return True
return True


def flat_nest_dict(d:Dict, separator:str='#', compress_none_key:bool=True, top_down:bool=False) -> Dict:
"""
讲一个 nested 的 dict 转成 flat 的 dict,例如
ex::
d = {'test': {'f1': {'f': 0.2, 'rec': 0.1}}} -> {'f#f1#test':0.2, 'rec#f1#test':0.1}

:param d: 需要展平的 dict 对象。
:param separator: 不同层级之间的 key 之间的连接符号。
:param compress_none_key: 如果有 key 为 None ,则忽略这一层连接。
:param top_down: 新的 key 的是否按照从最底层往最底层的顺序连接。
:return:
"""
assert isinstance(d, Dict)
assert isinstance(separator, str)
flat_d = {}
for key, value in d.items():
if key is None:
key = ()
else:
key = (key, )
if isinstance(value, Mapping):
flat_d.update(_flat_nest_dict(value, parent_key=key, compress_none_key=compress_none_key))
else:
flat_d[key] = value

str_flat_d = {}
for key, value in flat_d.items():
if top_down:
key = map(str, key)
else:
key = map(str, key[::-1])
key = separator.join(key)
str_flat_d[key] = value
return str_flat_d


def _flat_nest_dict(d:Mapping, parent_key:Tuple, compress_none_key:bool):
flat_d = {}
for k, v in d.items():
_key = parent_key
if k is not None:
_key = _key + (k,)
if isinstance(v, Mapping):
_d = _flat_nest_dict(v, parent_key=_key, compress_none_key=compress_none_key)
flat_d.update(_d)
else:
flat_d[_key] = v

return flat_d

+ 133
- 0
tests/core/controllers/test_trainer_jittor.py View File

@@ -0,0 +1,133 @@
import pytest

from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.controllers.trainer import Evaluator
from fastNLP.core.metrics.accuracy import Accuracy
from fastNLP.core.callbacks.progress_callback import RichCallback
from fastNLP.core.dataloaders.jittor_dataloader.fdl import JittorDataLoader
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR

if _NEED_IMPORT_JITTOR:
import jittor as jt
from jittor import nn, Module
from jittor.dataset import Dataset


class JittorNormalModel_Classification(Module):
"""
基础的 Jittor 分类模型
"""

def __init__(self, num_labels, feature_dimension):
super(JittorNormalModel_Classification, self).__init__()
self.num_labels = num_labels

self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64)
self.ac1 = nn.ReLU()
self.linear2 = nn.Linear(in_features=64, out_features=32)
self.ac2 = nn.ReLU()
self.output = nn.Linear(in_features=32, out_features=num_labels)
self.loss_fn = nn.CrossEntropyLoss()

def execute(self, x):
# It's similar to forward function in Pytorch
x = self.ac1(self.linear1(x))
x = self.ac2(self.linear2(x))
x = self.output(x)
return x

def train_step(self, x, y):
x = self(x)
return {"loss": self.loss_fn(x, y)}

def evaluate_step(self, x, y):
x = self(x)
return {"pred": x, "target": y.reshape((-1,))}


class JittorRandomMaxDataset(Dataset):
def __init__(self, num_samples, num_features):
super(JittorRandomMaxDataset, self).__init__()
self.x = jt.randn((num_samples, num_features))
self.y = self.x.argmax(dim=1)[0]

def __len__(self):
return len(self.y)

def __getitem__(self, item):
return {"x": self.x[item], "y": self.y[item]}


class TrainJittorConfig:
num_labels: int = 5
feature_dimension: int = 5
lr = 1e-1
batch_size: int = 4
shuffle: bool = True


@pytest.mark.parametrize("driver,device", [("jittor", None)])
@pytest.mark.parametrize("callbacks", [[RichCallback(100)]])
def test_trainer_jittor(
driver,
device,
callbacks,
n_epochs=3,
):
model = JittorNormalModel_Classification(
num_labels=TrainJittorConfig.num_labels,
feature_dimension=TrainJittorConfig.feature_dimension
)
optimizer = nn.SGD(model.parameters(), lr=TrainJittorConfig.lr)
train_dataloader = JittorDataLoader(
dataset=JittorRandomMaxDataset(1000, TrainJittorConfig.feature_dimension),
batch_size=TrainJittorConfig.batch_size,
shuffle=True,
# num_workers=4,
)
val_dataloader = JittorDataLoader(
dataset=JittorRandomMaxDataset(500, TrainJittorConfig.feature_dimension),
batch_size=TrainJittorConfig.batch_size,
shuffle=True,
# num_workers=4,
)
test_dataloader = JittorDataLoader(
dataset=JittorRandomMaxDataset(1000, TrainJittorConfig.feature_dimension),
batch_size=TrainJittorConfig.batch_size,
shuffle=True,
# num_workers=4,
)
metrics = {"acc": Accuracy()}

trainer = Trainer(
model=model,
driver=driver,
device=device,
optimizers=optimizer,
train_dataloader=train_dataloader,
evaluate_dataloaders=val_dataloader,
validate_every=-1,
evaluate_fn="evaluate_step",
input_mapping=None,
output_mapping=None,
metrics=metrics,
n_epochs=n_epochs,
callbacks=callbacks,
# progress_bar="rich"
)
trainer.run()

evaluator = Evaluator(
model=model,
driver=driver,
dataloaders=test_dataloader,
evaluate_fn="evaluate_step",
metrics=metrics,
)
metric_results = evaluator.run()
assert metric_results["acc#acc"] > 0.80


if __name__ == "__main__":
# test_trainer_jittor("jittor", None, [RichCallback(100)])
pytest.main(['test_trainer_jittor.py']) # 只运行此模块

+ 2
- 2
tests/core/controllers/test_trainer_w_evaluator_torch.py View File

@@ -174,7 +174,7 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps(
dist.destroy_process_group()

@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", 1)]) # ("torch", [0, 1]),("torch", 1)
@pytest.mark.parametrize("driver,device", [("torch", 'cpu')]) # ("torch", [0, 1]),("torch", 1)
@magic_argv_env_context
def test_trainer_validate_every(
model_and_optimizers: TrainerParameters,
@@ -234,7 +234,7 @@ def test_trainer_on(
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
evaluate_dataloaders={"dl":model_and_optimizers.evaluate_dataloaders},
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,


Loading…
Cancel
Save