@@ -24,7 +24,6 @@ from fastNLP.core.dataset import DataSet as FDataSet | |||
class _JittorDataset(Dataset): | |||
""" | |||
对用户传的dataset进行封装,以便JittorDataLoader能够支持使用自定义的dataset | |||
""" | |||
def __init__(self, dataset) -> None: | |||
@@ -83,7 +82,7 @@ class JittorDataLoader: | |||
# TODO 验证支持replacesampler (以后完成) 增加Sampler | |||
# 将内部dataset批次设置为1 | |||
if isinstance(dataset, Dataset): | |||
dataset.set_attrs(batch_size=1) | |||
dataset.set_attrs(batch_size=1, shuffle=False, endless=False) | |||
# FastNLP Datset, collate_fn not None | |||
if isinstance(dataset, FDataSet) and collate_fn is None: | |||
@@ -115,6 +114,12 @@ class JittorDataLoader: | |||
self.cur_batch_indices = None | |||
def __getattr__(self, attr): | |||
if attr in ["batch_size", "shuffle", "drop_last", "num_workers", "buffer_size", "stop_grad", | |||
"keep_numpy_array", "endless", "sampler"]: | |||
return getattr(self.dataset, attr) | |||
raise AttributeError(f"{self} has not attribute '{attr}'") | |||
def __iter__(self): | |||
# TODO 第一次迭代后不能设置collate_fn,设置是无效的 | |||
if self.cur_batch_indices is None: | |||
@@ -10,7 +10,7 @@ if _NEED_IMPORT_JITTOR: | |||
__all__ = [] | |||
def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], model: jittor.Module, **kwargs) -> JittorDriver: | |||
def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], model: "jittor.Module", **kwargs) -> JittorDriver: | |||
r""" | |||
用来根据参数 ``device`` 来确定并且初始化一个具体的 ``Driver`` 实例然后返回回去。 | |||
@@ -30,7 +30,7 @@ def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], mo | |||
raise ValueError("Parameter `driver` can only be one of these values: ['jittor'].") | |||
# TODO 实现更详细的判断 | |||
if device in ["cpu", "gpu", "cuda", "cuda:0", 0, None]: | |||
if device in ["cpu", "gpu", "cuda", None]: | |||
return JittorSingleDriver(model, device, **kwargs) | |||
elif type(device) is int: | |||
return JittorMPIDriver(model, device, **kwargs) | |||
@@ -1,23 +1,31 @@ | |||
import os | |||
import random | |||
from pathlib import Path | |||
from typing import Union, Optional | |||
from functools import partial | |||
import numpy as np | |||
from typing import Union, Optional, Dict | |||
from contextlib import nullcontext | |||
from dataclasses import dataclass | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
from fastNLP.core.drivers.driver import Driver | |||
from fastNLP.core.dataloaders import JittorDataLoader | |||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.utils import apply_to_collection | |||
from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_SEED_WORKERS | |||
from fastNLP.envs import ( | |||
FASTNLP_MODEL_FILENAME, | |||
FASTNLP_CHECKPOINT_FILENAME, | |||
) | |||
if _NEED_IMPORT_JITTOR: | |||
import jittor as jt | |||
from jittor import Module | |||
from jittor.optim import Optimizer | |||
from jittor.dataset import Dataset | |||
from jittor.dataset import ( | |||
BatchSampler as JittorBatchSampler, | |||
Sampler as JittorSampler, | |||
RandomSampler as JittorRandomSampler, | |||
SequentialSampler as JittorSequentialSampler | |||
) | |||
_reduces = { | |||
'max': jt.max, | |||
@@ -56,6 +64,7 @@ class JittorDriver(Driver): | |||
else: | |||
jt.flags.auto_mixed_precision_level = 0 | |||
self.fp16 = fp16 | |||
self._auto_cast = nullcontext | |||
# 用来设置是否关闭 auto_param_call 中的参数匹配问题; | |||
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | |||
@@ -68,7 +77,7 @@ class JittorDriver(Driver): | |||
def _check_optimizer_legality(optimizers): | |||
for each_optimizer in optimizers: | |||
if not isinstance(each_optimizer, Optimizer): | |||
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'jittor.optim.Optimizer' type, " | |||
raise TypeError(f"Each optimizer of parameter `optimizers` should be 'jittor.optim.Optimizer' type, " | |||
f"not {type(each_optimizer)}.") | |||
def step(self): | |||
@@ -117,30 +126,118 @@ class JittorDriver(Driver): | |||
model = self.unwrap_model() | |||
model.load(filepath) | |||
def save_checkpoint(self): | |||
... | |||
def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||
dataloader_args = self.get_dataloader_args(dataloader) | |||
if dataloader_args.sampler: | |||
sampler = dataloader_args.sampler | |||
else: | |||
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | |||
num_consumed_batches = states.pop('num_consumed_batches') | |||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | |||
sampler_states = sampler.state_dict() | |||
# 需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||
# 会造成多余实际消耗的问题。因为 | |||
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||
if num_consumed_samples_array is not None: | |||
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||
if dataloader_args.batch_size is not None: | |||
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||
else: # 有可能 batch_size 为 None,就只有损失精度了 | |||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
"it may cause missing some samples when reload.") | |||
num_consumed_batches = sampler_states['num_consumed_samples'] | |||
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||
else: | |||
if dataloader_args.batch_size is not None: | |||
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||
* num_consumed_batches | |||
else: | |||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
"it may cause missing some samples when reload.") | |||
states['sampler_states'] = sampler_states | |||
else: | |||
raise RuntimeError('The sampler has no `state_dict()` method, fastNLP cannot save the training ' | |||
'state.') | |||
# 2. 保存模型的状态; | |||
if should_save_model: | |||
if not os.path.exists(folder): | |||
os.mkdir(folder) | |||
model_path = folder.joinpath(FASTNLP_MODEL_FILENAME) | |||
self.save_model(model_path, only_state_dict=only_state_dict) | |||
# 3. 保存 optimizers 的状态; | |||
states["optimizers_state_dict"] = self.get_optimizer_state() | |||
# 4. 保存fp16的状态 | |||
logger.debug("Save optimizer state dict") | |||
jt.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) | |||
def get_optimizer_state(self): | |||
# optimizers_state_dict = {} | |||
# for i in range(len(self.optimizers)): | |||
# optimizer: torch.optim.Optimizer = self.optimizers[i] | |||
# optimizer_state = optimizer.state_dict() | |||
# optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu")) | |||
# optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | |||
# return optimizers_state_dict | |||
... | |||
optimizers_state_dict = {} | |||
for i in range(len(self.optimizers)): | |||
optimizer: Optimizer = self.optimizers[i] | |||
optimizers_state_dict[f"optimizer{i}"] = optimizer.state_dict() # 注意这里没有使用 deepcopy,测试是不需要的; | |||
return optimizers_state_dict | |||
def load_optimizer_state(self, states): | |||
# assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \ | |||
# f"checkpoint it is:{len(states)}" | |||
# for i in range(len(self.optimizers)): | |||
# optimizer: torch.optim.Optimizer = self.optimizers[i] | |||
# optimizer.load_state_dict(states[f"optimizer{i}"]) | |||
# logger.debug("Load optimizer state dict.") | |||
... | |||
assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \ | |||
f"checkpoint it is:{len(states)}" | |||
for i in range(len(self.optimizers)): | |||
optimizer: Optimizer = self.optimizers[i] | |||
optimizer.load_state_dict(states[f"optimizer{i}"]) | |||
logger.debug("Load optimizer state dict.") | |||
def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | |||
states = jt.load(str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) | |||
# 1. 加载 optimizers 的状态; | |||
optimizers_state_dict = states.pop("optimizers_state_dict") | |||
self.load_optimizer_state(optimizers_state_dict) | |||
# 2. 加载模型状态; | |||
if should_load_model: | |||
self.load_model(filepath=folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict=only_state_dict) | |||
# 3. 加载fp16的状态 | |||
# 4. 恢复 sampler 的状态; | |||
dataloader_args = self.get_dataloader_args(dataloader) | |||
if dataloader_args.sampler is None: | |||
sampler = RandomSampler(dataloader_args.sampler.dataset, shuffle=dataloader_args.shuffle) | |||
elif isinstance(dataloader_args.sampler, ReproducibleSampler): | |||
sampler = dataloader_args.sampler | |||
elif isinstance(dataloader_args.sampler, JittorRandomSampler): | |||
sampler = RandomSampler(dataloader_args.sampler.dataset) | |||
logger.debug("Replace jittor RandomSampler into fastNLP RandomSampler.") | |||
elif isinstance(dataloader_args.sampler, JittorSequentialSampler): | |||
sampler = RandomSampler(dataloader_args.sampler.dataset, shuffle=False) | |||
logger.debug("Replace jittor Sampler into fastNLP RandomSampler without shuffle.") | |||
elif self.is_distributed(): | |||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our" | |||
"`ReproducibleSampler`.") | |||
else: | |||
raise RuntimeError(f"Jittor sampler {type(dataloader_args.sampler)} is not supported now.") | |||
sampler.load_state_dict(states.pop('sampler_states')) | |||
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) | |||
# 4. 修改 trainer_state.batch_idx_in_epoch | |||
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | |||
if dataloader_args.drop_last: | |||
batch_idx_in_epoch = len( | |||
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size | |||
else: | |||
batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ | |||
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size | |||
def load_checkpoint(self): | |||
... | |||
states["batch_idx_in_epoch"] = batch_idx_in_epoch | |||
return states | |||
def get_evaluate_context(self): | |||
return jt.no_grad | |||
@@ -198,26 +295,8 @@ class JittorDriver(Driver): | |||
""" | |||
return batch | |||
@staticmethod | |||
def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover | |||
global_rank = rank if rank is not None else int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) | |||
process_seed = jt.get_seed() | |||
# back out the base seed so we can use all the bits | |||
base_seed = process_seed - worker_id | |||
ss = np.random.SeedSequence([base_seed, worker_id, global_rank]) | |||
# use 128 bits (4 x 32-bit words) | |||
np.random.seed(ss.generate_state(4)) | |||
# Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module | |||
jittor_ss, stdlib_ss = ss.spawn(2) | |||
jt.set_global_seed(jittor_ss.generate_state(1, dtype=np.uint64)[0]) | |||
# use 128 bits expressed as an integer | |||
stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum() | |||
random.seed(stdlib_seed) | |||
def set_deterministic_dataloader(self, dataloader: Union["JittorDataLoader", "Dataset"]): | |||
if int(os.environ.get(FASTNLP_SEED_WORKERS, 0)) and dataloader.worker_init_fn is None: | |||
dataloader.worker_init_fn = partial(self.worker_init_function, | |||
rank=int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))) | |||
... | |||
def set_sampler_epoch(self, dataloader: Union["JittorDataLoader", "Dataset"], cur_epoch_idx: int): | |||
# 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; | |||
@@ -226,4 +305,45 @@ class JittorDriver(Driver): | |||
@staticmethod | |||
def get_dataloader_args(dataloader: Union["JittorDataLoader", "Dataset"]): | |||
pass | |||
@dataclass | |||
class Res: | |||
dataset: Optional[Dataset] = None | |||
batch_sampler: Optional[JittorBatchSampler] = None | |||
sampler: Optional[JittorSampler] = None | |||
batch_size: Optional[int] = None | |||
shuffle: Optional[bool] = None | |||
drop_last: Optional[bool] = None | |||
res = Res() | |||
from fastNLP.core.dataloaders.jittor_dataloader.fdl import _JittorDataset | |||
if isinstance(dataloader, JittorDataLoader): | |||
# JittorDataLoader 实际上是迭代 dataset 成员的 | |||
dataloader = dataloader.dataset | |||
if isinstance(dataloader, _JittorDataset): | |||
# 获取最原始的 dataset | |||
res.dataset = dataloader.dataset | |||
else: | |||
res.dataset = dataloader | |||
# jittor 现在不支持 batch_sampler,所以除了 shuffle 都可以直接获取 | |||
res.batch_size = dataloader.batch_size | |||
res.drop_last = dataloader.drop_last | |||
if dataloader.sampler is None: | |||
# sampler 是 None,那么就从 Dataset 的属性中获取 | |||
res.shuffle = dataloader.shuffle | |||
elif isinstance(list(dataloader.sampler.__iter__())[0], (list,tuple)): | |||
# jittor 目前不支持 batch_sampler | |||
raise NotImplementedError("Jittor does not support using batch_sampler in `Dataset` now, " | |||
"please check if you have set `Dataset.sampler` as `BatchSampler`") | |||
else: | |||
# sampler 不为 None | |||
res.sampler = dataloader.sampler | |||
if hasattr(dataloader.sampler, "shuffle"): | |||
# 这种情况一般出现在 fastNLP 的 ReproduceSampler 中 | |||
res.shuffle = dataloader.sampler.shuffle | |||
elif isinstance(dataloader.sampler, JittorRandomSampler): | |||
res.shuffle = True | |||
else: | |||
res.shuffle = False | |||
return res |
@@ -38,6 +38,7 @@ class JittorMPIDriver(JittorDriver): | |||
): | |||
super(JittorMPIDriver, self).__init__(model, fp16=fp16, **kwargs) | |||
raise NotImplementedError("MPI for Jittor is not supported right now.") | |||
self.is_pull_by_jittor_run = is_pull_by_jittor_run | |||
self.parallel_device = parallel_device | |||
@@ -100,22 +101,6 @@ class JittorMPIDriver(JittorDriver): | |||
return self._data_device | |||
return self.parallel_device | |||
def step(self): | |||
# for optimizer in self.optimizers: | |||
# self.grad_scaler.step(optimizer) | |||
# self.grad_scaler.update() | |||
for optimizer in self.optimizers: | |||
optimizer.step() | |||
def backward(self, loss): | |||
# self.grad_scaler.scale(loss).backward() | |||
for optimizer in self.optimizers: | |||
optimizer.backward(loss) | |||
def zero_grad(self): | |||
for optimizer in self.optimizers: | |||
optimizer.zero_grad() | |||
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | |||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||
return auto_param_call(fn, batch, signature_fn=signature_fn) | |||
@@ -1,14 +1,21 @@ | |||
from typing import Dict, Union, Tuple, Callable, Optional | |||
from .jittor_driver import JittorDriver | |||
from .utils import replace_batch_sampler, replace_sampler | |||
from fastNLP.core.utils import auto_param_call | |||
from fastNLP.core.utils.utils import _get_fun_msg | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, \ | |||
ReproduceBatchSampler | |||
from fastNLP.core.samplers import RandomSampler | |||
from fastNLP.core.log import logger | |||
if _NEED_IMPORT_JITTOR: | |||
import jittor as jt | |||
from jittor.dataset import ( | |||
RandomSampler as JittorRandomSampler, | |||
SequentialSampler as JittorSequentialSampler, | |||
) | |||
__all__ = [ | |||
"JittorSingleDriver", | |||
@@ -89,31 +96,46 @@ class JittorSingleDriver(JittorDriver): | |||
""" | |||
return False | |||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], | |||
reproducible: bool = False, sampler_or_batch_sampler=None): | |||
# reproducible 的相关功能暂时没有实现 | |||
def set_dist_repro_dataloader(self, dataloader, | |||
dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler] = None, | |||
reproducible: bool = False): | |||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load_checkpoint 函数调用; | |||
if isinstance(dist, ReproducibleBatchSampler): | |||
raise NotImplementedError | |||
dataloader.batch_sampler = dist_sample | |||
if isinstance(dist, ReproducibleSampler): | |||
raise NotImplementedError | |||
dataloader.batch_sampler.sampler = dist | |||
return replace_batch_sampler(dataloader, dist) | |||
elif isinstance(dist, ReproducibleSampler): | |||
return replace_sampler(dataloader, dist) | |||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | |||
args = self.get_dataloader_args(dataloader) | |||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||
batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||
return replace_batch_sampler(dataloader, batch_sampler) | |||
elif isinstance(args.sampler, ReproducibleSampler): | |||
sampler = re_instantiate_sampler(args.sampler) | |||
return replace_sampler(dataloader, sampler) | |||
if reproducible: | |||
raise NotImplementedError | |||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): | |||
return dataloader | |||
elif isinstance(dataloader.batch_sampler, RandomBatchSampler): | |||
return dataloader | |||
else: | |||
# TODO | |||
batch_sampler = RandomBatchSampler( | |||
batch_sampler=dataloader.batch_sampler, | |||
batch_size=dataloader.batch_sampler.batch_size, | |||
drop_last=dataloader.drop_last | |||
) | |||
dataloader.batch_sampler = batch_sampler | |||
return dataloader | |||
if args.sampler is None: | |||
sampler = RandomSampler(args.dataset, args.shuffle) | |||
return replace_sampler(dataloader, sampler) | |||
elif isinstance(args.sampler, JittorRandomSampler): | |||
if getattr(args.sampler, '_num_samples', None) is None \ | |||
and getattr(args.sampler, 'rep', False) is False: | |||
# 如果本来就是随机的,并且没有定制,直接替换掉吧。 | |||
sampler = RandomSampler(args.sampler.dataset, shuffle=True) | |||
logger.debug("Replace jittor RandomSampler into fastNLP RandomSampler.") | |||
return replace_sampler(dataloader, sampler) | |||
elif isinstance(args.sampler, JittorSequentialSampler): | |||
# 需要替换为不要 shuffle 的。 | |||
sampler = RandomSampler(args.sampler.dataset, shuffle=False) | |||
logger.debug("Replace jittor SequentialSampler into fastNLP RandomSampler.") | |||
return replace_sampler(dataloader, sampler) | |||
batch_sampler = ReproduceBatchSampler( | |||
batch_sampler=args.batch_sampler, | |||
batch_size=args.batch_size, | |||
drop_last=args.drop_last | |||
) | |||
return replace_batch_sampler(dataloader, batch_sampler) | |||
else: | |||
return dataloader | |||
@@ -1,6 +1,29 @@ | |||
import inspect | |||
from copy import deepcopy | |||
from typing import Union | |||
from fastNLP.core.dataloaders import JittorDataLoader | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
if _NEED_IMPORT_JITTOR: | |||
import jittor | |||
from jittor.dataset import Dataset | |||
__all__ = [] | |||
def replace_batch_sampler(dataloader, batch_sampler): | |||
raise NotImplementedError("Jittor does not support using batch_sampler in `Dataset` now, " | |||
"please check if you have set `Dataset.sampler` as `BatchSampler`" | |||
"or report this bug to us.") | |||
def replace_sampler(dataloader: Union["Dataset", "JittorDataLoader"], sampler): | |||
if isinstance(dataloader, JittorDataLoader): | |||
init_params = dict(inspect.signature(dataloader.__init__).parameters) | |||
reconstruct_args = {name: getattr(dataloader, name, p.default) for name, p in init_params.items()} | |||
reconstruct_args["dataset"] = replace_sampler(reconstruct_args["dataset"].dataset, reconstruct_args["dataset"].sampler) | |||
new_dataloader = type(dataloader)(**reconstruct_args) | |||
new_dataloader.dataset.set_attrs(sampler=sampler) | |||
else: | |||
new_dataloader = deepcopy(dataloader) | |||
new_dataloader.set_attrs(sampler=sampler) | |||
return new_dataloader |
@@ -31,7 +31,6 @@ if _NEED_IMPORT_PADDLE: | |||
import paddle | |||
from paddle.io import ( | |||
DataLoader, | |||
IterableDataset, | |||
Dataset, | |||
Sampler, | |||
BatchSampler, | |||
@@ -97,6 +96,9 @@ class PaddleDriver(Driver): | |||
def check_dataloader_legality(self, dataloader): | |||
if not isinstance(dataloader, DataLoader): | |||
raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") | |||
if dataloader.batch_size is None and dataloader.batch_sampler is None: | |||
raise ValueError("Please ensure at least one of your dataloader's batch_size and batch_sampler" | |||
"is not None") | |||
@staticmethod | |||
def _check_optimizer_legality(optimizers): | |||
@@ -107,7 +109,7 @@ class PaddleDriver(Driver): | |||
""" | |||
for each_optimizer in optimizers: | |||
if not isinstance(each_optimizer, Optimizer): | |||
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, " | |||
raise TypeError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, " | |||
f"not {type(each_optimizer)}.") | |||
@staticmethod | |||
@@ -263,9 +265,7 @@ class PaddleDriver(Driver): | |||
optimizers_state_dict = {} | |||
for i in range(len(self.optimizers)): | |||
optimizer: Optimizer = self.optimizers[i] | |||
optimizer_state = optimizer.state_dict() | |||
optimizer_state["state"] = optimizer_state_to_device(optimizer_state, "cpu") | |||
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | |||
optimizers_state_dict[f"optimizer{i}"] = optimizer_state_to_device(optimizer.state_dict(), "cpu") | |||
return optimizers_state_dict | |||
@@ -399,6 +399,8 @@ class PaddleDriver(Driver): | |||
def set_sampler_epoch(self, dataloader: "DataLoader", cur_epoch_idx): | |||
if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): | |||
dataloader.batch_sampler.set_epoch(cur_epoch_idx) | |||
elif callable(getattr(dataloader.batch_sampler.sampler, "set_epoch", None)): | |||
dataloader.batch_sampler.sampler.set_epoch(cur_epoch_idx) | |||
@staticmethod | |||
def get_dataloader_args(dataloader: "DataLoader"): | |||
@@ -99,7 +99,7 @@ class TorchDriver(Driver): | |||
def _check_optimizer_legality(optimizers): | |||
for each_optimizer in optimizers: | |||
if not isinstance(each_optimizer, Optimizer): | |||
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' type, " | |||
raise TypeError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' type, " | |||
f"not {type(each_optimizer)}.") | |||
@staticmethod | |||
@@ -210,7 +210,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
self.num_consumed_samples = 0 | |||
self.during_iter = True | |||
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||
indices = list(range(self.num_samples)) | |||
if self.shuffle: | |||
if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的 | |||
@@ -237,7 +237,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
if len(indices)%self.batch_size!=0: | |||
batches.append(indices[_num_batches*self.batch_size:]) | |||
need_pad_num = (getattr(self.dataset, 'total_len', len(self.dataset))-self.num_consumed_samples) % self.num_replicas | |||
need_pad_num = (self.num_samples-self.num_consumed_samples) % self.num_replicas | |||
if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: | |||
if len(batches) > 0: | |||
if len(batches[-1])<self.batch_size: | |||
@@ -290,9 +290,9 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
@property | |||
def batch_idx_in_epoch(self): | |||
if self.drop_last: | |||
return getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||
return self.num_samples // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||
else: | |||
return (getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||
return (self.num_samples // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||
(self.num_left_samples + self.batch_size - 1) // self.batch_size | |||
@property | |||
@@ -313,8 +313,12 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
:return: | |||
""" | |||
num_consumed_samples = self.num_consumed_samples | |||
return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \ | |||
self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas)) | |||
return math.ceil((self.num_samples - num_consumed_samples) / self.num_replicas) if \ | |||
self.pad else math.floor(((self.num_samples - num_consumed_samples) / self.num_replicas)) | |||
@property | |||
def num_samples(self): | |||
return getattr(self.dataset, 'total_len', len(self.dataset)) | |||
def __len__(self)->int: | |||
""" | |||
@@ -332,7 +336,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | |||
" consumed. ") | |||
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | |||
'sampler_type': self.__class__.__name__, 'length': getattr(self.dataset, 'total_len', len(self.dataset)), 'shuffle': self.shuffle, | |||
'sampler_type': self.__class__.__name__, 'length': self.num_samples, 'shuffle': self.shuffle, | |||
'batch_size': self.batch_size, | |||
'num_replicas': self.num_replicas} | |||
@@ -347,7 +351,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
f"we cannot use {self.__class__.__name__} to load it." | |||
length = states['length'] | |||
assert length == getattr(self.dataset, 'total_len', len(self.dataset)), "The number of samples is different between the checkpoint record " \ | |||
assert length == self.num_samples, "The number of samples is different between the checkpoint record " \ | |||
"and current dataset." | |||
self.seed = states['seed'] | |||
self.epoch = states['epoch'] | |||
@@ -464,8 +468,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
:return: | |||
""" | |||
num_consumed_samples = self.num_consumed_samples | |||
return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \ | |||
self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas)) | |||
return math.ceil((self.num_samples - num_consumed_samples) / self.num_replicas) if \ | |||
self.pad else math.floor(((self.num_samples - num_consumed_samples) / self.num_replicas)) | |||
@property | |||
def num_samples(self): | |||
return getattr(self.dataset, 'total_len', len(self.dataset)) | |||
def __len__(self)->int: | |||
""" | |||
@@ -515,7 +523,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
if len(sorted_indices)%self.batch_size!=0: | |||
batches.append(sorted_indices[_num_batches*self.batch_size:]) | |||
need_pad_num = (getattr(self.dataset, 'total_len', len(self.dataset))-self.num_consumed_samples) % self.num_replicas | |||
need_pad_num = (self.num_samples-self.num_consumed_samples) % self.num_replicas | |||
if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: | |||
if len(batches) > 0: | |||
if len(batches[-1])<self.batch_size: | |||
@@ -593,7 +601,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | |||
" consumed. ") | |||
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | |||
'sampler_type': self.__class__.__name__, 'length': getattr(self.dataset, 'total_len', len(self.dataset)), 'shuffle': self.shuffle, | |||
'sampler_type': self.__class__.__name__, 'length': self.num_samples, 'shuffle': self.shuffle, | |||
'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket, | |||
'num_replicas': self.num_replicas | |||
} | |||
@@ -609,7 +617,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
f"we cannot use {self.__class__.__name__} to load it." | |||
length = states['length'] | |||
assert length == getattr(self.dataset, 'total_len', len(self.dataset)), "The number of samples is different between the checkpoint record " \ | |||
assert length == self.num_samples, "The number of samples is different between the checkpoint record " \ | |||
"and current dataset." | |||
self.seed = states['seed'] | |||
self.epoch = states['epoch'] | |||
@@ -630,7 +638,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
@property | |||
def batch_idx_in_epoch(self): | |||
if self.drop_last: | |||
return getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||
return self.num_samples // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||
else: | |||
return (getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||
return (self.num_samples // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||
(self.num_left_samples + self.batch_size - 1) // self.batch_size |
@@ -48,6 +48,10 @@ class ReproducibleSampler: | |||
def num_left_samples(self): | |||
raise NotImplementedError("Each specific sampler should implement its own `num_left_samples` method.") | |||
@property | |||
def num_samples(self): | |||
raise NotImplementedError("Each specific sampler should implement its own `num_samples` method.") | |||
def set_epoch(self, epoch): | |||
pass | |||
@@ -131,19 +135,19 @@ class RandomSampler(ReproducibleSampler): | |||
:return: | |||
""" | |||
if self.shuffle: | |||
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||
indices = list(range(self.num_samples)) | |||
seed = self.seed + self.epoch | |||
rng = np.random.default_rng(abs(seed)) | |||
rng.shuffle(indices) | |||
if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 | |||
self.epoch -= 1 | |||
else: | |||
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||
indices = list(range(self.num_samples)) | |||
return indices | |||
def state_dict(self) -> Dict: | |||
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | |||
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle} | |||
'sampler_type': self.__class__.__name__, 'length': self.num_samples, 'shuffle': self.shuffle} | |||
return states | |||
def load_state_dict(self, states: Dict): | |||
@@ -155,8 +159,8 @@ class RandomSampler(ReproducibleSampler): | |||
f"we cannot use {self.__class__.__name__} to load it." | |||
length = states['length'] | |||
assert length == getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of samples is different between the checkpoint record({length}) " \ | |||
f"and current dataset({getattr(self.dataset, 'total_len', len(self.dataset))})." | |||
assert length == self.num_samples, "The number of samples is different between the checkpoint " \ | |||
f"record({length}) and current dataset({self.num_samples})." | |||
self.seed = states['seed'] | |||
self.epoch = states['epoch'] | |||
self.num_consumed_samples = states['num_consumed_samples'] | |||
@@ -208,9 +212,17 @@ class RandomSampler(ReproducibleSampler): | |||
:return: | |||
""" | |||
num_consumed_samples = self.num_consumed_samples | |||
return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \ | |||
self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas)) | |||
return math.ceil((self.num_samples - num_consumed_samples) / self.num_replicas) if \ | |||
self.pad else math.floor(((self.num_samples - num_consumed_samples) / self.num_replicas)) | |||
@property | |||
def num_samples(self): | |||
""" | |||
返回样本的总数 | |||
:return: | |||
""" | |||
return getattr(self.dataset, 'total_len', len(self.dataset)) | |||
class SequentialSampler(RandomSampler): | |||
""" | |||
@@ -258,12 +270,10 @@ class SequentialSampler(RandomSampler): | |||
:return: | |||
""" | |||
return list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||
return list(range(self.num_samples)) | |||
def state_dict(self) -> Dict: | |||
states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, | |||
'length': getattr(self.dataset, 'total_len', len(self.dataset)) | |||
} | |||
states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, 'length': self.num_samples} | |||
return states | |||
def load_state_dict(self, states: Dict): | |||
@@ -275,8 +285,8 @@ class SequentialSampler(RandomSampler): | |||
f"we cannot use {self.__class__.__name__} to load it." | |||
length = states['length'] | |||
assert length == getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of samples is different between the checkpoint record({length}) " \ | |||
f"and current dataset({getattr(self.dataset, 'total_len', len(self.dataset))})." | |||
assert length == self.num_samples, "The number of samples is different between the checkpoint " \ | |||
f"record({length}) and current dataset({self.num_samples})." | |||
self.num_consumed_samples = states['num_consumed_samples'] | |||
if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | |||
self.num_consumed_samples = 0 | |||
@@ -314,9 +324,9 @@ class SortedSampler(SequentialSampler): | |||
except BaseException as e: | |||
logger.error(f"Cannot use {self.__class__.__name__} as length, since it is not sortable.") | |||
assert len(length) == getattr(self.dataset, 'total_len', len(self.dataset)), f"The length of `dataset`({len(dataset)}) and " \ | |||
f"`length`({getattr(self.dataset, 'total_len', len(self.dataset))}) should be equal." | |||
assert len(self.sorted_indices) == getattr(self.dataset, 'total_len', len(self.dataset)), "The indices and dataset should have equal length." | |||
assert len(length) == self.num_samples, f"The length of `dataset`({len(dataset)}) and " \ | |||
f"`length`({self.num_samples}) should be equal." | |||
assert len(self.sorted_indices) == self.num_samples, "The indices and dataset should have equal length." | |||
self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | |||
self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的 | |||
@@ -42,8 +42,8 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||
返回 sampler 一次完整的迭代过程会产生多少个index。多卡的情况下,只考虑当前rank; | |||
:return: | |||
""" | |||
num_common = getattr(self.dataset, 'total_len', len(self.dataset))//self.num_replicas | |||
num_samples = num_common + int(self.rank < (getattr(self.dataset, 'total_len', len(self.dataset))-num_common*self.num_replicas)) | |||
num_common = self.num_samples//self.num_replicas | |||
num_samples = num_common + int(self.rank < (self.num_samples-num_common*self.num_replicas)) | |||
return num_samples | |||
def __iter__(self): | |||
@@ -63,14 +63,14 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||
:return: | |||
""" | |||
if self.shuffle: | |||
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||
indices = list(range(self.num_samples)) | |||
seed = self.seed + self.epoch | |||
rng = np.random.default_rng(abs(seed)) | |||
rng.shuffle(indices) | |||
if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 | |||
self.epoch -= 1 | |||
else: | |||
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||
indices = list(range(self.num_samples)) | |||
return indices | |||
def set_epoch(self, epoch: int) -> None: | |||
@@ -84,8 +84,8 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||
:param rank: | |||
:return: | |||
""" | |||
assert num_replicas<=getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of replicas({num_replicas}) should be lesser than the " \ | |||
f"number of samples({getattr(self.dataset, 'total_len', len(self.dataset))})." | |||
assert num_replicas<=self.num_samples, f"The number of replicas({num_replicas}) should be lesser than the " \ | |||
f"number of samples({self.num_samples})." | |||
assert num_replicas>0 and isinstance(num_replicas, int) | |||
assert isinstance(rank, int) and 0<=rank<num_replicas | |||
# 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态; | |||
@@ -94,6 +94,15 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||
return self | |||
@property | |||
def num_samples(self): | |||
""" | |||
返回样本的总数 | |||
:return: | |||
""" | |||
return getattr(self.dataset, 'total_len', len(self.dataset)) | |||
class UnrepeatedSortedSampler(UnrepeatedRandomSampler): | |||
""" | |||
@@ -147,5 +156,5 @@ class UnrepeatedSequentialSampler(UnrepeatedRandomSampler): | |||
yield index | |||
def generate_indices(self) -> List[int]: | |||
return list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||
return list(range(self.num_samples)) | |||
@@ -27,7 +27,7 @@ from paddle.optimizer import Adam | |||
from paddle.io import DataLoader | |||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||
from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset | |||
from tests.helpers.datasets.paddle_data import PaddleArgMaxDataset | |||
from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback | |||
@dataclass | |||
@@ -52,12 +52,12 @@ def test_trainer_fleet( | |||
optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001) | |||
train_dataloader = DataLoader( | |||
dataset=PaddleRandomMaxDataset(20, MNISTTrainFleetConfig.feature_dimension), | |||
dataset=PaddleArgMaxDataset(20, MNISTTrainFleetConfig.feature_dimension), | |||
batch_size=MNISTTrainFleetConfig.batch_size, | |||
shuffle=True | |||
) | |||
val_dataloader = DataLoader( | |||
dataset=PaddleRandomMaxDataset(12, MNISTTrainFleetConfig.feature_dimension), | |||
dataset=PaddleArgMaxDataset(12, MNISTTrainFleetConfig.feature_dimension), | |||
batch_size=MNISTTrainFleetConfig.batch_size, | |||
shuffle=True | |||
) | |||
@@ -24,7 +24,7 @@ from paddle.io import DataLoader | |||
import paddle.distributed.fleet as fleet | |||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_2 | |||
from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset | |||
from tests.helpers.datasets.paddle_data import PaddleArgMaxDataset | |||
from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback | |||
@dataclass | |||
@@ -54,12 +54,12 @@ def test_trainer_fleet( | |||
optimizers = fleet.distributed_optimizer(optimizers) | |||
train_dataloader = DataLoader( | |||
dataset=PaddleRandomMaxDataset(20, MNISTTrainFleetConfig.feature_dimension), | |||
dataset=PaddleArgMaxDataset(20, MNISTTrainFleetConfig.feature_dimension), | |||
batch_size=MNISTTrainFleetConfig.batch_size, | |||
shuffle=True | |||
) | |||
val_dataloader = DataLoader( | |||
dataset=PaddleRandomMaxDataset(12, MNISTTrainFleetConfig.feature_dimension), | |||
dataset=PaddleArgMaxDataset(12, MNISTTrainFleetConfig.feature_dimension), | |||
batch_size=MNISTTrainFleetConfig.batch_size, | |||
shuffle=True | |||
) | |||
@@ -46,8 +46,8 @@ class LSTM(Module): | |||
def init_hidden(self, x): | |||
# batch_first | |||
batch_size = x.shape[0] | |||
h0 = jt.randn(1, batch_size, hidden_size) | |||
c0 = jt.randn(1, batch_size, hidden_size) | |||
h0 = jt.randn(1, batch_size, self.hidden_size) | |||
c0 = jt.randn(1, batch_size, self.hidden_size) | |||
return h0, c0 | |||
@@ -1,4 +1,5 @@ | |||
import pytest | |||
from fastNLP.core.callbacks import callback | |||
from fastNLP.core.controllers.trainer import Trainer | |||
from fastNLP.core.controllers.trainer import Evaluator | |||
@@ -14,6 +15,7 @@ if _NEED_IMPORT_JITTOR: | |||
else: | |||
from fastNLP.core.utils.dummy_class import DummyClass as Module | |||
from fastNLP.core.utils.dummy_class import DummyClass as Dataset | |||
jt.flags.use_cuda=1 | |||
class JittorNormalModel_Classification(Module): | |||
@@ -68,11 +70,9 @@ class TrainJittorConfig: | |||
batch_size: int = 4 | |||
shuffle: bool = True | |||
@pytest.mark.parametrize("driver", ["jittor"]) | |||
@pytest.mark.parametrize("device", ["cpu", "gpu", "cuda:0"]) | |||
@pytest.mark.parametrize("device", ["cpu", "gpu", "cuda", None]) | |||
@pytest.mark.parametrize("callbacks", [[RichCallback(100)]]) | |||
@pytest.mark.jittor | |||
def test_trainer_jittor( | |||
driver, | |||
device, | |||
@@ -15,7 +15,7 @@ if _NEED_IMPORT_PADDLE: | |||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||
from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset | |||
from tests.helpers.datasets.paddle_data import PaddleArgMaxDataset | |||
from tests.helpers.utils import magic_argv_env_context | |||
@dataclass | |||
@@ -44,12 +44,12 @@ def test_trainer_paddle( | |||
) | |||
optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001) | |||
train_dataloader = DataLoader( | |||
dataset=PaddleRandomMaxDataset(20, TrainPaddleConfig.feature_dimension), | |||
dataset=PaddleArgMaxDataset(20, TrainPaddleConfig.feature_dimension), | |||
batch_size=TrainPaddleConfig.batch_size, | |||
shuffle=True | |||
) | |||
val_dataloader = DataLoader( | |||
dataset=PaddleRandomMaxDataset(12, TrainPaddleConfig.feature_dimension), | |||
dataset=PaddleArgMaxDataset(12, TrainPaddleConfig.feature_dimension), | |||
batch_size=TrainPaddleConfig.batch_size, | |||
shuffle=True | |||
) | |||
@@ -76,7 +76,7 @@ class TestPaddle: | |||
from paddle.io import Dataset | |||
import paddle | |||
class PaddleRandomMaxDataset(Dataset): | |||
class PaddleArgMaxDataset(Dataset): | |||
def __init__(self, num_samples, num_features): | |||
self.x = paddle.randn((num_samples, num_features)) | |||
self.y = self.x.argmax(axis=-1) | |||
@@ -87,7 +87,7 @@ class TestPaddle: | |||
def __getitem__(self, item): | |||
return {"x": self.x[item], "y": self.y[item]} | |||
ds = PaddleRandomMaxDataset(100, 2) | |||
ds = PaddleArgMaxDataset(100, 2) | |||
dl = DataLoader(ds, places=None, collate_fn=Collator(), batch_size=4) | |||
for batch in dl: | |||
print(batch) |
@@ -0,0 +1,45 @@ | |||
import pytest | |||
from fastNLP.core.drivers import JittorSingleDriver, JittorMPIDriver | |||
from fastNLP.core.drivers.jittor_driver.initialize_jittor_driver import initialize_jittor_driver | |||
from tests.helpers.models.jittor_model import JittorNormalModel_Classification_1 | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
if _NEED_IMPORT_JITTOR: | |||
import jittor as jt | |||
@pytest.mark.jittor | |||
def test_incorrect_driver(): | |||
model = JittorNormalModel_Classification_1(20, 10) | |||
with pytest.raises(ValueError): | |||
driver = initialize_jittor_driver("torch", 0, model) | |||
@pytest.mark.jittor | |||
@pytest.mark.parametrize( | |||
"device", | |||
["cpu", "gpu", None, "cuda"] | |||
) | |||
def test_get_single_device(device): | |||
""" | |||
测试正常情况下初始化 JittorSingleDriver 的情况 | |||
""" | |||
model = JittorNormalModel_Classification_1(20, 10) | |||
driver = initialize_jittor_driver("jittor", device, model) | |||
assert isinstance(driver, JittorSingleDriver) | |||
@pytest.mark.jittor | |||
@pytest.mark.parametrize( | |||
"device", | |||
[[0, 2, 3], 1, 2] | |||
) | |||
def test_get_mpi(device): | |||
""" | |||
测试 jittor 多卡的初始化情况 | |||
""" | |||
model = JittorNormalModel_Classification_1(20, 10) | |||
with pytest.raises(NotImplementedError): | |||
driver = initialize_jittor_driver("jittor", device, model) | |||
# assert isinstance(driver, JittorMPIDriver) |
@@ -1,99 +1,614 @@ | |||
import pytest | |||
import os | |||
from copy import deepcopy | |||
from pathlib import Path | |||
import numpy as np | |||
from fastNLP.core.drivers.jittor_driver.single_device import JittorSingleDriver | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
from fastNLP.core.drivers.jittor_driver import JittorSingleDriver | |||
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | |||
from fastNLP.core.dataloaders import JittorDataLoader | |||
from tests.helpers.models.jittor_model import JittorNormalModel_Classification_1 | |||
from tests.helpers.datasets.jittor_data import JittorNormalDataset, JittorNormalXYDataset | |||
from tests.helpers.datasets.torch_data import TorchNormalDataset | |||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||
from fastNLP.envs.distributed import rank_zero_rm | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR, _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_JITTOR: | |||
import jittor as jt # 将 jittor 引入 | |||
from jittor import nn, Module # 引入相关的模块 | |||
from jittor import init | |||
from jittor.dataset import MNIST | |||
else: | |||
from fastNLP.core.utils.dummy_class import DummyClass as Module | |||
import jittor as jt | |||
from jittor.dataset import ( | |||
BatchSampler as JittorBatchSampler, | |||
RandomSampler as JittorRandomSampler, | |||
SequentialSampler as JittorSequentialSampler, | |||
SubsetRandomSampler as JittorSubsetRandomSampler | |||
) | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
def get_dataloader(dataset, use_dataloader, sampler, batch_size, shuffle, drop_last=False): | |||
""" | |||
:param dataset: | |||
:param use_dataloader: 是否使用 JittorDataLoader 包裹 | |||
:param sampler: 使用 BatchSampler Samlper 还是不使用 Sampler | |||
""" | |||
if use_dataloader: | |||
dataloader = JittorDataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last) | |||
dataloader.dataset.set_attrs(sampler=sampler) | |||
else: | |||
dataloader = dataset | |||
dataloader.set_attrs(batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, sampler=sampler) | |||
class Model(Module): | |||
def __init__ (self): | |||
super (Model, self).__init__() | |||
self.conv1 = nn.Conv (3, 32, 3, 1) # no padding | |||
self.conv2 = nn.Conv (32, 64, 3, 1) | |||
self.bn = nn.BatchNorm(64) | |||
self.max_pool = nn.Pool (2, 2) | |||
self.relu = nn.Relu() | |||
self.fc1 = nn.Linear (64 * 12 * 12, 256) | |||
self.fc2 = nn.Linear (256, 10) | |||
def execute(self, x) : | |||
# it's simliar to forward function in Pytorch | |||
x = self.conv1 (x) | |||
x = self.relu (x) | |||
x = self.conv2 (x) | |||
x = self.bn (x) | |||
x = self.relu (x) | |||
x = self.max_pool (x) | |||
x = jt.reshape (x, [x.shape[0], -1]) | |||
x = self.fc1 (x) | |||
x = self.relu(x) | |||
x = self.fc2 (x) | |||
return x | |||
return dataloader | |||
############################################################################ | |||
# | |||
# 测试基类 JittorDrvier 中的一些简单函数 | |||
# | |||
############################################################################ | |||
class TestJittorDriverFunctions: | |||
""" | |||
使用 JittorSingleDriver 测试基类的函数 | |||
""" | |||
@classmethod | |||
def setup_class(self): | |||
model = JittorNormalModel_Classification_1(10, 32) | |||
self.driver = JittorSingleDriver(model, device="cpu") | |||
@pytest.mark.jittor | |||
def test_check_optimizers_legality(self): | |||
""" | |||
测试对合法的 optimizers 的检查 | |||
""" | |||
# 单个 optimizer | |||
optimizer = jt.optim.Adam( | |||
params=self.driver.model.parameters(), | |||
lr=0.01 | |||
) | |||
self.driver.set_optimizers(optimizer) | |||
# optimizer 列表 | |||
optimizers = [ | |||
jt.optim.Adam( | |||
params=self.driver.model.parameters(), | |||
lr=0.01 | |||
) for i in range(10) | |||
] | |||
self.driver.set_optimizers(optimizers) | |||
@pytest.mark.torchjittor | |||
def test_invalid_optimizers(self): | |||
""" | |||
测试传入非法的 optimizers | |||
""" | |||
# 单个 optimizer | |||
optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | |||
with pytest.raises(TypeError): | |||
self.driver.set_optimizers(optimizer) | |||
optimizers = [ | |||
torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | |||
] | |||
with pytest.raises(TypeError): | |||
self.driver.set_optimizers(optimizers) | |||
@pytest.mark.jittor | |||
def test_check_dataloader_legality(self): | |||
""" | |||
测试 check_dataloader_legality 函数的表现 | |||
""" | |||
# 使用 JittorDataLoader | |||
dataloader = JittorDataLoader(JittorNormalDataset()) | |||
self.driver.check_dataloader_legality(dataloader) | |||
# 使用 jittor.dataset.Dataset | |||
self.driver.check_dataloader_legality(JittorNormalDataset()) | |||
@pytest.mark.torchjittor | |||
def test_check_dataloader_legality_invalid(self): | |||
""" | |||
测试 check_dataloader_legality 函数传入其他类型的表现 | |||
""" | |||
# 创建 torch 的 dataloader | |||
dataloader = torch.utils.data.DataLoader( | |||
TorchNormalDataset(), | |||
batch_size=32, shuffle=True | |||
) | |||
with pytest.raises(TypeError): | |||
self.driver.check_dataloader_legality(dataloader) | |||
@pytest.mark.jittor | |||
def test_tensor_to_numeric(self): | |||
""" | |||
测试 tensor_to_numeric 函数 | |||
""" | |||
# 单个张量 | |||
tensor = jt.Var(3) | |||
res = JittorSingleDriver.tensor_to_numeric(tensor) | |||
assert res == 3 | |||
tensor = jt.rand(3, 4) | |||
res = JittorSingleDriver.tensor_to_numeric(tensor) | |||
assert res == tensor.tolist() | |||
# 张量list | |||
tensor_list = [jt.rand(6, 4, 2) for i in range(10)] | |||
res = JittorSingleDriver.tensor_to_numeric(tensor_list) | |||
assert isinstance(res, list) | |||
tensor_list = [t.tolist() for t in tensor_list] | |||
assert res == tensor_list | |||
# 张量tuple | |||
tensor_tuple = tuple([jt.rand(6, 4, 2) for i in range(10)]) | |||
res = JittorSingleDriver.tensor_to_numeric(tensor_tuple) | |||
assert isinstance(res, tuple) | |||
tensor_tuple = tuple([t.tolist() for t in tensor_tuple]) | |||
assert res == tensor_tuple | |||
# 张量dict | |||
tensor_dict = { | |||
"tensor": jt.rand(3, 4), | |||
"list": [jt.rand(6, 4, 2) for i in range(10)], | |||
"dict":{ | |||
"list": [jt.rand(6, 4, 2) for i in range(10)], | |||
"tensor": jt.rand(3, 4) | |||
}, | |||
"int": 2, | |||
"string": "test string" | |||
} | |||
res = JittorSingleDriver.tensor_to_numeric(tensor_dict) | |||
assert isinstance(res, dict) | |||
assert res["tensor"] == tensor_dict["tensor"].tolist() | |||
assert isinstance(res["list"], list) | |||
for r, d in zip(res["list"], tensor_dict["list"]): | |||
assert r == d.tolist() | |||
assert isinstance(res["int"], int) | |||
assert isinstance(res["string"], str) | |||
assert isinstance(res["dict"], dict) | |||
assert isinstance(res["dict"]["list"], list) | |||
for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]): | |||
assert r == d.tolist() | |||
assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() | |||
@pytest.mark.jittor | |||
def test_tensor_to_numeric_reduce(self): | |||
tensor = jt.Var([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) | |||
res_max = JittorSingleDriver.tensor_to_numeric(tensor, reduce="max") | |||
res_min = JittorSingleDriver.tensor_to_numeric(tensor, reduce="min") | |||
res_sum = JittorSingleDriver.tensor_to_numeric(tensor, reduce="sum") | |||
res_mean = JittorSingleDriver.tensor_to_numeric(tensor, reduce="mean") | |||
assert res_max == 6 | |||
assert res_min == 1 | |||
assert res_sum == 21 | |||
assert res_mean == 3.5 | |||
@pytest.mark.jittor | |||
def test_set_model_mode(self): | |||
""" | |||
测试 set_model_mode 函数 | |||
""" | |||
self.driver.set_model_mode("train") | |||
assert self.driver.model.is_training() | |||
self.driver.set_model_mode("eval") | |||
assert not self.driver.model.is_training() | |||
# 应该报错 | |||
with pytest.raises(AssertionError): | |||
self.driver.set_model_mode("test") | |||
@pytest.mark.jittor | |||
def test_move_model_to_device_cpu(self): | |||
""" | |||
测试 move_model_to_device 函数,仅测试能否运行 | |||
""" | |||
JittorSingleDriver.move_model_to_device(self.driver.model, "cpu") | |||
@pytest.mark.jittor | |||
def test_move_model_to_device_gpu(self): | |||
""" | |||
测试 move_model_to_device 函数,仅测试能否运行 | |||
""" | |||
JittorSingleDriver.move_model_to_device(self.driver.model, "gpu") | |||
@pytest.mark.jittor | |||
def test_set_deterministic_dataloader(self): | |||
""" | |||
测试 set_deterministic_dataloader,仅测试能否运行 | |||
""" | |||
# 先确保不影响运行 | |||
# TODO:正确性 | |||
dataloader = JittorDataLoader(JittorNormalDataset()) | |||
self.driver.set_deterministic_dataloader(dataloader) | |||
self.driver.set_deterministic_dataloader(JittorNormalDataset()) | |||
@pytest.mark.jittor | |||
def test_set_sampler_epoch(self): | |||
""" | |||
测试 set_sampler_epoch | |||
""" | |||
# 先确保不影响运行 | |||
# TODO:正确性 | |||
dataloader = JittorDataLoader(JittorNormalDataset()) | |||
self.driver.set_sampler_epoch(dataloader, 0) | |||
self.driver.set_sampler_epoch(JittorNormalDataset(), 0) | |||
@pytest.mark.jittor | |||
@pytest.mark.parametrize("batch_size", [16]) | |||
@pytest.mark.parametrize("shuffle", [True, False]) | |||
@pytest.mark.parametrize("drop_last", [True, False]) | |||
@pytest.mark.parametrize("use_dataloader", [True, False]) | |||
def test_get_dataloader_args(self, batch_size, shuffle, drop_last, use_dataloader): | |||
""" | |||
测试正常情况下 get_dataloader_args 的表现 | |||
""" | |||
dataloader = get_dataloader( | |||
JittorNormalDataset(), | |||
use_dataloader=use_dataloader, | |||
sampler=None, | |||
batch_size=batch_size, | |||
shuffle=shuffle, | |||
drop_last=drop_last | |||
) | |||
res = JittorSingleDriver.get_dataloader_args(dataloader) | |||
assert isinstance(res.dataset, JittorNormalDataset) | |||
assert res.sampler is None | |||
assert res.shuffle == shuffle | |||
assert res.batch_size == batch_size | |||
assert res.drop_last == drop_last | |||
@pytest.mark.jittor | |||
@pytest.mark.parametrize("batch_size", [16]) | |||
@pytest.mark.parametrize("shuffle", [True, False]) | |||
@pytest.mark.parametrize("drop_last", [True, False]) | |||
@pytest.mark.parametrize("use_dataloader", [True, False]) | |||
def test_get_dataloader_args_with_randomsampler(self, batch_size, shuffle, drop_last, use_dataloader): | |||
""" | |||
测试替换了 sampler 后 get_dataloader_args 的表现 | |||
""" | |||
dataset = JittorNormalDataset() | |||
dataloader = get_dataloader( | |||
dataset, | |||
use_dataloader=use_dataloader, | |||
batch_size=batch_size, | |||
sampler=RandomSampler(dataset, shuffle=shuffle), | |||
shuffle=shuffle, | |||
drop_last=drop_last | |||
) | |||
res = JittorSingleDriver.get_dataloader_args(dataloader) | |||
assert isinstance(res.dataset, JittorNormalDataset) | |||
assert isinstance(res.sampler, RandomSampler) | |||
assert res.shuffle == shuffle | |||
assert res.batch_size == batch_size | |||
assert res.drop_last == drop_last | |||
############################################################################ | |||
# | |||
# 测试 JittorSingleDrvier 中的一些简单函数 | |||
# | |||
############################################################################ | |||
@pytest.mark.jittor | |||
@pytest.mark.skip("Skip jittor tests now.") | |||
class TestSingleDevice: | |||
def test_on_gpu_without_fp16(self): | |||
# TODO get_dataloader | |||
batch_size = 64 | |||
learning_rate = 0.1 | |||
epochs = 5 | |||
losses = [] | |||
losses_idx = [] | |||
train_loader = MNIST(train=True, batch_size=batch_size, shuffle=True) | |||
val_loader = MNIST(train=False, batch_size=1, shuffle=False) | |||
model = Model() | |||
driver = JittorSingleDriver(model, device=[1]) | |||
optimizer = nn.SGD(model.parameters(), learning_rate) | |||
driver.set_optimizers(optimizer) | |||
for epoch in range(epochs): | |||
driver.set_model_mode("train") | |||
lens = len(train_loader) | |||
for batch_idx, (inputs, targets) in enumerate(train_loader): | |||
outputs =driver.train_step(inputs) | |||
loss = nn.cross_entropy_loss(outputs, targets) | |||
driver.backward(loss) | |||
driver.step() | |||
driver.zero_grad() | |||
losses.append(loss.data[0]) | |||
losses_idx.append(epoch * lens + batch_idx) | |||
test_loss = 0 | |||
correct = 0 | |||
total_acc = 0 | |||
total_num = 0 | |||
driver.set_model_mode("eval") | |||
for batch_idx, (inputs, targets) in enumerate(val_loader): | |||
batch_size = inputs.shape[0] | |||
outputs = driver.test_step(inputs) | |||
pred = np.argmax(outputs.data, axis=1) | |||
acc = np.sum(targets.data==pred) | |||
total_acc += acc | |||
total_num += batch_size | |||
acc = acc / batch_size | |||
assert total_acc / total_num > 0.95 | |||
def test_on_cpu_without_fp16(self): | |||
pass | |||
def test_on_gpu_with_fp16(self): | |||
pass | |||
class TestSingleDeviceFunction: | |||
""" | |||
测试其它函数的测试例 | |||
""" | |||
@classmethod | |||
def setup_class(cls): | |||
model = JittorNormalModel_Classification_1(10, 784) | |||
cls.driver = JittorSingleDriver(model, device="cpu") | |||
def test_unwrap_model(self): | |||
""" | |||
测试能否运行 | |||
""" | |||
res = self.driver.unwrap_model() | |||
assert res is self.driver.model | |||
def test_is_distributed(self): | |||
assert self.driver.is_distributed() == False | |||
def test_move_data_to_device(self): | |||
self.driver.move_data_to_device(jt.rand(32, 64)) | |||
############################################################################ | |||
# | |||
# 测试 set_dist_repro_dataloader 函数 | |||
# | |||
############################################################################ | |||
@pytest.mark.jittor | |||
class TestSetDistReproDataloader: | |||
""" | |||
专门测试 set_dist_repro_dataloader 函数的类 | |||
""" | |||
def setup_method(self): | |||
self.dataset = JittorNormalDataset(20) | |||
model = JittorNormalModel_Classification_1(10, 32) | |||
self.driver = JittorSingleDriver(model, device="cpu") | |||
@pytest.mark.parametrize("use_dataloader", [True, False]) | |||
def test_with_reproducible_false(self, use_dataloader): | |||
""" | |||
测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现 | |||
当dist为字符串时,此时应该返回原来的 dataloader | |||
""" | |||
dataloader = get_dataloader(self.dataset, use_dataloader, sampler=None, batch_size=2, shuffle=True) | |||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||
assert replaced_loader is dataloader | |||
@pytest.mark.parametrize("shuffle", [True, False]) | |||
@pytest.mark.parametrize("sampler", [None, "random", "sequential"]) | |||
@pytest.mark.parametrize("use_dataloader", [True, False]) | |||
def test_with_reproducible_true(self, shuffle, sampler, use_dataloader): | |||
""" | |||
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | |||
当dist为字符串时,此时应该返回新的 dataloader,会替换 sampler 为 RandomSampler | |||
""" | |||
if sampler == "random": | |||
sampler = JittorRandomSampler(self.dataset) | |||
_shuffle = True | |||
elif sampler == "sequential": | |||
sampler = JittorSequentialSampler(self.dataset) | |||
_shuffle = False | |||
else: | |||
_shuffle = shuffle | |||
dataloader = get_dataloader(self.dataset, use_dataloader, sampler=sampler, batch_size=2, shuffle=shuffle) | |||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) | |||
assert not (replaced_loader is dataloader) | |||
assert isinstance(replaced_loader.sampler, RandomSampler) | |||
assert replaced_loader.sampler.shuffle == _shuffle | |||
assert replaced_loader.batch_size == dataloader.batch_size | |||
assert replaced_loader.drop_last == dataloader.drop_last | |||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle, use_dataloader) | |||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||
@pytest.mark.parametrize("use_dataloader", [True, False]) | |||
def test_with_dist_batch_sampler(self, shuffle, use_dataloader): | |||
""" | |||
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler | |||
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler | |||
jittor 暂时不支持这种情况,会报错 | |||
""" | |||
dataloader = get_dataloader(self.dataset, use_dataloader, sampler=None, batch_size=2, shuffle=not shuffle) | |||
dist = ReproduceBatchSampler(JittorBatchSampler(JittorRandomSampler(self.dataset), 4, False), 4, False) | |||
with pytest.raises(RuntimeError): | |||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) | |||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||
@pytest.mark.parametrize("use_dataloader", [True, False]) | |||
def test_with_dist_sampler(self, shuffle, use_dataloader): | |||
""" | |||
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 | |||
应该返回新的 dataloader,并将 sampler 替换为 dist 对应的 Sampler | |||
""" | |||
dataloader = get_dataloader(self.dataset, use_dataloader, sampler=None, batch_size=2, shuffle=not shuffle) | |||
dist = RandomSampler(self.dataset, shuffle=shuffle) | |||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) | |||
assert not (replaced_loader is dataloader) | |||
assert isinstance(replaced_loader.sampler, RandomSampler) | |||
assert replaced_loader.sampler is dist | |||
assert replaced_loader.batch_size == dataloader.batch_size | |||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle, use_dataloader) | |||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||
@pytest.mark.parametrize("use_dataloader", [True, False]) | |||
def test_with_dataloader_reproducible_batch_sampler(self, shuffle, use_dataloader): | |||
""" | |||
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 | |||
应该返回新的 dataloader,且其余各项设置和原来相同 | |||
""" | |||
dataloader = get_dataloader( | |||
self.dataset, | |||
use_dataloader=use_dataloader, | |||
sampler=ReproduceBatchSampler( | |||
JittorBatchSampler(JittorRandomSampler(self.dataset), 4, False), | |||
batch_size=4, | |||
drop_last=False, | |||
), | |||
batch_size=4, | |||
shuffle=shuffle, | |||
) | |||
with pytest.raises(RuntimeError): | |||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||
@pytest.mark.parametrize("use_dataloader", [True, False]) | |||
def test_with_dataloader_reproducible_sampler(self, shuffle, use_dataloader): | |||
""" | |||
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 | |||
应该返回新的 dataloader,且其余各项设置和原来相同 | |||
""" | |||
dataloader = get_dataloader( | |||
self.dataset, | |||
use_dataloader=use_dataloader, | |||
sampler=RandomSampler(self.dataset, shuffle), | |||
batch_size=2, | |||
shuffle=shuffle, | |||
) | |||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||
assert not (replaced_loader is dataloader) | |||
assert not (replaced_loader.sampler is dataloader.sampler) | |||
assert isinstance(replaced_loader.sampler, RandomSampler) | |||
assert replaced_loader.batch_size == 2 | |||
assert replaced_loader.shuffle == shuffle | |||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle, use_dataloader) | |||
def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle, use_dataloader): | |||
""" | |||
测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确 | |||
""" | |||
# 迭代两个 batch | |||
num_consumed_batches = 2 | |||
already_seen_idx = set() | |||
replaced_loader.sampler.set_epoch(6) | |||
for idx, batch in enumerate(replaced_loader): | |||
if idx >= num_consumed_batches: | |||
break | |||
already_seen_idx.update(batch.tolist()) | |||
sampler_states = replaced_loader.sampler.state_dict() | |||
# 重新加载,应该可以输出剩下的内容,且对于 JittorNormalDataset 来说,排序后应该是一个 range | |||
left_idxes = set() | |||
batch_size = replaced_loader.batch_size | |||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | |||
# 重新构造 dataloader | |||
if use_dataloader: | |||
dataset = deepcopy(replaced_loader.dataset.dataset) | |||
else: | |||
dataset = deepcopy(replaced_loader) | |||
new_loader = get_dataloader( | |||
dataset=dataset, | |||
use_dataloader=use_dataloader, | |||
sampler = RandomSampler(dataset, shuffle=shuffle), | |||
batch_size=batch_size, | |||
shuffle=shuffle, | |||
drop_last=False | |||
) | |||
new_loader.sampler.load_state_dict(sampler_states) | |||
new_loader.sampler.set_epoch(6) | |||
for idx, batch in enumerate(new_loader): | |||
left_idxes.update(batch.tolist()) | |||
print(already_seen_idx) | |||
print(left_idxes) | |||
assert len(left_idxes) + len(already_seen_idx) == self.dataset.total_len | |||
assert len(left_idxes | already_seen_idx) == self.dataset.total_len | |||
############################################################################ | |||
# | |||
# 测试 save 和 load 相关的功能 | |||
# | |||
############################################################################ | |||
def generate_random_driver(labels, features, fp16=False, device="cpu", lr=0.01): | |||
""" | |||
生成driver | |||
""" | |||
model = JittorNormalModel_Classification_1(labels, features) | |||
opt = jt.optim.Adam(params=model.parameters(), lr=lr) | |||
driver = JittorSingleDriver(model, device=device, fp16=fp16) | |||
driver.set_optimizers(opt) | |||
driver.setup() | |||
return driver | |||
@pytest.mark.jittor | |||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||
@pytest.mark.parametrize("use_dataloader", [True, False]) | |||
def test_save_and_load_model(only_state_dict, use_dataloader): | |||
""" | |||
测试 save_model 和 load_model 函数 | |||
""" | |||
try: | |||
path = "model" | |||
dataset = JittorNormalXYDataset(20) | |||
dataloader = get_dataloader(dataset, sampler=None, use_dataloader=use_dataloader, batch_size=4, shuffle=True) | |||
driver1, driver2 = generate_random_driver(20, 1, device="gpu"), generate_random_driver(20, 1, device="gpu") | |||
driver1.save_model(path, only_state_dict) | |||
driver2.load_model(path, only_state_dict) | |||
for batch in dataloader: | |||
batch = driver1.move_data_to_device(batch) | |||
res1 = driver1.model.evaluate_step(**batch) | |||
res2 = driver2.model.evaluate_step(**batch) | |||
assert jt.all_(jt.equal(res1["pred"], res2["pred"])) | |||
finally: | |||
rank_zero_rm(path) | |||
@pytest.mark.jittor | |||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||
@pytest.mark.parametrize("use_dataloader", [True, False]) | |||
def test_save_and_load_with_randomsampler(only_state_dict, use_dataloader): | |||
""" | |||
测试save和load函数,主要测试 dataloader 被替换了 sampler 的情况 | |||
""" | |||
try: | |||
path = "model.ckp" | |||
driver1, driver2 = generate_random_driver(20, 1, device="gpu", lr=0.01), \ | |||
generate_random_driver(20, 1, device="gpu", lr=0.001) | |||
dataset = JittorNormalXYDataset(20) | |||
dataloader = get_dataloader( | |||
dataset, use_dataloader, | |||
sampler = RandomSampler(dataset, True), | |||
batch_size=4, | |||
shuffle=True | |||
) | |||
num_consumed_batches = 2 | |||
already_seen_x_set = set() | |||
already_seen_y_set = set() | |||
driver1.set_sampler_epoch(dataloader, 7) | |||
for idx, batch in enumerate(dataloader): | |||
if idx >= num_consumed_batches: | |||
break | |||
already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) | |||
already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) | |||
sampler_states = dataloader.sampler.state_dict() | |||
save_states = {"num_consumed_batches": num_consumed_batches} | |||
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||
# 加载 | |||
# 更改 batch_size | |||
dataloader = get_dataloader( | |||
dataset, use_dataloader, | |||
sampler=RandomSampler(dataset, True), | |||
batch_size=2, | |||
shuffle=True | |||
) | |||
load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||
replaced_loader = load_states.pop("dataloader") | |||
# 1. 检查 optimizer 的状态 | |||
assert driver2.optimizers[0].lr == driver1.optimizers[0].lr | |||
# 2. 检查 sampler 是否被正确地加载和替换 | |||
assert not (replaced_loader is dataloader) | |||
assert isinstance(replaced_loader.sampler, RandomSampler) | |||
assert replaced_loader.sampler.seed == sampler_states["seed"] | |||
assert replaced_loader.sampler.epoch == sampler_states["epoch"] | |||
assert replaced_loader.sampler.num_consumed_samples == 4 * num_consumed_batches | |||
assert replaced_loader.sampler.dataset.total_len == sampler_states["length"] | |||
assert replaced_loader.sampler.shuffle == sampler_states["shuffle"] | |||
# 4. 检查 model 的参数是否正确 | |||
# 5. 检查 batch_idx | |||
start_batch = load_states.pop('batch_idx_in_epoch') | |||
assert start_batch == 2 * num_consumed_batches | |||
left_x_batches = set() | |||
left_y_batches = set() | |||
driver2.set_sampler_epoch(replaced_loader, 7) | |||
for idx, batch in enumerate(replaced_loader): | |||
left_x_batches.update(batch["x"].reshape(-1, ).tolist()) | |||
left_y_batches.update(batch["y"].reshape(-1, ).tolist()) | |||
res1 = driver1.model.evaluate_step(**batch) | |||
res2 = driver2.model.evaluate_step(**batch) | |||
assert jt.all_(jt.equal(res1["pred"], res2["pred"])) | |||
assert len(left_x_batches) + len(already_seen_x_set) == dataset.total_len | |||
assert len(left_x_batches | already_seen_x_set) == dataset.total_len | |||
assert len(left_y_batches) + len(already_seen_y_set) == dataset.total_len | |||
assert len(left_y_batches | already_seen_y_set) == dataset.total_len | |||
finally: | |||
rank_zero_rm(path) |
@@ -0,0 +1,43 @@ | |||
import pytest | |||
from fastNLP.core.drivers.jittor_driver.utils import replace_sampler | |||
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | |||
from fastNLP.core.dataloaders import JittorDataLoader | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
if _NEED_IMPORT_JITTOR: | |||
import jittor as jt | |||
from tests.helpers.datasets.jittor_data import JittorNormalDataset | |||
@pytest.mark.jittor | |||
@pytest.mark.parametrize("dataset", [ | |||
JittorNormalDataset(20, batch_size=10, shuffle=True), | |||
JittorNormalDataset(20, batch_size=5, drop_last=True), | |||
JittorNormalDataset(20) | |||
]) | |||
def test_replace_sampler_dataset(dataset): | |||
dataset = JittorNormalDataset(20) | |||
sampler = RandomSampler(dataset) | |||
replaced_loader = replace_sampler(dataset, sampler) | |||
assert not (replaced_loader is dataset) | |||
assert isinstance(replaced_loader.sampler, RandomSampler) | |||
assert replaced_loader.batch_size == dataset.batch_size | |||
assert replaced_loader.drop_last == dataset.drop_last | |||
assert replaced_loader.shuffle == dataset.shuffle | |||
assert replaced_loader.total_len == dataset.total_len | |||
@pytest.mark.jittor | |||
def test_replace_sampler_jittordataloader(): | |||
dataset = JittorNormalDataset(20, batch_size=10, shuffle=True) | |||
dataloader = JittorDataLoader(dataset, batch_size=8, shuffle=True) | |||
sampler = RandomSampler(dataset) | |||
replaced_loader = replace_sampler(dataloader, sampler) | |||
assert not (replaced_loader is dataloader) | |||
assert not (replaced_loader.dataset.dataset is dataloader.dataset.dataset) | |||
assert isinstance(replaced_loader.sampler, RandomSampler) | |||
assert replaced_loader.batch_size == 8 | |||
assert replaced_loader.shuffle == True |
@@ -10,7 +10,7 @@ from fastNLP.core.samplers import ( | |||
UnrepeatedSequentialSampler, | |||
) | |||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset | |||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleNormalXYDataset | |||
from tests.helpers.utils import magic_argv_env_context | |||
from fastNLP.envs.distributed import rank_zero_rm | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
@@ -19,8 +19,8 @@ if _NEED_IMPORT_PADDLE: | |||
import paddle.distributed as dist | |||
from paddle.io import DataLoader, BatchSampler | |||
def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"): | |||
paddle_model = PaddleNormalModel_Classification_1(num_labels, feature_dimension) | |||
def generate_driver(labels, features, device=[0,1], fp16=False, output_from_new_proc="only_error"): | |||
paddle_model = PaddleNormalModel_Classification_1(labels, features) | |||
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||
driver = PaddleFleetDriver( | |||
model=paddle_model, | |||
@@ -465,10 +465,14 @@ class TestSetDistReproDataloader: | |||
num_replicas = len(self.device) | |||
num_consumed_batches = 2 | |||
already_seen_idx = set() | |||
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): | |||
sampler_states = replaced_loader.batch_sampler.set_epoch(10) | |||
else: | |||
sampler_states = replaced_loader.batch_sampler.sampler.set_epoch(10) | |||
for idx, batch in enumerate(replaced_loader): | |||
if idx >= num_consumed_batches: | |||
break | |||
already_seen_idx.update(batch) | |||
already_seen_idx.update(batch.tolist()) | |||
dist.barrier() | |||
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): | |||
sampler_states = replaced_loader.batch_sampler.state_dict() | |||
@@ -496,6 +500,7 @@ class TestSetDistReproDataloader: | |||
pad=True | |||
) | |||
new_loader.batch_sampler.load_state_dict(sampler_states) | |||
new_loader.batch_sampler.set_epoch(10) | |||
else: | |||
batch_size = replaced_loader.batch_sampler.batch_size | |||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas | |||
@@ -508,8 +513,9 @@ class TestSetDistReproDataloader: | |||
) | |||
new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler) | |||
new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | |||
new_loader.batch_sampler.sampler.set_epoch(10) | |||
for idx, batch in enumerate(new_loader): | |||
left_idxes.update(batch) | |||
left_idxes.update(batch.tolist()) | |||
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas | |||
assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas | |||
@@ -533,7 +539,7 @@ class TestSaveLoad: | |||
cls.driver = generate_driver(10, 10, device=[0,1]) | |||
def setup_method(self): | |||
self.dataset = PaddleRandomMaxDataset(20, 10) | |||
self.dataset = PaddleNormalXYDataset(40) | |||
@magic_argv_env_context | |||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||
@@ -545,12 +551,12 @@ class TestSaveLoad: | |||
path = "model" | |||
dataloader = DataLoader(self.dataset, batch_size=2) | |||
self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10) | |||
self.driver1, self.driver2 = generate_driver(40, 1), generate_driver(40, 1) | |||
if only_state_dict: | |||
self.driver1.save_model(path, only_state_dict) | |||
else: | |||
self.driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((4, 10))]) | |||
self.driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((4, 1))]) | |||
# 同步 | |||
dist.barrier() | |||
@@ -594,8 +600,8 @@ class TestSaveLoad: | |||
path = "model.ckp" | |||
num_replicas = len(device) | |||
self.driver1, self.driver2 = generate_driver(10, 10, device=device, fp16=fp16), \ | |||
generate_driver(10, 10, device=device, fp16=False) | |||
self.driver1, self.driver2 = generate_driver(40, 1, device=device, fp16=fp16), \ | |||
generate_driver(40, 1, device=device, fp16=False) | |||
dataloader = DataLoader( | |||
dataset=self.dataset, | |||
batch_sampler=BucketedBatchSampler( | |||
@@ -613,11 +619,12 @@ class TestSaveLoad: | |||
already_seen_x_set = set() | |||
already_seen_y_set = set() | |||
self.driver1.set_sampler_epoch(dataloader, 2) | |||
for idx, batch in enumerate(dataloader): | |||
if idx >= num_consumed_batches: | |||
break | |||
already_seen_x_set.update(batch["x"]) | |||
already_seen_y_set.update(batch["y"]) | |||
already_seen_x_set.update(batch["x"].reshape((-1, )).tolist()) | |||
already_seen_y_set.update(batch["y"].reshape((-1, )).tolist()) | |||
# 同步 | |||
dist.barrier() | |||
@@ -669,10 +676,11 @@ class TestSaveLoad: | |||
assert start_batch == 2 * num_consumed_batches | |||
left_x_batches = set() | |||
left_y_batches = set() | |||
self.driver2.set_sampler_epoch(replaced_loader, 2) | |||
for idx, batch in enumerate(replaced_loader): | |||
left_x_batches.update(batch["x"]) | |||
left_y_batches.update(batch["y"]) | |||
left_x_batches.update(batch["x"].reshape((-1, )).tolist()) | |||
left_y_batches.update(batch["y"].reshape((-1, )).tolist()) | |||
res1 = self.driver1.model( | |||
batch, | |||
fastnlp_fn=self.driver1.model._layers.model.evaluate_step, | |||
@@ -709,8 +717,8 @@ class TestSaveLoad: | |||
num_replicas = len(device) | |||
self.driver1 = generate_driver(10, 10, device=device, fp16=fp16) | |||
self.driver2 = generate_driver(10, 10, device=device, fp16=False) | |||
self.driver1 = generate_driver(40, 1, device=device, fp16=fp16) | |||
self.driver2 = generate_driver(40, 1, device=device, fp16=False) | |||
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4) | |||
batch_sampler.sampler = RandomSampler(self.dataset, True) | |||
batch_sampler.sampler.set_distributed( | |||
@@ -726,11 +734,12 @@ class TestSaveLoad: | |||
already_seen_x_set = set() | |||
already_seen_y_set = set() | |||
self.driver1.set_sampler_epoch(dataloader, 2) | |||
for idx, batch in enumerate(dataloader): | |||
if idx >= num_consumed_batches: | |||
break | |||
already_seen_x_set.update(batch["x"]) | |||
already_seen_y_set.update(batch["y"]) | |||
already_seen_x_set.update(batch["x"].reshape((-1, )).tolist()) | |||
already_seen_y_set.update(batch["y"].reshape((-1, )).tolist()) | |||
# 同步 | |||
dist.barrier() | |||
@@ -779,10 +788,11 @@ class TestSaveLoad: | |||
assert start_batch == 2 * num_consumed_batches | |||
left_x_batches = set() | |||
left_y_batches = set() | |||
self.driver2.set_sampler_epoch(replaced_loader, 2) | |||
for idx, batch in enumerate(replaced_loader): | |||
left_x_batches.update(batch["x"]) | |||
left_y_batches.update(batch["y"]) | |||
left_x_batches.update(batch["x"].reshape((-1, )).tolist()) | |||
left_y_batches.update(batch["y"].reshape((-1, )).tolist()) | |||
res1 = self.driver1.model( | |||
batch, | |||
fastnlp_fn=self.driver1.model._layers.model.evaluate_step, | |||
@@ -12,7 +12,7 @@ if _NEED_IMPORT_PADDLE: | |||
@pytest.mark.paddle | |||
def test_incorrect_driver(): | |||
model = PaddleNormalModel_Classification_1(2, 100) | |||
model = PaddleNormalModel_Classification_1(20, 10) | |||
with pytest.raises(ValueError): | |||
driver = initialize_paddle_driver("torch", 0, model) | |||
@@ -26,7 +26,7 @@ def test_get_single_device(device): | |||
测试正常情况下初始化 PaddleSingleDriver 的情况 | |||
""" | |||
model = PaddleNormalModel_Classification_1(2, 100) | |||
model = PaddleNormalModel_Classification_1(20, 10) | |||
driver = initialize_paddle_driver("paddle", device, model) | |||
assert isinstance(driver, PaddleSingleDriver) | |||
@@ -41,7 +41,7 @@ def test_get_fleet(device): | |||
测试 fleet 多卡的初始化情况 | |||
""" | |||
model = PaddleNormalModel_Classification_1(64, 10) | |||
model = PaddleNormalModel_Classification_1(20, 10) | |||
driver = initialize_paddle_driver("paddle", device, model) | |||
assert isinstance(driver, PaddleFleetDriver) | |||
@@ -56,6 +56,6 @@ def test_device_out_of_range(device): | |||
""" | |||
测试传入的device超过范围的情况 | |||
""" | |||
model = PaddleNormalModel_Classification_1(2, 100) | |||
model = PaddleNormalModel_Classification_1(20, 10) | |||
with pytest.raises(ValueError): | |||
driver = initialize_paddle_driver("paddle", device, model) |
@@ -4,14 +4,16 @@ from pathlib import Path | |||
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | |||
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | |||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset | |||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleNormalXYDataset | |||
from tests.helpers.datasets.torch_data import TorchNormalDataset | |||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||
from fastNLP.envs.distributed import rank_zero_rm | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_PADDLE: | |||
import paddle | |||
from paddle.io import DataLoader, BatchSampler | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
@@ -31,102 +33,70 @@ class TestPaddleDriverFunctions: | |||
model = PaddleNormalModel_Classification_1(10, 32) | |||
self.driver = PaddleSingleDriver(model, device="cpu") | |||
@pytest.mark.torchpaddle | |||
def test_check_single_optimizer_legality(self): | |||
@pytest.mark.paddle | |||
def test_check_optimizers_legality(self): | |||
""" | |||
测试传入单个 optimizer 时的表现 | |||
测试对合法的 optimizers 的检查 | |||
""" | |||
# 单个 optimizer | |||
optimizer = paddle.optimizer.Adam( | |||
parameters=self.driver.model.parameters(), | |||
learning_rate=0.01 | |||
) | |||
self.driver.set_optimizers(optimizer) | |||
optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | |||
# 传入torch的optimizer时,应该报错ValueError | |||
with pytest.raises(ValueError): | |||
self.driver.set_optimizers(optimizer) | |||
@pytest.mark.torchpaddle | |||
def test_check_optimizers_legality(self): | |||
""" | |||
测试传入 optimizer list 的表现 | |||
""" | |||
# optimizer 列表 | |||
optimizers = [ | |||
paddle.optimizer.Adam( | |||
parameters=self.driver.model.parameters(), | |||
learning_rate=0.01 | |||
) for i in range(10) | |||
] | |||
self.driver.set_optimizers(optimizers) | |||
optimizers += [ | |||
@pytest.mark.torchpaddle | |||
def test_invalid_optimizers(self): | |||
""" | |||
测试传入非法的 optimizers | |||
""" | |||
# 单个 optimizer | |||
optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | |||
with pytest.raises(TypeError): | |||
self.driver.set_optimizers(optimizer) | |||
optimizers = [ | |||
torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | |||
] | |||
with pytest.raises(ValueError): | |||
with pytest.raises(TypeError): | |||
self.driver.set_optimizers(optimizers) | |||
@pytest.mark.torchpaddle | |||
def test_check_dataloader_legality_in_train(self): | |||
@pytest.mark.paddle | |||
def test_check_dataloader_legality(self): | |||
""" | |||
测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 | |||
测试 check_dataloader_legality 函数的表现 | |||
""" | |||
dataloader = DataLoader(PaddleNormalDataset()) | |||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
self.driver.check_dataloader_legality(dataloader) | |||
# batch_size 和 batch_sampler 均为 None 的情形 | |||
dataloader = DataLoader(PaddleNormalDataset(), batch_size=None) | |||
with pytest.raises(ValueError): | |||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
# 创建torch的dataloader | |||
dataloader = torch.utils.data.DataLoader( | |||
TorchNormalDataset(), | |||
batch_size=32, shuffle=True | |||
) | |||
with pytest.raises(ValueError): | |||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
self.driver.check_dataloader_legality(dataloader) | |||
@pytest.mark.torchpaddle | |||
def test_check_dataloader_legality_in_test(self): | |||
def test_check_dataloader_legality_invalid(self): | |||
""" | |||
测试 `is_train` 参数为 False 时,_check_dataloader_legality 函数的表现 | |||
测试 check_dataloader_legality 函数传入其他类型的表现 | |||
""" | |||
# 此时传入的应该是dict | |||
dataloader = { | |||
"train": DataLoader(PaddleNormalDataset()), | |||
"test":DataLoader(PaddleNormalDataset()) | |||
} | |||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
# batch_size 和 batch_sampler 均为 None 的情形 | |||
dataloader = { | |||
"train": DataLoader(PaddleNormalDataset()), | |||
"test":DataLoader(PaddleNormalDataset(), batch_size=None) | |||
} | |||
with pytest.raises(ValueError): | |||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
# 传入的不是 dict ,应该报错 | |||
dataloader = DataLoader(PaddleNormalDataset()) | |||
with pytest.raises(ValueError): | |||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
# 创建 torch 的 dataloader | |||
train_loader = torch.utils.data.DataLoader( | |||
TorchNormalDataset(), | |||
batch_size=32, shuffle=True | |||
) | |||
test_loader = torch.utils.data.DataLoader( | |||
dataloader = torch.utils.data.DataLoader( | |||
TorchNormalDataset(), | |||
batch_size=32, shuffle=True | |||
) | |||
dataloader = {"train": train_loader, "test": test_loader} | |||
with pytest.raises(ValueError): | |||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
with pytest.raises(TypeError): | |||
self.driver.check_dataloader_legality(dataloader) | |||
@pytest.mark.paddle | |||
def test_tensor_to_numeric(self): | |||
@@ -505,10 +475,14 @@ class TestSetDistReproDataloader: | |||
# 迭代两个 batch | |||
num_consumed_batches = 2 | |||
already_seen_idx = set() | |||
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | |||
sampler_states = replaced_loader.batch_sampler.set_epoch(5) | |||
else: | |||
sampler_states = replaced_loader.batch_sampler.sampler.set_epoch(5) | |||
for idx, batch in enumerate(replaced_loader): | |||
if idx >= num_consumed_batches: | |||
break | |||
already_seen_idx.update(batch) | |||
already_seen_idx.update(batch.tolist()) | |||
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | |||
sampler_states = replaced_loader.batch_sampler.state_dict() | |||
else: | |||
@@ -529,6 +503,7 @@ class TestSetDistReproDataloader: | |||
) | |||
) | |||
new_loader.batch_sampler.load_state_dict(sampler_states) | |||
new_loader.batch_sampler.set_epoch(5) | |||
else: | |||
batch_size = replaced_loader.batch_sampler.batch_size | |||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | |||
@@ -537,8 +512,9 @@ class TestSetDistReproDataloader: | |||
batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle) | |||
new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler) | |||
new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | |||
new_loader.batch_sampler.sampler.set_epoch(5) | |||
for idx, batch in enumerate(new_loader): | |||
left_idxes.update(batch) | |||
left_idxes.update(batch.tolist()) | |||
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) | |||
assert len(left_idxes | already_seen_idx) == len(self.dataset) | |||
@@ -549,7 +525,7 @@ class TestSetDistReproDataloader: | |||
# | |||
############################################################################ | |||
def generate_random_driver(features, labels, fp16=False, device="cpu"): | |||
def generate_random_driver(labels, features, fp16=False, device="cpu"): | |||
""" | |||
生成driver | |||
""" | |||
@@ -569,9 +545,9 @@ def test_save_and_load_model(only_state_dict): | |||
""" | |||
try: | |||
path = "model" | |||
dataset = PaddleRandomMaxDataset(40, 10) | |||
dataset = PaddleNormalXYDataset(20) | |||
dataloader = DataLoader(dataset, batch_size=4) | |||
driver1, driver2 = generate_random_driver(10, 10, device="gpu"), generate_random_driver(10, 10, device="gpu") | |||
driver1, driver2 = generate_random_driver(20, 1, device="gpu"), generate_random_driver(20, 1, device="gpu") | |||
if only_state_dict: | |||
driver1.save_model(path, only_state_dict) | |||
@@ -580,6 +556,7 @@ def test_save_and_load_model(only_state_dict): | |||
driver2.load_model(path, only_state_dict) | |||
for batch in dataloader: | |||
print("?") | |||
batch = driver1.move_data_to_device(batch) | |||
res1 = driver1.model.evaluate_step(**batch) | |||
res2 = driver2.model.evaluate_step(**batch) | |||
@@ -604,22 +581,23 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||
try: | |||
path = "model.ckp" | |||
dataset = PaddleRandomMaxDataset(40, 10) | |||
dataset = PaddleNormalXYDataset(40) | |||
dataloader = DataLoader( | |||
dataset=dataset, | |||
batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) | |||
) | |||
driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu") | |||
driver1, driver2 = generate_random_driver(40, 1, fp16, "gpu"), generate_random_driver(40, 1, False, "gpu") | |||
num_consumed_batches = 2 | |||
already_seen_x_set = set() | |||
already_seen_y_set = set() | |||
driver1.set_sampler_epoch(dataloader, 3) | |||
for idx, batch in enumerate(dataloader): | |||
if idx >= num_consumed_batches: | |||
break | |||
already_seen_x_set.update(batch["x"]) | |||
already_seen_y_set.update(batch["y"]) | |||
already_seen_x_set.update(batch["x"].reshape((-1, )).tolist()) | |||
already_seen_y_set.update(batch["y"].reshape((-1, )).tolist()) | |||
sampler_states = dataloader.batch_sampler.state_dict() | |||
save_states = {"num_consumed_batches": num_consumed_batches} | |||
@@ -656,10 +634,11 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||
assert start_batch == 2 * num_consumed_batches | |||
left_x_batches = set() | |||
left_y_batches = set() | |||
driver2.set_sampler_epoch(replaced_loader, 3) | |||
for idx, batch in enumerate(replaced_loader): | |||
left_x_batches.update(batch["x"]) | |||
left_y_batches.update(batch["y"]) | |||
left_x_batches.update(batch["x"].reshape((-1, )).tolist()) | |||
left_y_batches.update(batch["y"].reshape((-1, )).tolist()) | |||
res1 = driver1.model.evaluate_step(**batch) | |||
res2 = driver2.model.evaluate_step(**batch) | |||
assert paddle.equal_all(res1["pred"], res2["pred"]) | |||
@@ -679,14 +658,14 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||
@pytest.mark.parametrize("fp16", ([True, False])) | |||
def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||
""" | |||
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 | |||
测试save和load函数,主要测试 dataloader 被替换了 sampler 的情况 | |||
""" | |||
try: | |||
path = "model.ckp" | |||
driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu") | |||
dataset = PaddleRandomMaxDataset(40, 10) | |||
driver1, driver2 = generate_random_driver(40, 1, fp16, "gpu"), generate_random_driver(40, 1, False, "gpu") | |||
dataset = PaddleNormalXYDataset(40) | |||
batch_sampler = BatchSampler(dataset=dataset, batch_size=4) | |||
batch_sampler.sampler = RandomSampler(dataset, True) | |||
dataloader = DataLoader( | |||
@@ -697,11 +676,12 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||
already_seen_x_set = set() | |||
already_seen_y_set = set() | |||
driver1.set_sampler_epoch(dataloader, 3) | |||
for idx, batch in enumerate(dataloader): | |||
if idx >= num_consumed_batches: | |||
break | |||
already_seen_x_set.update(batch["x"]) | |||
already_seen_y_set.update(batch["y"]) | |||
already_seen_x_set.update(batch["x"].reshape((-1, )).tolist()) | |||
already_seen_y_set.update(batch["y"].reshape((-1, )).tolist()) | |||
sampler_states = dataloader.batch_sampler.sampler.state_dict() | |||
save_states = {"num_consumed_batches": num_consumed_batches} | |||
@@ -743,10 +723,11 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||
assert start_batch == 2 * num_consumed_batches | |||
left_x_batches = set() | |||
left_y_batches = set() | |||
driver1.set_sampler_epoch(replaced_loader, 3) | |||
for idx, batch in enumerate(replaced_loader): | |||
left_x_batches.update(batch["x"]) | |||
left_y_batches.update(batch["y"]) | |||
left_x_batches.update(batch["x"].reshape((-1, )).tolist()) | |||
left_y_batches.update(batch["y"].reshape((-1, )).tolist()) | |||
res1 = driver1.model.evaluate_step(**batch) | |||
res2 = driver2.model.evaluate_step(**batch) | |||
assert paddle.equal_all(res1["pred"], res2["pred"]) | |||
@@ -10,7 +10,7 @@ from fastNLP.core.samplers import ( | |||
UnrepeatedSequentialSampler, | |||
) | |||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset | |||
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchNormalXYDataset | |||
from tests.helpers.utils import magic_argv_env_context | |||
from fastNLP.envs.distributed import rank_zero_rm | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
@@ -19,8 +19,8 @@ if _NEED_IMPORT_TORCH: | |||
import torch.distributed as dist | |||
from torch.utils.data import DataLoader, BatchSampler | |||
def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="all"): | |||
torch_model = TorchNormalModel_Classification_1(num_labels, feature_dimension) | |||
def generate_driver(labels, features, device=[0,1], fp16=False, output_from_new_proc="all"): | |||
torch_model = TorchNormalModel_Classification_1(labels, features) | |||
torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) | |||
device = [torch.device(i) for i in device] | |||
driver = TorchDDPDriver( | |||
@@ -504,10 +504,14 @@ class TestSetDistReproDataloader: | |||
num_replicas = len(self.device) | |||
num_consumed_batches = 2 | |||
already_seen_idx = set() | |||
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): | |||
sampler_states = replaced_loader.batch_sampler.set_epoch(4) | |||
else: | |||
sampler_states = replaced_loader.batch_sampler.sampler.set_epoch(4) | |||
for idx, batch in enumerate(replaced_loader): | |||
if idx >= num_consumed_batches: | |||
break | |||
already_seen_idx.update(batch) | |||
already_seen_idx.update(batch.tolist()) | |||
dist.barrier() | |||
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): | |||
sampler_states = replaced_loader.batch_sampler.state_dict() | |||
@@ -533,6 +537,7 @@ class TestSetDistReproDataloader: | |||
pad=True | |||
) | |||
new_loader.batch_sampler.load_state_dict(sampler_states) | |||
new_loader.batch_sampler.set_epoch(4) | |||
else: | |||
batch_size = replaced_loader.batch_sampler.batch_size | |||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas | |||
@@ -543,8 +548,9 @@ class TestSetDistReproDataloader: | |||
rank=driver.global_rank | |||
) | |||
new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | |||
new_loader.batch_sampler.sampler.set_epoch(4) | |||
for idx, batch in enumerate(new_loader): | |||
left_idxes.update(batch) | |||
left_idxes.update(batch.tolist()) | |||
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas | |||
assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas | |||
@@ -562,7 +568,7 @@ class TestSaveLoad: | |||
""" | |||
def setup_method(self): | |||
self.dataset = TorchArgMaxDataset(10, 20) | |||
self.dataset = TorchNormalXYDataset(20) | |||
@magic_argv_env_context | |||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||
@@ -574,7 +580,7 @@ class TestSaveLoad: | |||
path = "model" | |||
dataloader = DataLoader(self.dataset, batch_size=2) | |||
driver1, driver2 = generate_driver(10, 10), generate_driver(10, 10) | |||
driver1, driver2 = generate_driver(20, 1), generate_driver(20, 1) | |||
driver1.save_model(path, only_state_dict) | |||
@@ -618,8 +624,8 @@ class TestSaveLoad: | |||
path = "model.ckp" | |||
num_replicas = len(device) | |||
driver1, driver2 = generate_driver(10, 10, device=device, fp16=fp16), \ | |||
generate_driver(10, 10, device=device, fp16=False) | |||
driver1, driver2 = generate_driver(20, 1, device=device, fp16=fp16), \ | |||
generate_driver(20, 1, device=device, fp16=False) | |||
dataloader = dataloader_with_bucketedbatchsampler( | |||
self.dataset, | |||
length=[10 for i in range(len(self.dataset))], | |||
@@ -636,11 +642,12 @@ class TestSaveLoad: | |||
already_seen_x_set = set() | |||
already_seen_y_set = set() | |||
driver1.set_sampler_epoch(dataloader, 4) | |||
for idx, batch in enumerate(dataloader): | |||
if idx >= num_consumed_batches: | |||
break | |||
already_seen_x_set.update(batch["x"]) | |||
already_seen_y_set.update(batch["y"]) | |||
already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) | |||
already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) | |||
# 同步 | |||
dist.barrier() | |||
@@ -665,7 +672,6 @@ class TestSaveLoad: | |||
pad=True | |||
) | |||
dist.barrier() | |||
print("========load=======", driver1.global_rank, driver2.global_rank) | |||
load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||
dist.barrier() | |||
replaced_loader = load_states.pop("dataloader") | |||
@@ -690,10 +696,11 @@ class TestSaveLoad: | |||
assert start_batch == 2 * num_consumed_batches | |||
left_x_batches = set() | |||
left_y_batches = set() | |||
driver2.set_sampler_epoch(replaced_loader, 4) | |||
for idx, batch in enumerate(replaced_loader): | |||
left_x_batches.update(batch["x"]) | |||
left_y_batches.update(batch["y"]) | |||
left_x_batches.update(batch["x"].reshape(-1, ).tolist()) | |||
left_y_batches.update(batch["y"].reshape(-1, ).tolist()) | |||
res1 = driver1.model( | |||
batch, | |||
fastnlp_fn=driver1.model.module.model.evaluate_step, | |||
@@ -716,7 +723,6 @@ class TestSaveLoad: | |||
dist.barrier() | |||
finally: | |||
rank_zero_rm(path) | |||
print("=======delete======") | |||
if dist.is_initialized(): | |||
dist.destroy_process_group() | |||
@@ -735,8 +741,8 @@ class TestSaveLoad: | |||
num_replicas = len(device) | |||
driver1 = generate_driver(10, 10, device=device, fp16=fp16) | |||
driver2 = generate_driver(10, 10, device=device, fp16=False) | |||
driver1 = generate_driver(20, 1, device=device, fp16=fp16) | |||
driver2 = generate_driver(20, 1, device=device, fp16=False) | |||
dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False) | |||
dataloader.batch_sampler.sampler.set_distributed( | |||
@@ -748,11 +754,12 @@ class TestSaveLoad: | |||
already_seen_x_set = set() | |||
already_seen_y_set = set() | |||
driver1.set_sampler_epoch(dataloader, 4) | |||
for idx, batch in enumerate(dataloader): | |||
if idx >= num_consumed_batches: | |||
break | |||
already_seen_x_set.update(batch["x"]) | |||
already_seen_y_set.update(batch["y"]) | |||
already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) | |||
already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) | |||
# 同步 | |||
dist.barrier() | |||
@@ -797,10 +804,11 @@ class TestSaveLoad: | |||
assert start_batch == 2 * num_consumed_batches | |||
left_x_batches = set() | |||
left_y_batches = set() | |||
driver2.set_sampler_epoch(replaced_loader, 4) | |||
for idx, batch in enumerate(replaced_loader): | |||
left_x_batches.update(batch["x"]) | |||
left_y_batches.update(batch["y"]) | |||
left_x_batches.update(batch["x"].reshape(-1, ).tolist()) | |||
left_y_batches.update(batch["y"].reshape(-1, ).tolist()) | |||
res1 = driver1.model( | |||
batch, | |||
fastnlp_fn=driver1.model.module.model.evaluate_step, | |||
@@ -14,7 +14,7 @@ else: | |||
@pytest.mark.torch | |||
def test_incorrect_driver(): | |||
model = TorchNormalModel_Classification_1(2, 100) | |||
model = TorchNormalModel_Classification_1(20, 10) | |||
with pytest.raises(ValueError): | |||
driver = initialize_torch_driver("paddle", 0, model) | |||
@@ -33,7 +33,7 @@ def test_get_single_device(driver, device): | |||
测试正常情况下初始化TorchSingleDriver的情况 | |||
""" | |||
model = TorchNormalModel_Classification_1(2, 100) | |||
model = TorchNormalModel_Classification_1(20, 10) | |||
driver = initialize_torch_driver(driver, device, model) | |||
assert isinstance(driver, TorchSingleDriver) | |||
@@ -52,7 +52,7 @@ def test_get_ddp(driver, device): | |||
测试 ddp 多卡的初始化情况 | |||
""" | |||
model = TorchNormalModel_Classification_1(64, 10) | |||
model = TorchNormalModel_Classification_1(20, 10) | |||
driver = initialize_torch_driver(driver, device, model) | |||
assert isinstance(driver, TorchDDPDriver) | |||
@@ -70,6 +70,6 @@ def test_device_out_of_range(driver, device): | |||
""" | |||
测试传入的device超过范围的情况 | |||
""" | |||
model = TorchNormalModel_Classification_1(2, 100) | |||
model = TorchNormalModel_Classification_1(20, 10) | |||
with pytest.raises(ValueError): | |||
driver = initialize_torch_driver(driver, device, model) |
@@ -6,7 +6,7 @@ from pkg_resources import parse_version | |||
from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver | |||
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | |||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset | |||
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchNormalXYDataset | |||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset | |||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||
from fastNLP.envs.distributed import rank_zero_rm | |||
@@ -15,6 +15,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
from torch.utils.data import DataLoader, BatchSampler | |||
if _NEED_IMPORT_PADDLE: | |||
import paddle | |||
@@ -67,95 +68,67 @@ class TestTorchDriverFunctions: | |||
model = TorchNormalModel_Classification_1(10, 32) | |||
self.driver = TorchSingleDriver(model, device="cpu") | |||
@pytest.mark.torchpaddle | |||
def test_check_single_optimizer_legality(self): | |||
@pytest.mark.torch | |||
def test_check_optimizers_legality(self): | |||
""" | |||
测试传入单个 optimizer 时的表现 | |||
测试对合法 optimizers 的检查 | |||
""" | |||
# 单个 optimizer | |||
optimizer = torch.optim.Adam( | |||
params=self.driver.model.parameters(), | |||
lr=0.01 | |||
) | |||
self.driver.set_optimizers(optimizer) | |||
optimizer = paddle.optimizer.Adam( | |||
parameters=PaddleNormalModel_Classification_1(10, 32).parameters(), | |||
learning_rate=0.01, | |||
) | |||
# 传入 torch 的 optimize r时,应该报错 ValueError | |||
with pytest.raises(ValueError): | |||
self.driver.set_optimizers(optimizer) | |||
@pytest.mark.torchpaddle | |||
def test_check_optimizers_legality(self): | |||
""" | |||
测试传入 optimizer list 的表现 | |||
""" | |||
# 列表 | |||
optimizers = [ | |||
torch.optim.Adam( | |||
params=self.driver.model.parameters(), | |||
lr=0.01 | |||
) for i in range(10) | |||
] | |||
self.driver.set_optimizers(optimizers) | |||
optimizers += [ | |||
@pytest.mark.torchpaddle | |||
def test_invalid_optimizers(self): | |||
""" | |||
测试传入非法的 optimizers | |||
""" | |||
optimizer = paddle.optimizer.Adam( | |||
parameters=PaddleNormalModel_Classification_1(10, 32).parameters(), | |||
learning_rate=0.01, | |||
) | |||
with pytest.raises(TypeError): | |||
self.driver.set_optimizers(optimizer) | |||
optimizers = [ | |||
paddle.optimizer.Adam( | |||
parameters=PaddleNormalModel_Classification_1(10, 32).parameters(), | |||
learning_rate=0.01, | |||
) | |||
] | |||
with pytest.raises(ValueError): | |||
with pytest.raises(TypeError): | |||
self.driver.set_optimizers(optimizers) | |||
@pytest.mark.torchpaddle | |||
def test_check_dataloader_legality_in_train(self): | |||
@pytest.mark.torch | |||
def test_check_dataloader_legality(self): | |||
""" | |||
测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 | |||
测试 check_dataloader_legality 函数的表现 | |||
""" | |||
dataloader = DataLoader(TorchNormalDataset()) | |||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
# 创建 paddle 的 dataloader | |||
dataloader = paddle.io.DataLoader( | |||
PaddleNormalDataset(), | |||
batch_size=32, shuffle=True | |||
) | |||
with pytest.raises(ValueError): | |||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
self.driver.check_dataloader_legality(dataloader) | |||
@pytest.mark.torchpaddle | |||
def test_check_dataloader_legality_in_test(self): | |||
def test_check_dataloader_legality_invalid(self): | |||
""" | |||
测试 `is_train` 参数为 False 时,_check_dataloader_legality 函数的表现 | |||
测试 check_dataloader_legality 函数传入其他类型的表现 | |||
""" | |||
# 此时传入的应该是dict | |||
dataloader = { | |||
"train": DataLoader(TorchNormalDataset()), | |||
"test": DataLoader(TorchNormalDataset()) | |||
} | |||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
# 传入的不是 dict,应该报错 | |||
dataloader = DataLoader(TorchNormalDataset()) | |||
with pytest.raises(ValueError): | |||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
# 创建 paddle 的 dataloader | |||
train_loader = paddle.io.DataLoader( | |||
PaddleNormalDataset(), | |||
batch_size=32, shuffle=True | |||
) | |||
test_loader = paddle.io.DataLoader( | |||
dataloader = paddle.io.DataLoader( | |||
PaddleNormalDataset(), | |||
batch_size=32, shuffle=True | |||
) | |||
dataloader = {"train": train_loader, "test": test_loader} | |||
with pytest.raises(ValueError): | |||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
with pytest.raises(TypeError): | |||
self.driver.check_dataloader_legality(dataloader) | |||
@pytest.mark.torch | |||
def test_tensor_to_numeric(self): | |||
@@ -515,10 +488,14 @@ class TestSetDistReproDataloader: | |||
# 迭代两个 batch | |||
num_consumed_batches = 2 | |||
already_seen_idx = set() | |||
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | |||
replaced_loader.batch_sampler.set_epoch(3) | |||
else: | |||
replaced_loader.batch_sampler.sampler.set_epoch(3) | |||
for idx, batch in enumerate(replaced_loader): | |||
if idx >= num_consumed_batches: | |||
break | |||
already_seen_idx.update(batch) | |||
already_seen_idx.update(batch.tolist()) | |||
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | |||
sampler_states = replaced_loader.batch_sampler.state_dict() | |||
else: | |||
@@ -532,14 +509,16 @@ class TestSetDistReproDataloader: | |||
# 重新改造 dataloader | |||
new_loader = dataloader_with_randombatchsampler(replaced_loader.dataset, batch_size, shuffle, False) | |||
new_loader.batch_sampler.load_state_dict(sampler_states) | |||
new_loader.batch_sampler.set_epoch(3) | |||
else: | |||
batch_size = replaced_loader.batch_sampler.batch_size | |||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | |||
# 重新构造 dataloader | |||
new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, False) | |||
new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | |||
new_loader.batch_sampler.sampler.set_epoch(3) | |||
for idx, batch in enumerate(new_loader): | |||
left_idxes.update(batch) | |||
left_idxes.update(batch.tolist()) | |||
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) | |||
assert len(left_idxes | already_seen_idx) == len(self.dataset) | |||
@@ -550,7 +529,7 @@ class TestSetDistReproDataloader: | |||
# | |||
############################################################################ | |||
def generate_random_driver(features, labels, fp16=False, device="cpu"): | |||
def generate_random_driver(labels, features, fp16=False, device="cpu"): | |||
""" | |||
生成driver | |||
""" | |||
@@ -570,9 +549,9 @@ def test_save_and_load_model(only_state_dict): | |||
""" | |||
try: | |||
path = "model" | |||
dataset = TorchArgMaxDataset(10, 40) | |||
dataset = TorchNormalXYDataset(20) | |||
dataloader = DataLoader(dataset, batch_size=4) | |||
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) | |||
driver1, driver2 = generate_random_driver(20, 1), generate_random_driver(20, 1) | |||
driver1.save_model(path, only_state_dict) | |||
driver2.load_model(path, only_state_dict) | |||
@@ -596,19 +575,20 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||
try: | |||
path = "model.ckp" | |||
dataset = TorchArgMaxDataset(10, 40) | |||
dataset = TorchNormalXYDataset(20) | |||
dataloader = dataloader_with_randombatchsampler(dataset, 4, True, False) | |||
driver1, driver2 = generate_random_driver(10, 10, fp16, "cuda"), generate_random_driver(10, 10, False, "cuda") | |||
driver1, driver2 = generate_random_driver(20, 1, fp16, "cuda"), generate_random_driver(20, 1, False, "cuda") | |||
num_consumed_batches = 2 | |||
already_seen_x_set = set() | |||
already_seen_y_set = set() | |||
driver1.set_sampler_epoch(dataloader, 3) | |||
for idx, batch in enumerate(dataloader): | |||
if idx >= num_consumed_batches: | |||
break | |||
already_seen_x_set.update(batch["x"]) | |||
already_seen_y_set.update(batch["y"]) | |||
already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) | |||
already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) | |||
sampler_states = dataloader.batch_sampler.state_dict() | |||
save_states = {"num_consumed_batches": num_consumed_batches} | |||
@@ -639,11 +619,12 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||
assert start_batch == 2 * num_consumed_batches | |||
left_x_batches = set() | |||
left_y_batches = set() | |||
driver1.set_sampler_epoch(replaced_loader, 3) | |||
for idx, batch in enumerate(replaced_loader): | |||
batch = driver2.move_data_to_device(batch) | |||
left_x_batches.update(batch["x"]) | |||
left_y_batches.update(batch["y"]) | |||
left_x_batches.update(batch["x"].reshape(-1, ).tolist()) | |||
left_y_batches.update(batch["y"].reshape(-1, ).tolist()) | |||
res1 = driver1.model.evaluate_step(**batch) | |||
res2 = driver2.model.evaluate_step(**batch) | |||
assert torch.equal(res1["preds"], res2["preds"]) | |||
@@ -660,24 +641,25 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||
@pytest.mark.parametrize("fp16", ([True, False])) | |||
def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||
""" | |||
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 | |||
测试save和load函数,主要测试 dataloader 被替换了 sampler 的情况 | |||
""" | |||
try: | |||
path = "model.ckp" | |||
driver1, driver2 = generate_random_driver(10, 10, fp16, "cuda"), generate_random_driver(10, 10, False, "cuda") | |||
dataset = TorchArgMaxDataset(10, 40) | |||
driver1, driver2 = generate_random_driver(40, 1, fp16, "cuda"), generate_random_driver(40, 1, False, "cuda") | |||
dataset = TorchNormalXYDataset(40) | |||
dataloader = dataloader_with_randomsampler(dataset, 4, True, False) | |||
num_consumed_batches = 2 | |||
already_seen_x_set = set() | |||
already_seen_y_set = set() | |||
driver1.set_sampler_epoch(dataloader, 3) | |||
for idx, batch in enumerate(dataloader): | |||
if idx >= num_consumed_batches: | |||
break | |||
already_seen_x_set.update(batch["x"]) | |||
already_seen_y_set.update(batch["y"]) | |||
already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) | |||
already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) | |||
sampler_states = dataloader.batch_sampler.sampler.state_dict() | |||
save_states = {"num_consumed_batches": num_consumed_batches} | |||
@@ -711,11 +693,13 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||
assert start_batch == 2 * num_consumed_batches | |||
left_x_batches = set() | |||
left_y_batches = set() | |||
# set epoch | |||
driver2.set_sampler_epoch(replaced_loader, 3) | |||
for idx, batch in enumerate(replaced_loader): | |||
batch = driver2.move_data_to_device(batch) | |||
left_x_batches.update(batch["x"]) | |||
left_y_batches.update(batch["y"]) | |||
left_x_batches.update(batch["x"].reshape(-1, ).tolist()) | |||
left_y_batches.update(batch["y"].reshape(-1, ).tolist()) | |||
res1 = driver1.model.evaluate_step(**batch) | |||
res2 = driver2.model.evaluate_step(**batch) | |||
assert torch.equal(res1["preds"], res2["preds"]) | |||
@@ -0,0 +1,46 @@ | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
if _NEED_IMPORT_JITTOR: | |||
import jittor as jt | |||
from jittor.dataset import Dataset | |||
else: | |||
from fastNLP.core.utils.dummy_class import DummyClass as Dataset | |||
class JittorNormalDataset(Dataset): | |||
def __init__(self, num_of_data=100, **kwargs): | |||
super(JittorNormalDataset, self).__init__(**kwargs) | |||
self._data = list(range(num_of_data)) | |||
self.set_attrs(total_len=num_of_data) | |||
def __getitem__(self, item): | |||
return self._data[item] | |||
class JittorNormalXYDataset(Dataset): | |||
""" | |||
可以被输入到分类模型中的普通数据集 | |||
""" | |||
def __init__(self, num_of_data=1000, **kwargs): | |||
super(JittorNormalXYDataset, self).__init__(**kwargs) | |||
self.num_of_data = num_of_data | |||
self._data = list(range(num_of_data)) | |||
self.set_attrs(total_len=num_of_data) | |||
def __getitem__(self, item): | |||
return { | |||
"x": jt.Var([self._data[item]]), | |||
"y": jt.Var([self._data[item]]) | |||
} | |||
class JittorArgMaxDataset(Dataset): | |||
def __init__(self, num_samples, num_features, **kwargs): | |||
super(JittorArgMaxDataset, self).__init__(**kwargs) | |||
self.x = jt.randn(num_samples, num_features) | |||
self.y = self.x.argmax(dim=-1) | |||
self.set_attrs(total_len=num_samples) | |||
def __getitem__(self, item): | |||
return {"x": self.x[item], "y": self.y[item]} | |||
if __name__ == "__main__": | |||
dataset = JittorNormalDataset() | |||
print(len(dataset)) |
@@ -19,8 +19,24 @@ class PaddleNormalDataset(Dataset): | |||
def __getitem__(self, item): | |||
return self._data[item] | |||
class PaddleNormalXYDataset(Dataset): | |||
""" | |||
可以被输入到分类模型中的普通数据集 | |||
""" | |||
def __init__(self, num_of_data=1000): | |||
self.num_of_data = num_of_data | |||
self._data = list(range(num_of_data)) | |||
def __len__(self): | |||
return self.num_of_data | |||
def __getitem__(self, item): | |||
return { | |||
"x": paddle.to_tensor([self._data[item]], dtype="float32"), | |||
"y": paddle.to_tensor([self._data[item]], dtype="float32") | |||
} | |||
class PaddleRandomMaxDataset(Dataset): | |||
class PaddleArgMaxDataset(Dataset): | |||
def __init__(self, num_samples, num_features): | |||
self.x = paddle.randn((num_samples, num_features)) | |||
self.y = self.x.argmax(axis=-1) | |||
@@ -1,4 +1,6 @@ | |||
from functools import reduce | |||
from numpy import dtype | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
@@ -19,6 +21,23 @@ class TorchNormalDataset(Dataset): | |||
def __getitem__(self, item): | |||
return self._data[item] | |||
class TorchNormalXYDataset(Dataset): | |||
""" | |||
可以被输入到分类模型中的普通数据集 | |||
""" | |||
def __init__(self, num_of_data=1000): | |||
self.num_of_data = num_of_data | |||
self._data = list(range(num_of_data)) | |||
def __len__(self): | |||
return self.num_of_data | |||
def __getitem__(self, item): | |||
return { | |||
"x": torch.tensor([self._data[item]], dtype=torch.float), | |||
"y": torch.tensor([self._data[item]], dtype=torch.float) | |||
} | |||
# 该类专门用于为 tests.helpers.models.torch_model.py/ TorchNormalModel_Classification_1 创建数据; | |||
class TorchNormalDataset_Classification(Dataset): | |||
@@ -0,0 +1,57 @@ | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
if _NEED_IMPORT_JITTOR: | |||
from jittor import Module, nn | |||
else: | |||
from fastNLP.core.utils.dummy_class import DummyClass as Module | |||
class JittorNormalModel_Classification_1(Module): | |||
""" | |||
基础的 jittor 分类模型 | |||
""" | |||
def __init__(self, num_labels, feature_dimension): | |||
super(JittorNormalModel_Classification_1, self).__init__() | |||
self.num_labels = num_labels | |||
self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64) | |||
self.ac1 = nn.ReLU() | |||
self.linear2 = nn.Linear(in_features=64, out_features=32) | |||
self.ac2 = nn.ReLU() | |||
self.output = nn.Linear(in_features=32, out_features=num_labels) | |||
self.loss_fn = nn.CrossEntropyLoss() | |||
def execute(self, x): | |||
x = self.ac1(self.linear1(x)) | |||
x = self.ac2(self.linear2(x)) | |||
x = self.output(x) | |||
return x | |||
def train_step(self, x, y): | |||
x = self(x) | |||
return {"loss": self.loss_fn(x, y)} | |||
def evaluate_step(self, x, y): | |||
x = self(x) | |||
return {"pred": x, "target": y.reshape((-1,))} | |||
class JittorNormalModel_Classification_2(Module): | |||
""" | |||
基础的 jittor 分类模型,只实现 execute 函数测试用户自己初始化了分布式的场景 | |||
""" | |||
def __init__(self, num_labels, feature_dimension): | |||
super(JittorNormalModel_Classification_2, self).__init__() | |||
self.num_labels = num_labels | |||
self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64) | |||
self.ac1 = nn.ReLU() | |||
self.linear2 = nn.Linear(in_features=64, out_features=32) | |||
self.ac2 = nn.ReLU() | |||
self.output = nn.Linear(in_features=32, out_features=num_labels) | |||
self.loss_fn = nn.CrossEntropyLoss() | |||
def execute(self, x, y): | |||
x = self.ac1(self.linear1(x)) | |||
x = self.ac2(self.linear2(x)) | |||
x = self.output(x) | |||
return {"loss": self.loss_fn(x, y), "pred": x, "target": y.reshape((-1,))} |
@@ -8,7 +8,7 @@ else: | |||
class PaddleNormalModel_Classification_1(Layer): | |||
""" | |||
基础的paddle分类模型 | |||
基础的 paddle 分类模型 | |||
""" | |||
def __init__(self, num_labels, feature_dimension): | |||
super(PaddleNormalModel_Classification_1, self).__init__() | |||
@@ -39,7 +39,7 @@ class PaddleNormalModel_Classification_1(Layer): | |||
class PaddleNormalModel_Classification_2(Layer): | |||
""" | |||
基础的paddle分类模型,只实现 forward 函数测试用户自己初始化了分布式的场景 | |||
基础的 paddle 分类模型,只实现 forward 函数测试用户自己初始化了分布式的场景 | |||
""" | |||
def __init__(self, num_labels, feature_dimension): | |||
super(PaddleNormalModel_Classification_2, self).__init__() | |||
@@ -56,5 +56,4 @@ class PaddleNormalModel_Classification_2(Layer): | |||
x = self.ac1(self.linear1(x)) | |||
x = self.ac2(self.linear2(x)) | |||
x = self.output(x) | |||
loss = self.loss_fn(x, y) | |||
return {"loss": self.loss_fn(x, y), "pred": x, "target": y.reshape((-1,))} |
@@ -33,7 +33,11 @@ class TestPaddle2Torch: | |||
""" | |||
assert isinstance(tensor, torch.Tensor) | |||
assert tensor.device == torch.device(device) | |||
if device == "cpu": | |||
assert not tensor.is_cuda | |||
else: | |||
assert tensor.is_cuda | |||
assert tensor.device.index == torch.device(device).index | |||
assert tensor.requires_grad == requires_grad | |||
def test_gradient(self): | |||
@@ -261,7 +265,8 @@ class TestJittor2Torch: | |||
if device == "cpu": | |||
assert not tensor.is_cuda | |||
else: | |||
assert tensor.device == torch.device(device) | |||
assert tensor.is_cuda | |||
assert tensor.device.index == torch.device(device).index | |||
assert tensor.requires_grad == requires_grad | |||
def test_var_transfer(self): | |||
@@ -271,7 +276,10 @@ class TestJittor2Torch: | |||
jittor_var = jittor.rand((3, 4, 5)) | |||
res = jittor2torch(jittor_var) | |||
self.check_torch_tensor(res, "cpu", True) | |||
if jittor.flags.use_cuda: | |||
self.check_torch_tensor(res, "cuda:0", True) | |||
else: | |||
self.check_torch_tensor(res, "cpu", True) | |||
res = jittor2torch(jittor_var, device="cuda:2", no_gradient=None) | |||
self.check_torch_tensor(res, "cuda:2", True) | |||
@@ -291,7 +299,10 @@ class TestJittor2Torch: | |||
res = jittor2torch(jittor_list) | |||
assert isinstance(res, list) | |||
for t in res: | |||
self.check_torch_tensor(t, "cpu", True) | |||
if jittor.flags.use_cuda: | |||
self.check_torch_tensor(t, "cuda:0", True) | |||
else: | |||
self.check_torch_tensor(t, "cpu", True) | |||
res = jittor2torch(jittor_list, device="cuda:1", no_gradient=False) | |||
assert isinstance(res, list) | |||
@@ -327,17 +338,29 @@ class TestJittor2Torch: | |||
} | |||
res = jittor2torch(jittor_dict) | |||
assert isinstance(res, dict) | |||
self.check_torch_tensor(res["tensor"], "cpu", True) | |||
if jittor.flags.use_cuda: | |||
self.check_torch_tensor(res["tensor"], "cuda:0", True) | |||
else: | |||
self.check_torch_tensor(res["tensor"], "cpu", True) | |||
assert isinstance(res["list"], list) | |||
for t in res["list"]: | |||
self.check_torch_tensor(t, "cpu", True) | |||
if jittor.flags.use_cuda: | |||
self.check_torch_tensor(t, "cuda:0", True) | |||
else: | |||
self.check_torch_tensor(t, "cpu", True) | |||
assert isinstance(res["int"], int) | |||
assert isinstance(res["string"], str) | |||
assert isinstance(res["dict"], dict) | |||
assert isinstance(res["dict"]["list"], list) | |||
for t in res["dict"]["list"]: | |||
self.check_torch_tensor(t, "cpu", True) | |||
self.check_torch_tensor(res["dict"]["tensor"], "cpu", True) | |||
if jittor.flags.use_cuda: | |||
self.check_torch_tensor(t, "cuda:0", True) | |||
else: | |||
self.check_torch_tensor(t, "cpu", True) | |||
if jittor.flags.use_cuda: | |||
self.check_torch_tensor(res["dict"]["tensor"], "cuda:0", True) | |||
else: | |||
self.check_torch_tensor(res["dict"]["tensor"], "cpu", True) | |||
############################################################################ | |||