diff --git a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py index a93fb55c..96f6747b 100644 --- a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py @@ -24,7 +24,6 @@ from fastNLP.core.dataset import DataSet as FDataSet class _JittorDataset(Dataset): """ 对用户传的dataset进行封装,以便JittorDataLoader能够支持使用自定义的dataset - """ def __init__(self, dataset) -> None: @@ -83,7 +82,7 @@ class JittorDataLoader: # TODO 验证支持replacesampler (以后完成) 增加Sampler # 将内部dataset批次设置为1 if isinstance(dataset, Dataset): - dataset.set_attrs(batch_size=1) + dataset.set_attrs(batch_size=1, shuffle=False, endless=False) # FastNLP Datset, collate_fn not None if isinstance(dataset, FDataSet) and collate_fn is None: @@ -115,6 +114,12 @@ class JittorDataLoader: self.cur_batch_indices = None + def __getattr__(self, attr): + if attr in ["batch_size", "shuffle", "drop_last", "num_workers", "buffer_size", "stop_grad", + "keep_numpy_array", "endless", "sampler"]: + return getattr(self.dataset, attr) + raise AttributeError(f"{self} has not attribute '{attr}'") + def __iter__(self): # TODO 第一次迭代后不能设置collate_fn,设置是无效的 if self.cur_batch_indices is None: diff --git a/fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py b/fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py index eff8fcfe..b95d965e 100644 --- a/fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py +++ b/fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py @@ -10,7 +10,7 @@ if _NEED_IMPORT_JITTOR: __all__ = [] -def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], model: jittor.Module, **kwargs) -> JittorDriver: +def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], model: "jittor.Module", **kwargs) -> JittorDriver: r""" 用来根据参数 ``device`` 来确定并且初始化一个具体的 ``Driver`` 实例然后返回回去。 @@ -30,7 +30,7 @@ def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], mo raise ValueError("Parameter `driver` can only be one of these values: ['jittor'].") # TODO 实现更详细的判断 - if device in ["cpu", "gpu", "cuda", "cuda:0", 0, None]: + if device in ["cpu", "gpu", "cuda", None]: return JittorSingleDriver(model, device, **kwargs) elif type(device) is int: return JittorMPIDriver(model, device, **kwargs) diff --git a/fastNLP/core/drivers/jittor_driver/jittor_driver.py b/fastNLP/core/drivers/jittor_driver/jittor_driver.py index 0dd6d0fb..5b38747d 100644 --- a/fastNLP/core/drivers/jittor_driver/jittor_driver.py +++ b/fastNLP/core/drivers/jittor_driver/jittor_driver.py @@ -1,23 +1,31 @@ import os -import random from pathlib import Path -from typing import Union, Optional -from functools import partial - -import numpy as np +from typing import Union, Optional, Dict +from contextlib import nullcontext +from dataclasses import dataclass from fastNLP.envs.imports import _NEED_IMPORT_JITTOR from fastNLP.core.drivers.driver import Driver from fastNLP.core.dataloaders import JittorDataLoader +from fastNLP.core.samplers import ReproducibleSampler, RandomSampler from fastNLP.core.log import logger from fastNLP.core.utils import apply_to_collection -from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_SEED_WORKERS +from fastNLP.envs import ( + FASTNLP_MODEL_FILENAME, + FASTNLP_CHECKPOINT_FILENAME, +) if _NEED_IMPORT_JITTOR: import jittor as jt from jittor import Module from jittor.optim import Optimizer from jittor.dataset import Dataset + from jittor.dataset import ( + BatchSampler as JittorBatchSampler, + Sampler as JittorSampler, + RandomSampler as JittorRandomSampler, + SequentialSampler as JittorSequentialSampler + ) _reduces = { 'max': jt.max, @@ -56,6 +64,7 @@ class JittorDriver(Driver): else: jt.flags.auto_mixed_precision_level = 0 self.fp16 = fp16 + self._auto_cast = nullcontext # 用来设置是否关闭 auto_param_call 中的参数匹配问题; self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) @@ -68,7 +77,7 @@ class JittorDriver(Driver): def _check_optimizer_legality(optimizers): for each_optimizer in optimizers: if not isinstance(each_optimizer, Optimizer): - raise ValueError(f"Each optimizer of parameter `optimizers` should be 'jittor.optim.Optimizer' type, " + raise TypeError(f"Each optimizer of parameter `optimizers` should be 'jittor.optim.Optimizer' type, " f"not {type(each_optimizer)}.") def step(self): @@ -117,30 +126,118 @@ class JittorDriver(Driver): model = self.unwrap_model() model.load(filepath) - def save_checkpoint(self): - ... + def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): + dataloader_args = self.get_dataloader_args(dataloader) + if dataloader_args.sampler: + sampler = dataloader_args.sampler + else: + raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") + + num_consumed_batches = states.pop('num_consumed_batches') + if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): + sampler_states = sampler.state_dict() + # 需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples + # 会造成多余实际消耗的问题。因为 + num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) + if num_consumed_samples_array is not None: + if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 + if dataloader_args.batch_size is not None: + num_consumed_batches = num_consumed_batches * dataloader_args.batch_size + else: # 有可能 batch_size 为 None,就只有损失精度了 + logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " + "it may cause missing some samples when reload.") + num_consumed_batches = sampler_states['num_consumed_samples'] + sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] + assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." + else: + if dataloader_args.batch_size is not None: + sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ + * num_consumed_batches + else: + logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " + "it may cause missing some samples when reload.") + + states['sampler_states'] = sampler_states + else: + raise RuntimeError('The sampler has no `state_dict()` method, fastNLP cannot save the training ' + 'state.') + + # 2. 保存模型的状态; + if should_save_model: + if not os.path.exists(folder): + os.mkdir(folder) + model_path = folder.joinpath(FASTNLP_MODEL_FILENAME) + self.save_model(model_path, only_state_dict=only_state_dict) + + # 3. 保存 optimizers 的状态; + states["optimizers_state_dict"] = self.get_optimizer_state() + + # 4. 保存fp16的状态 + + logger.debug("Save optimizer state dict") + jt.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) def get_optimizer_state(self): - # optimizers_state_dict = {} - # for i in range(len(self.optimizers)): - # optimizer: torch.optim.Optimizer = self.optimizers[i] - # optimizer_state = optimizer.state_dict() - # optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu")) - # optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; - # return optimizers_state_dict - ... + optimizers_state_dict = {} + for i in range(len(self.optimizers)): + optimizer: Optimizer = self.optimizers[i] + optimizers_state_dict[f"optimizer{i}"] = optimizer.state_dict() # 注意这里没有使用 deepcopy,测试是不需要的; + return optimizers_state_dict def load_optimizer_state(self, states): - # assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \ - # f"checkpoint it is:{len(states)}" - # for i in range(len(self.optimizers)): - # optimizer: torch.optim.Optimizer = self.optimizers[i] - # optimizer.load_state_dict(states[f"optimizer{i}"]) - # logger.debug("Load optimizer state dict.") - ... + assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \ + f"checkpoint it is:{len(states)}" + for i in range(len(self.optimizers)): + optimizer: Optimizer = self.optimizers[i] + optimizer.load_state_dict(states[f"optimizer{i}"]) + logger.debug("Load optimizer state dict.") + + def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: + + states = jt.load(str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) + + # 1. 加载 optimizers 的状态; + optimizers_state_dict = states.pop("optimizers_state_dict") + self.load_optimizer_state(optimizers_state_dict) + + # 2. 加载模型状态; + if should_load_model: + self.load_model(filepath=folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict=only_state_dict) + + # 3. 加载fp16的状态 + + # 4. 恢复 sampler 的状态; + dataloader_args = self.get_dataloader_args(dataloader) + if dataloader_args.sampler is None: + sampler = RandomSampler(dataloader_args.sampler.dataset, shuffle=dataloader_args.shuffle) + elif isinstance(dataloader_args.sampler, ReproducibleSampler): + sampler = dataloader_args.sampler + elif isinstance(dataloader_args.sampler, JittorRandomSampler): + sampler = RandomSampler(dataloader_args.sampler.dataset) + logger.debug("Replace jittor RandomSampler into fastNLP RandomSampler.") + elif isinstance(dataloader_args.sampler, JittorSequentialSampler): + sampler = RandomSampler(dataloader_args.sampler.dataset, shuffle=False) + logger.debug("Replace jittor Sampler into fastNLP RandomSampler without shuffle.") + elif self.is_distributed(): + raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our" + "`ReproducibleSampler`.") + else: + raise RuntimeError(f"Jittor sampler {type(dataloader_args.sampler)} is not supported now.") + sampler.load_state_dict(states.pop('sampler_states')) + states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) + + # 4. 修改 trainer_state.batch_idx_in_epoch + # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; + if dataloader_args.drop_last: + batch_idx_in_epoch = len( + sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size + else: + batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ + (sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size - def load_checkpoint(self): - ... + states["batch_idx_in_epoch"] = batch_idx_in_epoch + + return states def get_evaluate_context(self): return jt.no_grad @@ -198,26 +295,8 @@ class JittorDriver(Driver): """ return batch - @staticmethod - def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover - global_rank = rank if rank is not None else int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) - process_seed = jt.get_seed() - # back out the base seed so we can use all the bits - base_seed = process_seed - worker_id - ss = np.random.SeedSequence([base_seed, worker_id, global_rank]) - # use 128 bits (4 x 32-bit words) - np.random.seed(ss.generate_state(4)) - # Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module - jittor_ss, stdlib_ss = ss.spawn(2) - jt.set_global_seed(jittor_ss.generate_state(1, dtype=np.uint64)[0]) - # use 128 bits expressed as an integer - stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum() - random.seed(stdlib_seed) - def set_deterministic_dataloader(self, dataloader: Union["JittorDataLoader", "Dataset"]): - if int(os.environ.get(FASTNLP_SEED_WORKERS, 0)) and dataloader.worker_init_fn is None: - dataloader.worker_init_fn = partial(self.worker_init_function, - rank=int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))) + ... def set_sampler_epoch(self, dataloader: Union["JittorDataLoader", "Dataset"], cur_epoch_idx: int): # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; @@ -226,4 +305,45 @@ class JittorDriver(Driver): @staticmethod def get_dataloader_args(dataloader: Union["JittorDataLoader", "Dataset"]): - pass + @dataclass + class Res: + dataset: Optional[Dataset] = None + batch_sampler: Optional[JittorBatchSampler] = None + sampler: Optional[JittorSampler] = None + batch_size: Optional[int] = None + shuffle: Optional[bool] = None + drop_last: Optional[bool] = None + + res = Res() + from fastNLP.core.dataloaders.jittor_dataloader.fdl import _JittorDataset + if isinstance(dataloader, JittorDataLoader): + # JittorDataLoader 实际上是迭代 dataset 成员的 + dataloader = dataloader.dataset + if isinstance(dataloader, _JittorDataset): + # 获取最原始的 dataset + res.dataset = dataloader.dataset + else: + res.dataset = dataloader + + # jittor 现在不支持 batch_sampler,所以除了 shuffle 都可以直接获取 + res.batch_size = dataloader.batch_size + res.drop_last = dataloader.drop_last + if dataloader.sampler is None: + # sampler 是 None,那么就从 Dataset 的属性中获取 + res.shuffle = dataloader.shuffle + elif isinstance(list(dataloader.sampler.__iter__())[0], (list,tuple)): + # jittor 目前不支持 batch_sampler + raise NotImplementedError("Jittor does not support using batch_sampler in `Dataset` now, " + "please check if you have set `Dataset.sampler` as `BatchSampler`") + else: + # sampler 不为 None + res.sampler = dataloader.sampler + if hasattr(dataloader.sampler, "shuffle"): + # 这种情况一般出现在 fastNLP 的 ReproduceSampler 中 + res.shuffle = dataloader.sampler.shuffle + elif isinstance(dataloader.sampler, JittorRandomSampler): + res.shuffle = True + else: + res.shuffle = False + + return res \ No newline at end of file diff --git a/fastNLP/core/drivers/jittor_driver/mpi.py b/fastNLP/core/drivers/jittor_driver/mpi.py index 93187ede..b072b83d 100644 --- a/fastNLP/core/drivers/jittor_driver/mpi.py +++ b/fastNLP/core/drivers/jittor_driver/mpi.py @@ -38,6 +38,7 @@ class JittorMPIDriver(JittorDriver): ): super(JittorMPIDriver, self).__init__(model, fp16=fp16, **kwargs) + raise NotImplementedError("MPI for Jittor is not supported right now.") self.is_pull_by_jittor_run = is_pull_by_jittor_run self.parallel_device = parallel_device @@ -100,22 +101,6 @@ class JittorMPIDriver(JittorDriver): return self._data_device return self.parallel_device - def step(self): - # for optimizer in self.optimizers: - # self.grad_scaler.step(optimizer) - # self.grad_scaler.update() - for optimizer in self.optimizers: - optimizer.step() - - def backward(self, loss): - # self.grad_scaler.scale(loss).backward() - for optimizer in self.optimizers: - optimizer.backward(loss) - - def zero_grad(self): - for optimizer in self.optimizers: - optimizer.zero_grad() - def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: if isinstance(batch, Dict) and not self.wo_auto_param_call: return auto_param_call(fn, batch, signature_fn=signature_fn) diff --git a/fastNLP/core/drivers/jittor_driver/single_device.py b/fastNLP/core/drivers/jittor_driver/single_device.py index b559fd92..4e9b3447 100644 --- a/fastNLP/core/drivers/jittor_driver/single_device.py +++ b/fastNLP/core/drivers/jittor_driver/single_device.py @@ -1,14 +1,21 @@ from typing import Dict, Union, Tuple, Callable, Optional from .jittor_driver import JittorDriver +from .utils import replace_batch_sampler, replace_sampler from fastNLP.core.utils import auto_param_call from fastNLP.core.utils.utils import _get_fun_msg from fastNLP.envs.imports import _NEED_IMPORT_JITTOR -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, \ + ReproduceBatchSampler +from fastNLP.core.samplers import RandomSampler from fastNLP.core.log import logger if _NEED_IMPORT_JITTOR: import jittor as jt + from jittor.dataset import ( + RandomSampler as JittorRandomSampler, + SequentialSampler as JittorSequentialSampler, + ) __all__ = [ "JittorSingleDriver", @@ -89,31 +96,46 @@ class JittorSingleDriver(JittorDriver): """ return False - def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], - reproducible: bool = False, sampler_or_batch_sampler=None): - # reproducible 的相关功能暂时没有实现 + def set_dist_repro_dataloader(self, dataloader, + dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler] = None, + reproducible: bool = False): + # 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load_checkpoint 函数调用; if isinstance(dist, ReproducibleBatchSampler): - raise NotImplementedError - dataloader.batch_sampler = dist_sample - if isinstance(dist, ReproducibleSampler): - raise NotImplementedError - dataloader.batch_sampler.sampler = dist + return replace_batch_sampler(dataloader, dist) + elif isinstance(dist, ReproducibleSampler): + return replace_sampler(dataloader, dist) + + # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; + args = self.get_dataloader_args(dataloader) + if isinstance(args.batch_sampler, ReproducibleBatchSampler): + batch_sampler = re_instantiate_sampler(args.batch_sampler) + return replace_batch_sampler(dataloader, batch_sampler) + elif isinstance(args.sampler, ReproducibleSampler): + sampler = re_instantiate_sampler(args.sampler) + return replace_sampler(dataloader, sampler) if reproducible: - raise NotImplementedError - if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): - return dataloader - elif isinstance(dataloader.batch_sampler, RandomBatchSampler): - return dataloader - else: - # TODO - batch_sampler = RandomBatchSampler( - batch_sampler=dataloader.batch_sampler, - batch_size=dataloader.batch_sampler.batch_size, - drop_last=dataloader.drop_last - ) - dataloader.batch_sampler = batch_sampler - return dataloader + if args.sampler is None: + sampler = RandomSampler(args.dataset, args.shuffle) + return replace_sampler(dataloader, sampler) + elif isinstance(args.sampler, JittorRandomSampler): + if getattr(args.sampler, '_num_samples', None) is None \ + and getattr(args.sampler, 'rep', False) is False: + # 如果本来就是随机的,并且没有定制,直接替换掉吧。 + sampler = RandomSampler(args.sampler.dataset, shuffle=True) + logger.debug("Replace jittor RandomSampler into fastNLP RandomSampler.") + return replace_sampler(dataloader, sampler) + elif isinstance(args.sampler, JittorSequentialSampler): + # 需要替换为不要 shuffle 的。 + sampler = RandomSampler(args.sampler.dataset, shuffle=False) + logger.debug("Replace jittor SequentialSampler into fastNLP RandomSampler.") + return replace_sampler(dataloader, sampler) + batch_sampler = ReproduceBatchSampler( + batch_sampler=args.batch_sampler, + batch_size=args.batch_size, + drop_last=args.drop_last + ) + return replace_batch_sampler(dataloader, batch_sampler) else: return dataloader diff --git a/fastNLP/core/drivers/jittor_driver/utils.py b/fastNLP/core/drivers/jittor_driver/utils.py index 50eed7e3..046603d0 100644 --- a/fastNLP/core/drivers/jittor_driver/utils.py +++ b/fastNLP/core/drivers/jittor_driver/utils.py @@ -1,6 +1,29 @@ +import inspect +from copy import deepcopy +from typing import Union + +from fastNLP.core.dataloaders import JittorDataLoader from fastNLP.envs.imports import _NEED_IMPORT_JITTOR if _NEED_IMPORT_JITTOR: - import jittor + from jittor.dataset import Dataset __all__ = [] + +def replace_batch_sampler(dataloader, batch_sampler): + raise NotImplementedError("Jittor does not support using batch_sampler in `Dataset` now, " + "please check if you have set `Dataset.sampler` as `BatchSampler`" + "or report this bug to us.") + +def replace_sampler(dataloader: Union["Dataset", "JittorDataLoader"], sampler): + if isinstance(dataloader, JittorDataLoader): + init_params = dict(inspect.signature(dataloader.__init__).parameters) + reconstruct_args = {name: getattr(dataloader, name, p.default) for name, p in init_params.items()} + reconstruct_args["dataset"] = replace_sampler(reconstruct_args["dataset"].dataset, reconstruct_args["dataset"].sampler) + new_dataloader = type(dataloader)(**reconstruct_args) + new_dataloader.dataset.set_attrs(sampler=sampler) + else: + new_dataloader = deepcopy(dataloader) + new_dataloader.set_attrs(sampler=sampler) + + return new_dataloader \ No newline at end of file diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index 090bf567..f809f9ec 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -31,7 +31,6 @@ if _NEED_IMPORT_PADDLE: import paddle from paddle.io import ( DataLoader, - IterableDataset, Dataset, Sampler, BatchSampler, @@ -97,6 +96,9 @@ class PaddleDriver(Driver): def check_dataloader_legality(self, dataloader): if not isinstance(dataloader, DataLoader): raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") + if dataloader.batch_size is None and dataloader.batch_sampler is None: + raise ValueError("Please ensure at least one of your dataloader's batch_size and batch_sampler" + "is not None") @staticmethod def _check_optimizer_legality(optimizers): @@ -107,7 +109,7 @@ class PaddleDriver(Driver): """ for each_optimizer in optimizers: if not isinstance(each_optimizer, Optimizer): - raise ValueError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, " + raise TypeError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, " f"not {type(each_optimizer)}.") @staticmethod @@ -263,9 +265,7 @@ class PaddleDriver(Driver): optimizers_state_dict = {} for i in range(len(self.optimizers)): optimizer: Optimizer = self.optimizers[i] - optimizer_state = optimizer.state_dict() - optimizer_state["state"] = optimizer_state_to_device(optimizer_state, "cpu") - optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; + optimizers_state_dict[f"optimizer{i}"] = optimizer_state_to_device(optimizer.state_dict(), "cpu") return optimizers_state_dict @@ -399,6 +399,8 @@ class PaddleDriver(Driver): def set_sampler_epoch(self, dataloader: "DataLoader", cur_epoch_idx): if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): dataloader.batch_sampler.set_epoch(cur_epoch_idx) + elif callable(getattr(dataloader.batch_sampler.sampler, "set_epoch", None)): + dataloader.batch_sampler.sampler.set_epoch(cur_epoch_idx) @staticmethod def get_dataloader_args(dataloader: "DataLoader"): diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 17d65d54..9449782b 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -99,7 +99,7 @@ class TorchDriver(Driver): def _check_optimizer_legality(optimizers): for each_optimizer in optimizers: if not isinstance(each_optimizer, Optimizer): - raise ValueError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' type, " + raise TypeError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' type, " f"not {type(each_optimizer)}.") @staticmethod diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index 679bb1cf..874ad895 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -210,7 +210,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): self.num_consumed_samples = 0 self.during_iter = True - indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) + indices = list(range(self.num_samples)) if self.shuffle: if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的 @@ -237,7 +237,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): if len(indices)%self.batch_size!=0: batches.append(indices[_num_batches*self.batch_size:]) - need_pad_num = (getattr(self.dataset, 'total_len', len(self.dataset))-self.num_consumed_samples) % self.num_replicas + need_pad_num = (self.num_samples-self.num_consumed_samples) % self.num_replicas if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: if len(batches) > 0: if len(batches[-1])int: """ @@ -332,7 +336,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" " consumed. ") states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, - 'sampler_type': self.__class__.__name__, 'length': getattr(self.dataset, 'total_len', len(self.dataset)), 'shuffle': self.shuffle, + 'sampler_type': self.__class__.__name__, 'length': self.num_samples, 'shuffle': self.shuffle, 'batch_size': self.batch_size, 'num_replicas': self.num_replicas} @@ -347,7 +351,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): f"we cannot use {self.__class__.__name__} to load it." length = states['length'] - assert length == getattr(self.dataset, 'total_len', len(self.dataset)), "The number of samples is different between the checkpoint record " \ + assert length == self.num_samples, "The number of samples is different between the checkpoint record " \ "and current dataset." self.seed = states['seed'] self.epoch = states['epoch'] @@ -464,8 +468,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler): :return: """ num_consumed_samples = self.num_consumed_samples - return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \ - self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas)) + return math.ceil((self.num_samples - num_consumed_samples) / self.num_replicas) if \ + self.pad else math.floor(((self.num_samples - num_consumed_samples) / self.num_replicas)) + + @property + def num_samples(self): + return getattr(self.dataset, 'total_len', len(self.dataset)) def __len__(self)->int: """ @@ -515,7 +523,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): if len(sorted_indices)%self.batch_size!=0: batches.append(sorted_indices[_num_batches*self.batch_size:]) - need_pad_num = (getattr(self.dataset, 'total_len', len(self.dataset))-self.num_consumed_samples) % self.num_replicas + need_pad_num = (self.num_samples-self.num_consumed_samples) % self.num_replicas if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: if len(batches) > 0: if len(batches[-1]) Dict: states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, - 'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle} + 'sampler_type': self.__class__.__name__, 'length': self.num_samples, 'shuffle': self.shuffle} return states def load_state_dict(self, states: Dict): @@ -155,8 +159,8 @@ class RandomSampler(ReproducibleSampler): f"we cannot use {self.__class__.__name__} to load it." length = states['length'] - assert length == getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of samples is different between the checkpoint record({length}) " \ - f"and current dataset({getattr(self.dataset, 'total_len', len(self.dataset))})." + assert length == self.num_samples, "The number of samples is different between the checkpoint " \ + f"record({length}) and current dataset({self.num_samples})." self.seed = states['seed'] self.epoch = states['epoch'] self.num_consumed_samples = states['num_consumed_samples'] @@ -208,9 +212,17 @@ class RandomSampler(ReproducibleSampler): :return: """ num_consumed_samples = self.num_consumed_samples - return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \ - self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas)) + return math.ceil((self.num_samples - num_consumed_samples) / self.num_replicas) if \ + self.pad else math.floor(((self.num_samples - num_consumed_samples) / self.num_replicas)) + + @property + def num_samples(self): + """ + 返回样本的总数 + :return: + """ + return getattr(self.dataset, 'total_len', len(self.dataset)) class SequentialSampler(RandomSampler): """ @@ -258,12 +270,10 @@ class SequentialSampler(RandomSampler): :return: """ - return list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) + return list(range(self.num_samples)) def state_dict(self) -> Dict: - states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, - 'length': getattr(self.dataset, 'total_len', len(self.dataset)) - } + states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, 'length': self.num_samples} return states def load_state_dict(self, states: Dict): @@ -275,8 +285,8 @@ class SequentialSampler(RandomSampler): f"we cannot use {self.__class__.__name__} to load it." length = states['length'] - assert length == getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of samples is different between the checkpoint record({length}) " \ - f"and current dataset({getattr(self.dataset, 'total_len', len(self.dataset))})." + assert length == self.num_samples, "The number of samples is different between the checkpoint " \ + f"record({length}) and current dataset({self.num_samples})." self.num_consumed_samples = states['num_consumed_samples'] if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 self.num_consumed_samples = 0 @@ -314,9 +324,9 @@ class SortedSampler(SequentialSampler): except BaseException as e: logger.error(f"Cannot use {self.__class__.__name__} as length, since it is not sortable.") - assert len(length) == getattr(self.dataset, 'total_len', len(self.dataset)), f"The length of `dataset`({len(dataset)}) and " \ - f"`length`({getattr(self.dataset, 'total_len', len(self.dataset))}) should be equal." - assert len(self.sorted_indices) == getattr(self.dataset, 'total_len', len(self.dataset)), "The indices and dataset should have equal length." + assert len(length) == self.num_samples, f"The length of `dataset`({len(dataset)}) and " \ + f"`length`({self.num_samples}) should be equal." + assert len(self.sorted_indices) == self.num_samples, "The indices and dataset should have equal length." self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的 diff --git a/fastNLP/core/samplers/unrepeated_sampler.py b/fastNLP/core/samplers/unrepeated_sampler.py index 0ff55674..69eb532d 100644 --- a/fastNLP/core/samplers/unrepeated_sampler.py +++ b/fastNLP/core/samplers/unrepeated_sampler.py @@ -42,8 +42,8 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): 返回 sampler 一次完整的迭代过程会产生多少个index。多卡的情况下,只考虑当前rank; :return: """ - num_common = getattr(self.dataset, 'total_len', len(self.dataset))//self.num_replicas - num_samples = num_common + int(self.rank < (getattr(self.dataset, 'total_len', len(self.dataset))-num_common*self.num_replicas)) + num_common = self.num_samples//self.num_replicas + num_samples = num_common + int(self.rank < (self.num_samples-num_common*self.num_replicas)) return num_samples def __iter__(self): @@ -63,14 +63,14 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): :return: """ if self.shuffle: - indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) + indices = list(range(self.num_samples)) seed = self.seed + self.epoch rng = np.random.default_rng(abs(seed)) rng.shuffle(indices) if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 self.epoch -= 1 else: - indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) + indices = list(range(self.num_samples)) return indices def set_epoch(self, epoch: int) -> None: @@ -84,8 +84,8 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): :param rank: :return: """ - assert num_replicas<=getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of replicas({num_replicas}) should be lesser than the " \ - f"number of samples({getattr(self.dataset, 'total_len', len(self.dataset))})." + assert num_replicas<=self.num_samples, f"The number of replicas({num_replicas}) should be lesser than the " \ + f"number of samples({self.num_samples})." assert num_replicas>0 and isinstance(num_replicas, int) assert isinstance(rank, int) and 0<=rank List[int]: - return list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) + return list(range(self.num_samples)) diff --git a/tests/core/controllers/_test_trainer_fleet.py b/tests/core/controllers/_test_trainer_fleet.py index 89c8762e..9221cec6 100644 --- a/tests/core/controllers/_test_trainer_fleet.py +++ b/tests/core/controllers/_test_trainer_fleet.py @@ -27,7 +27,7 @@ from paddle.optimizer import Adam from paddle.io import DataLoader from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 -from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset +from tests.helpers.datasets.paddle_data import PaddleArgMaxDataset from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback @dataclass @@ -52,12 +52,12 @@ def test_trainer_fleet( optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001) train_dataloader = DataLoader( - dataset=PaddleRandomMaxDataset(20, MNISTTrainFleetConfig.feature_dimension), + dataset=PaddleArgMaxDataset(20, MNISTTrainFleetConfig.feature_dimension), batch_size=MNISTTrainFleetConfig.batch_size, shuffle=True ) val_dataloader = DataLoader( - dataset=PaddleRandomMaxDataset(12, MNISTTrainFleetConfig.feature_dimension), + dataset=PaddleArgMaxDataset(12, MNISTTrainFleetConfig.feature_dimension), batch_size=MNISTTrainFleetConfig.batch_size, shuffle=True ) diff --git a/tests/core/controllers/_test_trainer_fleet_outside.py b/tests/core/controllers/_test_trainer_fleet_outside.py index b021d1f6..aa62abac 100644 --- a/tests/core/controllers/_test_trainer_fleet_outside.py +++ b/tests/core/controllers/_test_trainer_fleet_outside.py @@ -24,7 +24,7 @@ from paddle.io import DataLoader import paddle.distributed.fleet as fleet from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_2 -from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset +from tests.helpers.datasets.paddle_data import PaddleArgMaxDataset from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback @dataclass @@ -54,12 +54,12 @@ def test_trainer_fleet( optimizers = fleet.distributed_optimizer(optimizers) train_dataloader = DataLoader( - dataset=PaddleRandomMaxDataset(20, MNISTTrainFleetConfig.feature_dimension), + dataset=PaddleArgMaxDataset(20, MNISTTrainFleetConfig.feature_dimension), batch_size=MNISTTrainFleetConfig.batch_size, shuffle=True ) val_dataloader = DataLoader( - dataset=PaddleRandomMaxDataset(12, MNISTTrainFleetConfig.feature_dimension), + dataset=PaddleArgMaxDataset(12, MNISTTrainFleetConfig.feature_dimension), batch_size=MNISTTrainFleetConfig.batch_size, shuffle=True ) diff --git a/tests/core/controllers/_test_trainer_jittor.py b/tests/core/controllers/_test_trainer_jittor.py index 13ab2e8b..a2369b28 100644 --- a/tests/core/controllers/_test_trainer_jittor.py +++ b/tests/core/controllers/_test_trainer_jittor.py @@ -46,8 +46,8 @@ class LSTM(Module): def init_hidden(self, x): # batch_first batch_size = x.shape[0] - h0 = jt.randn(1, batch_size, hidden_size) - c0 = jt.randn(1, batch_size, hidden_size) + h0 = jt.randn(1, batch_size, self.hidden_size) + c0 = jt.randn(1, batch_size, self.hidden_size) return h0, c0 diff --git a/tests/core/controllers/test_trainer_jittor.py b/tests/core/controllers/test_trainer_jittor.py index b6cefdf3..94d85f22 100644 --- a/tests/core/controllers/test_trainer_jittor.py +++ b/tests/core/controllers/test_trainer_jittor.py @@ -1,4 +1,5 @@ import pytest +from fastNLP.core.callbacks import callback from fastNLP.core.controllers.trainer import Trainer from fastNLP.core.controllers.trainer import Evaluator @@ -14,6 +15,7 @@ if _NEED_IMPORT_JITTOR: else: from fastNLP.core.utils.dummy_class import DummyClass as Module from fastNLP.core.utils.dummy_class import DummyClass as Dataset +jt.flags.use_cuda=1 class JittorNormalModel_Classification(Module): @@ -68,11 +70,9 @@ class TrainJittorConfig: batch_size: int = 4 shuffle: bool = True - @pytest.mark.parametrize("driver", ["jittor"]) -@pytest.mark.parametrize("device", ["cpu", "gpu", "cuda:0"]) +@pytest.mark.parametrize("device", ["cpu", "gpu", "cuda", None]) @pytest.mark.parametrize("callbacks", [[RichCallback(100)]]) -@pytest.mark.jittor def test_trainer_jittor( driver, device, diff --git a/tests/core/controllers/test_trainer_paddle.py b/tests/core/controllers/test_trainer_paddle.py index b38f0161..834c9363 100644 --- a/tests/core/controllers/test_trainer_paddle.py +++ b/tests/core/controllers/test_trainer_paddle.py @@ -15,7 +15,7 @@ if _NEED_IMPORT_PADDLE: from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 -from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset +from tests.helpers.datasets.paddle_data import PaddleArgMaxDataset from tests.helpers.utils import magic_argv_env_context @dataclass @@ -44,12 +44,12 @@ def test_trainer_paddle( ) optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001) train_dataloader = DataLoader( - dataset=PaddleRandomMaxDataset(20, TrainPaddleConfig.feature_dimension), + dataset=PaddleArgMaxDataset(20, TrainPaddleConfig.feature_dimension), batch_size=TrainPaddleConfig.batch_size, shuffle=True ) val_dataloader = DataLoader( - dataset=PaddleRandomMaxDataset(12, TrainPaddleConfig.feature_dimension), + dataset=PaddleArgMaxDataset(12, TrainPaddleConfig.feature_dimension), batch_size=TrainPaddleConfig.batch_size, shuffle=True ) diff --git a/tests/core/dataloaders/paddle_dataloader/test_fdl.py b/tests/core/dataloaders/paddle_dataloader/test_fdl.py index 717f3308..1a90aa11 100644 --- a/tests/core/dataloaders/paddle_dataloader/test_fdl.py +++ b/tests/core/dataloaders/paddle_dataloader/test_fdl.py @@ -76,7 +76,7 @@ class TestPaddle: from paddle.io import Dataset import paddle - class PaddleRandomMaxDataset(Dataset): + class PaddleArgMaxDataset(Dataset): def __init__(self, num_samples, num_features): self.x = paddle.randn((num_samples, num_features)) self.y = self.x.argmax(axis=-1) @@ -87,7 +87,7 @@ class TestPaddle: def __getitem__(self, item): return {"x": self.x[item], "y": self.y[item]} - ds = PaddleRandomMaxDataset(100, 2) + ds = PaddleArgMaxDataset(100, 2) dl = DataLoader(ds, places=None, collate_fn=Collator(), batch_size=4) for batch in dl: print(batch) \ No newline at end of file diff --git a/tests/core/drivers/jittor_driver/test_initialize_jittor_driver.py b/tests/core/drivers/jittor_driver/test_initialize_jittor_driver.py new file mode 100644 index 00000000..f03147d0 --- /dev/null +++ b/tests/core/drivers/jittor_driver/test_initialize_jittor_driver.py @@ -0,0 +1,45 @@ +import pytest + +from fastNLP.core.drivers import JittorSingleDriver, JittorMPIDriver +from fastNLP.core.drivers.jittor_driver.initialize_jittor_driver import initialize_jittor_driver +from tests.helpers.models.jittor_model import JittorNormalModel_Classification_1 +from fastNLP.envs.imports import _NEED_IMPORT_JITTOR +if _NEED_IMPORT_JITTOR: + import jittor as jt + +@pytest.mark.jittor +def test_incorrect_driver(): + + model = JittorNormalModel_Classification_1(20, 10) + with pytest.raises(ValueError): + driver = initialize_jittor_driver("torch", 0, model) + +@pytest.mark.jittor +@pytest.mark.parametrize( + "device", + ["cpu", "gpu", None, "cuda"] +) +def test_get_single_device(device): + """ + 测试正常情况下初始化 JittorSingleDriver 的情况 + """ + + model = JittorNormalModel_Classification_1(20, 10) + driver = initialize_jittor_driver("jittor", device, model) + assert isinstance(driver, JittorSingleDriver) + +@pytest.mark.jittor +@pytest.mark.parametrize( + "device", + [[0, 2, 3], 1, 2] +) +def test_get_mpi(device): + """ + 测试 jittor 多卡的初始化情况 + """ + + model = JittorNormalModel_Classification_1(20, 10) + with pytest.raises(NotImplementedError): + driver = initialize_jittor_driver("jittor", device, model) + + # assert isinstance(driver, JittorMPIDriver) diff --git a/tests/core/drivers/jittor_driver/test_jittor_driver.py b/tests/core/drivers/jittor_driver/test_jittor_driver.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/core/drivers/jittor_driver/test_single_device.py b/tests/core/drivers/jittor_driver/test_single_device.py index 2e220974..7597daf7 100644 --- a/tests/core/drivers/jittor_driver/test_single_device.py +++ b/tests/core/drivers/jittor_driver/test_single_device.py @@ -1,99 +1,614 @@ import pytest -import os +from copy import deepcopy +from pathlib import Path -import numpy as np - -from fastNLP.core.drivers.jittor_driver.single_device import JittorSingleDriver -from fastNLP.envs.imports import _NEED_IMPORT_JITTOR +from fastNLP.core.drivers.jittor_driver import JittorSingleDriver +from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler +from fastNLP.core.dataloaders import JittorDataLoader +from tests.helpers.models.jittor_model import JittorNormalModel_Classification_1 +from tests.helpers.datasets.jittor_data import JittorNormalDataset, JittorNormalXYDataset +from tests.helpers.datasets.torch_data import TorchNormalDataset +from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 +from fastNLP.envs.distributed import rank_zero_rm +from fastNLP.envs.imports import _NEED_IMPORT_JITTOR, _NEED_IMPORT_TORCH if _NEED_IMPORT_JITTOR: - import jittor as jt # 将 jittor 引入 - from jittor import nn, Module # 引入相关的模块 - from jittor import init - from jittor.dataset import MNIST -else: - from fastNLP.core.utils.dummy_class import DummyClass as Module + import jittor as jt + from jittor.dataset import ( + BatchSampler as JittorBatchSampler, + RandomSampler as JittorRandomSampler, + SequentialSampler as JittorSequentialSampler, + SubsetRandomSampler as JittorSubsetRandomSampler + ) +if _NEED_IMPORT_TORCH: + import torch +def get_dataloader(dataset, use_dataloader, sampler, batch_size, shuffle, drop_last=False): + """ + :param dataset: + :param use_dataloader: 是否使用 JittorDataLoader 包裹 + :param sampler: 使用 BatchSampler Samlper 还是不使用 Sampler + """ + if use_dataloader: + dataloader = JittorDataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last) + dataloader.dataset.set_attrs(sampler=sampler) + else: + dataloader = dataset + dataloader.set_attrs(batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, sampler=sampler) -class Model(Module): - def __init__ (self): - super (Model, self).__init__() - self.conv1 = nn.Conv (3, 32, 3, 1) # no padding - - self.conv2 = nn.Conv (32, 64, 3, 1) - self.bn = nn.BatchNorm(64) - - self.max_pool = nn.Pool (2, 2) - self.relu = nn.Relu() - self.fc1 = nn.Linear (64 * 12 * 12, 256) - self.fc2 = nn.Linear (256, 10) - - def execute(self, x) : - # it's simliar to forward function in Pytorch - x = self.conv1 (x) - x = self.relu (x) - - x = self.conv2 (x) - x = self.bn (x) - x = self.relu (x) - - x = self.max_pool (x) - x = jt.reshape (x, [x.shape[0], -1]) - x = self.fc1 (x) - x = self.relu(x) - x = self.fc2 (x) - return x + return dataloader +############################################################################ +# +# 测试基类 JittorDrvier 中的一些简单函数 +# +############################################################################ + +class TestJittorDriverFunctions: + """ + 使用 JittorSingleDriver 测试基类的函数 + """ + + @classmethod + def setup_class(self): + model = JittorNormalModel_Classification_1(10, 32) + self.driver = JittorSingleDriver(model, device="cpu") + + @pytest.mark.jittor + def test_check_optimizers_legality(self): + """ + 测试对合法的 optimizers 的检查 + """ + # 单个 optimizer + optimizer = jt.optim.Adam( + params=self.driver.model.parameters(), + lr=0.01 + ) + self.driver.set_optimizers(optimizer) + + # optimizer 列表 + optimizers = [ + jt.optim.Adam( + params=self.driver.model.parameters(), + lr=0.01 + ) for i in range(10) + ] + self.driver.set_optimizers(optimizers) + + @pytest.mark.torchjittor + def test_invalid_optimizers(self): + """ + 测试传入非法的 optimizers + """ + # 单个 optimizer + optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) + with pytest.raises(TypeError): + self.driver.set_optimizers(optimizer) + + optimizers = [ + torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) + ] + + with pytest.raises(TypeError): + self.driver.set_optimizers(optimizers) + + @pytest.mark.jittor + def test_check_dataloader_legality(self): + """ + 测试 check_dataloader_legality 函数的表现 + """ + # 使用 JittorDataLoader + dataloader = JittorDataLoader(JittorNormalDataset()) + self.driver.check_dataloader_legality(dataloader) + # 使用 jittor.dataset.Dataset + self.driver.check_dataloader_legality(JittorNormalDataset()) + + @pytest.mark.torchjittor + def test_check_dataloader_legality_invalid(self): + """ + 测试 check_dataloader_legality 函数传入其他类型的表现 + """ + # 创建 torch 的 dataloader + dataloader = torch.utils.data.DataLoader( + TorchNormalDataset(), + batch_size=32, shuffle=True + ) + with pytest.raises(TypeError): + self.driver.check_dataloader_legality(dataloader) + + @pytest.mark.jittor + def test_tensor_to_numeric(self): + """ + 测试 tensor_to_numeric 函数 + """ + # 单个张量 + tensor = jt.Var(3) + res = JittorSingleDriver.tensor_to_numeric(tensor) + assert res == 3 + + tensor = jt.rand(3, 4) + res = JittorSingleDriver.tensor_to_numeric(tensor) + assert res == tensor.tolist() + + # 张量list + tensor_list = [jt.rand(6, 4, 2) for i in range(10)] + res = JittorSingleDriver.tensor_to_numeric(tensor_list) + assert isinstance(res, list) + tensor_list = [t.tolist() for t in tensor_list] + assert res == tensor_list + + # 张量tuple + tensor_tuple = tuple([jt.rand(6, 4, 2) for i in range(10)]) + res = JittorSingleDriver.tensor_to_numeric(tensor_tuple) + assert isinstance(res, tuple) + tensor_tuple = tuple([t.tolist() for t in tensor_tuple]) + assert res == tensor_tuple + + # 张量dict + tensor_dict = { + "tensor": jt.rand(3, 4), + "list": [jt.rand(6, 4, 2) for i in range(10)], + "dict":{ + "list": [jt.rand(6, 4, 2) for i in range(10)], + "tensor": jt.rand(3, 4) + }, + "int": 2, + "string": "test string" + } + + res = JittorSingleDriver.tensor_to_numeric(tensor_dict) + assert isinstance(res, dict) + assert res["tensor"] == tensor_dict["tensor"].tolist() + assert isinstance(res["list"], list) + for r, d in zip(res["list"], tensor_dict["list"]): + assert r == d.tolist() + assert isinstance(res["int"], int) + assert isinstance(res["string"], str) + assert isinstance(res["dict"], dict) + assert isinstance(res["dict"]["list"], list) + for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]): + assert r == d.tolist() + assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() + + @pytest.mark.jittor + def test_tensor_to_numeric_reduce(self): + tensor = jt.Var([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + + res_max = JittorSingleDriver.tensor_to_numeric(tensor, reduce="max") + res_min = JittorSingleDriver.tensor_to_numeric(tensor, reduce="min") + res_sum = JittorSingleDriver.tensor_to_numeric(tensor, reduce="sum") + res_mean = JittorSingleDriver.tensor_to_numeric(tensor, reduce="mean") + + assert res_max == 6 + assert res_min == 1 + assert res_sum == 21 + assert res_mean == 3.5 + + + @pytest.mark.jittor + def test_set_model_mode(self): + """ + 测试 set_model_mode 函数 + """ + self.driver.set_model_mode("train") + assert self.driver.model.is_training() + self.driver.set_model_mode("eval") + assert not self.driver.model.is_training() + # 应该报错 + with pytest.raises(AssertionError): + self.driver.set_model_mode("test") + + @pytest.mark.jittor + def test_move_model_to_device_cpu(self): + """ + 测试 move_model_to_device 函数,仅测试能否运行 + """ + JittorSingleDriver.move_model_to_device(self.driver.model, "cpu") + + @pytest.mark.jittor + def test_move_model_to_device_gpu(self): + """ + 测试 move_model_to_device 函数,仅测试能否运行 + """ + JittorSingleDriver.move_model_to_device(self.driver.model, "gpu") + + @pytest.mark.jittor + def test_set_deterministic_dataloader(self): + """ + 测试 set_deterministic_dataloader,仅测试能否运行 + """ + # 先确保不影响运行 + # TODO:正确性 + dataloader = JittorDataLoader(JittorNormalDataset()) + self.driver.set_deterministic_dataloader(dataloader) + self.driver.set_deterministic_dataloader(JittorNormalDataset()) + + @pytest.mark.jittor + def test_set_sampler_epoch(self): + """ + 测试 set_sampler_epoch + """ + # 先确保不影响运行 + # TODO:正确性 + dataloader = JittorDataLoader(JittorNormalDataset()) + self.driver.set_sampler_epoch(dataloader, 0) + self.driver.set_sampler_epoch(JittorNormalDataset(), 0) + + @pytest.mark.jittor + @pytest.mark.parametrize("batch_size", [16]) + @pytest.mark.parametrize("shuffle", [True, False]) + @pytest.mark.parametrize("drop_last", [True, False]) + @pytest.mark.parametrize("use_dataloader", [True, False]) + def test_get_dataloader_args(self, batch_size, shuffle, drop_last, use_dataloader): + """ + 测试正常情况下 get_dataloader_args 的表现 + """ + dataloader = get_dataloader( + JittorNormalDataset(), + use_dataloader=use_dataloader, + sampler=None, + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last + ) + res = JittorSingleDriver.get_dataloader_args(dataloader) + + assert isinstance(res.dataset, JittorNormalDataset) + assert res.sampler is None + assert res.shuffle == shuffle + assert res.batch_size == batch_size + assert res.drop_last == drop_last + + @pytest.mark.jittor + @pytest.mark.parametrize("batch_size", [16]) + @pytest.mark.parametrize("shuffle", [True, False]) + @pytest.mark.parametrize("drop_last", [True, False]) + @pytest.mark.parametrize("use_dataloader", [True, False]) + def test_get_dataloader_args_with_randomsampler(self, batch_size, shuffle, drop_last, use_dataloader): + """ + 测试替换了 sampler 后 get_dataloader_args 的表现 + """ + dataset = JittorNormalDataset() + dataloader = get_dataloader( + dataset, + use_dataloader=use_dataloader, + batch_size=batch_size, + sampler=RandomSampler(dataset, shuffle=shuffle), + shuffle=shuffle, + drop_last=drop_last + ) + + res = JittorSingleDriver.get_dataloader_args(dataloader) + + assert isinstance(res.dataset, JittorNormalDataset) + assert isinstance(res.sampler, RandomSampler) + assert res.shuffle == shuffle + assert res.batch_size == batch_size + assert res.drop_last == drop_last + + +############################################################################ +# +# 测试 JittorSingleDrvier 中的一些简单函数 +# +############################################################################ @pytest.mark.jittor -@pytest.mark.skip("Skip jittor tests now.") -class TestSingleDevice: - - def test_on_gpu_without_fp16(self): - # TODO get_dataloader - batch_size = 64 - learning_rate = 0.1 - epochs = 5 - losses = [] - losses_idx = [] - - train_loader = MNIST(train=True, batch_size=batch_size, shuffle=True) - val_loader = MNIST(train=False, batch_size=1, shuffle=False) - - model = Model() - driver = JittorSingleDriver(model, device=[1]) - optimizer = nn.SGD(model.parameters(), learning_rate) - driver.set_optimizers(optimizer) - - for epoch in range(epochs): - driver.set_model_mode("train") - lens = len(train_loader) - for batch_idx, (inputs, targets) in enumerate(train_loader): - outputs =driver.train_step(inputs) - loss = nn.cross_entropy_loss(outputs, targets) - driver.backward(loss) - driver.step() - driver.zero_grad() - losses.append(loss.data[0]) - losses_idx.append(epoch * lens + batch_idx) - - test_loss = 0 - correct = 0 - total_acc = 0 - total_num = 0 - driver.set_model_mode("eval") - for batch_idx, (inputs, targets) in enumerate(val_loader): - batch_size = inputs.shape[0] - outputs = driver.test_step(inputs) - pred = np.argmax(outputs.data, axis=1) - acc = np.sum(targets.data==pred) - total_acc += acc - total_num += batch_size - acc = acc / batch_size - assert total_acc / total_num > 0.95 - - - def test_on_cpu_without_fp16(self): - pass - - def test_on_gpu_with_fp16(self): - pass \ No newline at end of file +class TestSingleDeviceFunction: + """ + 测试其它函数的测试例 + """ + + @classmethod + def setup_class(cls): + model = JittorNormalModel_Classification_1(10, 784) + cls.driver = JittorSingleDriver(model, device="cpu") + + def test_unwrap_model(self): + """ + 测试能否运行 + """ + res = self.driver.unwrap_model() + assert res is self.driver.model + + def test_is_distributed(self): + assert self.driver.is_distributed() == False + + def test_move_data_to_device(self): + self.driver.move_data_to_device(jt.rand(32, 64)) + + +############################################################################ +# +# 测试 set_dist_repro_dataloader 函数 +# +############################################################################ + +@pytest.mark.jittor +class TestSetDistReproDataloader: + """ + 专门测试 set_dist_repro_dataloader 函数的类 + """ + def setup_method(self): + self.dataset = JittorNormalDataset(20) + model = JittorNormalModel_Classification_1(10, 32) + self.driver = JittorSingleDriver(model, device="cpu") + + @pytest.mark.parametrize("use_dataloader", [True, False]) + def test_with_reproducible_false(self, use_dataloader): + """ + 测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现 + 当dist为字符串时,此时应该返回原来的 dataloader + """ + dataloader = get_dataloader(self.dataset, use_dataloader, sampler=None, batch_size=2, shuffle=True) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) + + assert replaced_loader is dataloader + + @pytest.mark.parametrize("shuffle", [True, False]) + @pytest.mark.parametrize("sampler", [None, "random", "sequential"]) + @pytest.mark.parametrize("use_dataloader", [True, False]) + def test_with_reproducible_true(self, shuffle, sampler, use_dataloader): + """ + 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 + 当dist为字符串时,此时应该返回新的 dataloader,会替换 sampler 为 RandomSampler + """ + if sampler == "random": + sampler = JittorRandomSampler(self.dataset) + _shuffle = True + elif sampler == "sequential": + sampler = JittorSequentialSampler(self.dataset) + _shuffle = False + else: + _shuffle = shuffle + dataloader = get_dataloader(self.dataset, use_dataloader, sampler=sampler, batch_size=2, shuffle=shuffle) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.sampler, RandomSampler) + assert replaced_loader.sampler.shuffle == _shuffle + assert replaced_loader.batch_size == dataloader.batch_size + assert replaced_loader.drop_last == dataloader.drop_last + + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle, use_dataloader) + + @pytest.mark.parametrize("shuffle", ([True, False])) + @pytest.mark.parametrize("use_dataloader", [True, False]) + def test_with_dist_batch_sampler(self, shuffle, use_dataloader): + """ + 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler + 应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler + jittor 暂时不支持这种情况,会报错 + """ + dataloader = get_dataloader(self.dataset, use_dataloader, sampler=None, batch_size=2, shuffle=not shuffle) + dist = ReproduceBatchSampler(JittorBatchSampler(JittorRandomSampler(self.dataset), 4, False), 4, False) + + with pytest.raises(RuntimeError): + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) + + @pytest.mark.parametrize("shuffle", ([True, False])) + @pytest.mark.parametrize("use_dataloader", [True, False]) + def test_with_dist_sampler(self, shuffle, use_dataloader): + """ + 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 + 应该返回新的 dataloader,并将 sampler 替换为 dist 对应的 Sampler + """ + dataloader = get_dataloader(self.dataset, use_dataloader, sampler=None, batch_size=2, shuffle=not shuffle) + dist = RandomSampler(self.dataset, shuffle=shuffle) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.sampler, RandomSampler) + assert replaced_loader.sampler is dist + assert replaced_loader.batch_size == dataloader.batch_size + + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle, use_dataloader) + + @pytest.mark.parametrize("shuffle", ([True, False])) + @pytest.mark.parametrize("use_dataloader", [True, False]) + def test_with_dataloader_reproducible_batch_sampler(self, shuffle, use_dataloader): + """ + 测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 + 应该返回新的 dataloader,且其余各项设置和原来相同 + """ + dataloader = get_dataloader( + self.dataset, + use_dataloader=use_dataloader, + sampler=ReproduceBatchSampler( + JittorBatchSampler(JittorRandomSampler(self.dataset), 4, False), + batch_size=4, + drop_last=False, + ), + batch_size=4, + shuffle=shuffle, + ) + with pytest.raises(RuntimeError): + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) + + @pytest.mark.parametrize("shuffle", ([True, False])) + @pytest.mark.parametrize("use_dataloader", [True, False]) + def test_with_dataloader_reproducible_sampler(self, shuffle, use_dataloader): + """ + 测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 + 应该返回新的 dataloader,且其余各项设置和原来相同 + """ + dataloader = get_dataloader( + self.dataset, + use_dataloader=use_dataloader, + sampler=RandomSampler(self.dataset, shuffle), + batch_size=2, + shuffle=shuffle, + ) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) + + assert not (replaced_loader is dataloader) + assert not (replaced_loader.sampler is dataloader.sampler) + assert isinstance(replaced_loader.sampler, RandomSampler) + assert replaced_loader.batch_size == 2 + assert replaced_loader.shuffle == shuffle + + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle, use_dataloader) + + def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle, use_dataloader): + """ + 测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确 + """ + # 迭代两个 batch + num_consumed_batches = 2 + already_seen_idx = set() + replaced_loader.sampler.set_epoch(6) + for idx, batch in enumerate(replaced_loader): + if idx >= num_consumed_batches: + break + already_seen_idx.update(batch.tolist()) + sampler_states = replaced_loader.sampler.state_dict() + + # 重新加载,应该可以输出剩下的内容,且对于 JittorNormalDataset 来说,排序后应该是一个 range + left_idxes = set() + batch_size = replaced_loader.batch_size + sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size + # 重新构造 dataloader + if use_dataloader: + dataset = deepcopy(replaced_loader.dataset.dataset) + else: + dataset = deepcopy(replaced_loader) + new_loader = get_dataloader( + dataset=dataset, + use_dataloader=use_dataloader, + sampler = RandomSampler(dataset, shuffle=shuffle), + batch_size=batch_size, + shuffle=shuffle, + drop_last=False + ) + new_loader.sampler.load_state_dict(sampler_states) + new_loader.sampler.set_epoch(6) + for idx, batch in enumerate(new_loader): + left_idxes.update(batch.tolist()) + + print(already_seen_idx) + print(left_idxes) + + assert len(left_idxes) + len(already_seen_idx) == self.dataset.total_len + assert len(left_idxes | already_seen_idx) == self.dataset.total_len + +############################################################################ +# +# 测试 save 和 load 相关的功能 +# +############################################################################ + +def generate_random_driver(labels, features, fp16=False, device="cpu", lr=0.01): + """ + 生成driver + """ + model = JittorNormalModel_Classification_1(labels, features) + opt = jt.optim.Adam(params=model.parameters(), lr=lr) + driver = JittorSingleDriver(model, device=device, fp16=fp16) + driver.set_optimizers(opt) + driver.setup() + + return driver + +@pytest.mark.jittor +@pytest.mark.parametrize("only_state_dict", ([True, False])) +@pytest.mark.parametrize("use_dataloader", [True, False]) +def test_save_and_load_model(only_state_dict, use_dataloader): + """ + 测试 save_model 和 load_model 函数 + """ + try: + path = "model" + dataset = JittorNormalXYDataset(20) + dataloader = get_dataloader(dataset, sampler=None, use_dataloader=use_dataloader, batch_size=4, shuffle=True) + driver1, driver2 = generate_random_driver(20, 1, device="gpu"), generate_random_driver(20, 1, device="gpu") + + driver1.save_model(path, only_state_dict) + driver2.load_model(path, only_state_dict) + + for batch in dataloader: + batch = driver1.move_data_to_device(batch) + res1 = driver1.model.evaluate_step(**batch) + res2 = driver2.model.evaluate_step(**batch) + + assert jt.all_(jt.equal(res1["pred"], res2["pred"])) + finally: + rank_zero_rm(path) + +@pytest.mark.jittor +@pytest.mark.parametrize("only_state_dict", ([True, False])) +@pytest.mark.parametrize("use_dataloader", [True, False]) +def test_save_and_load_with_randomsampler(only_state_dict, use_dataloader): + """ + 测试save和load函数,主要测试 dataloader 被替换了 sampler 的情况 + """ + + try: + path = "model.ckp" + + driver1, driver2 = generate_random_driver(20, 1, device="gpu", lr=0.01), \ + generate_random_driver(20, 1, device="gpu", lr=0.001) + dataset = JittorNormalXYDataset(20) + dataloader = get_dataloader( + dataset, use_dataloader, + sampler = RandomSampler(dataset, True), + batch_size=4, + shuffle=True + ) + num_consumed_batches = 2 + + already_seen_x_set = set() + already_seen_y_set = set() + driver1.set_sampler_epoch(dataloader, 7) + for idx, batch in enumerate(dataloader): + if idx >= num_consumed_batches: + break + already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) + already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) + + sampler_states = dataloader.sampler.state_dict() + save_states = {"num_consumed_batches": num_consumed_batches} + driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + + # 加载 + # 更改 batch_size + dataloader = get_dataloader( + dataset, use_dataloader, + sampler=RandomSampler(dataset, True), + batch_size=2, + shuffle=True + ) + load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) + replaced_loader = load_states.pop("dataloader") + + # 1. 检查 optimizer 的状态 + assert driver2.optimizers[0].lr == driver1.optimizers[0].lr + + # 2. 检查 sampler 是否被正确地加载和替换 + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.sampler, RandomSampler) + assert replaced_loader.sampler.seed == sampler_states["seed"] + assert replaced_loader.sampler.epoch == sampler_states["epoch"] + assert replaced_loader.sampler.num_consumed_samples == 4 * num_consumed_batches + assert replaced_loader.sampler.dataset.total_len == sampler_states["length"] + assert replaced_loader.sampler.shuffle == sampler_states["shuffle"] + + # 4. 检查 model 的参数是否正确 + # 5. 检查 batch_idx + start_batch = load_states.pop('batch_idx_in_epoch') + assert start_batch == 2 * num_consumed_batches + left_x_batches = set() + left_y_batches = set() + driver2.set_sampler_epoch(replaced_loader, 7) + for idx, batch in enumerate(replaced_loader): + + left_x_batches.update(batch["x"].reshape(-1, ).tolist()) + left_y_batches.update(batch["y"].reshape(-1, ).tolist()) + res1 = driver1.model.evaluate_step(**batch) + res2 = driver2.model.evaluate_step(**batch) + assert jt.all_(jt.equal(res1["pred"], res2["pred"])) + + assert len(left_x_batches) + len(already_seen_x_set) == dataset.total_len + assert len(left_x_batches | already_seen_x_set) == dataset.total_len + assert len(left_y_batches) + len(already_seen_y_set) == dataset.total_len + assert len(left_y_batches | already_seen_y_set) == dataset.total_len + finally: + rank_zero_rm(path) diff --git a/tests/core/drivers/jittor_driver/test_utils.py b/tests/core/drivers/jittor_driver/test_utils.py index e69de29b..6eb3d3a3 100644 --- a/tests/core/drivers/jittor_driver/test_utils.py +++ b/tests/core/drivers/jittor_driver/test_utils.py @@ -0,0 +1,43 @@ +import pytest + +from fastNLP.core.drivers.jittor_driver.utils import replace_sampler +from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler +from fastNLP.core.dataloaders import JittorDataLoader +from fastNLP.envs.imports import _NEED_IMPORT_JITTOR +if _NEED_IMPORT_JITTOR: + import jittor as jt + +from tests.helpers.datasets.jittor_data import JittorNormalDataset + +@pytest.mark.jittor +@pytest.mark.parametrize("dataset", [ + JittorNormalDataset(20, batch_size=10, shuffle=True), + JittorNormalDataset(20, batch_size=5, drop_last=True), + JittorNormalDataset(20) +]) +def test_replace_sampler_dataset(dataset): + dataset = JittorNormalDataset(20) + sampler = RandomSampler(dataset) + + replaced_loader = replace_sampler(dataset, sampler) + + assert not (replaced_loader is dataset) + assert isinstance(replaced_loader.sampler, RandomSampler) + assert replaced_loader.batch_size == dataset.batch_size + assert replaced_loader.drop_last == dataset.drop_last + assert replaced_loader.shuffle == dataset.shuffle + assert replaced_loader.total_len == dataset.total_len + +@pytest.mark.jittor +def test_replace_sampler_jittordataloader(): + dataset = JittorNormalDataset(20, batch_size=10, shuffle=True) + dataloader = JittorDataLoader(dataset, batch_size=8, shuffle=True) + sampler = RandomSampler(dataset) + + replaced_loader = replace_sampler(dataloader, sampler) + + assert not (replaced_loader is dataloader) + assert not (replaced_loader.dataset.dataset is dataloader.dataset.dataset) + assert isinstance(replaced_loader.sampler, RandomSampler) + assert replaced_loader.batch_size == 8 + assert replaced_loader.shuffle == True \ No newline at end of file diff --git a/tests/core/drivers/paddle_driver/test_fleet.py b/tests/core/drivers/paddle_driver/test_fleet.py index 5f90ed12..93d3e832 100644 --- a/tests/core/drivers/paddle_driver/test_fleet.py +++ b/tests/core/drivers/paddle_driver/test_fleet.py @@ -10,7 +10,7 @@ from fastNLP.core.samplers import ( UnrepeatedSequentialSampler, ) from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 -from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset +from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleNormalXYDataset from tests.helpers.utils import magic_argv_env_context from fastNLP.envs.distributed import rank_zero_rm from fastNLP.envs.imports import _NEED_IMPORT_PADDLE @@ -19,8 +19,8 @@ if _NEED_IMPORT_PADDLE: import paddle.distributed as dist from paddle.io import DataLoader, BatchSampler -def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"): - paddle_model = PaddleNormalModel_Classification_1(num_labels, feature_dimension) +def generate_driver(labels, features, device=[0,1], fp16=False, output_from_new_proc="only_error"): + paddle_model = PaddleNormalModel_Classification_1(labels, features) paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) driver = PaddleFleetDriver( model=paddle_model, @@ -465,10 +465,14 @@ class TestSetDistReproDataloader: num_replicas = len(self.device) num_consumed_batches = 2 already_seen_idx = set() + if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): + sampler_states = replaced_loader.batch_sampler.set_epoch(10) + else: + sampler_states = replaced_loader.batch_sampler.sampler.set_epoch(10) for idx, batch in enumerate(replaced_loader): if idx >= num_consumed_batches: break - already_seen_idx.update(batch) + already_seen_idx.update(batch.tolist()) dist.barrier() if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): sampler_states = replaced_loader.batch_sampler.state_dict() @@ -496,6 +500,7 @@ class TestSetDistReproDataloader: pad=True ) new_loader.batch_sampler.load_state_dict(sampler_states) + new_loader.batch_sampler.set_epoch(10) else: batch_size = replaced_loader.batch_sampler.batch_size sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas @@ -508,8 +513,9 @@ class TestSetDistReproDataloader: ) new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler) new_loader.batch_sampler.sampler.load_state_dict(sampler_states) + new_loader.batch_sampler.sampler.set_epoch(10) for idx, batch in enumerate(new_loader): - left_idxes.update(batch) + left_idxes.update(batch.tolist()) assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas @@ -533,7 +539,7 @@ class TestSaveLoad: cls.driver = generate_driver(10, 10, device=[0,1]) def setup_method(self): - self.dataset = PaddleRandomMaxDataset(20, 10) + self.dataset = PaddleNormalXYDataset(40) @magic_argv_env_context @pytest.mark.parametrize("only_state_dict", ([True, False])) @@ -545,12 +551,12 @@ class TestSaveLoad: path = "model" dataloader = DataLoader(self.dataset, batch_size=2) - self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10) + self.driver1, self.driver2 = generate_driver(40, 1), generate_driver(40, 1) if only_state_dict: self.driver1.save_model(path, only_state_dict) else: - self.driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((4, 10))]) + self.driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((4, 1))]) # 同步 dist.barrier() @@ -594,8 +600,8 @@ class TestSaveLoad: path = "model.ckp" num_replicas = len(device) - self.driver1, self.driver2 = generate_driver(10, 10, device=device, fp16=fp16), \ - generate_driver(10, 10, device=device, fp16=False) + self.driver1, self.driver2 = generate_driver(40, 1, device=device, fp16=fp16), \ + generate_driver(40, 1, device=device, fp16=False) dataloader = DataLoader( dataset=self.dataset, batch_sampler=BucketedBatchSampler( @@ -613,11 +619,12 @@ class TestSaveLoad: already_seen_x_set = set() already_seen_y_set = set() + self.driver1.set_sampler_epoch(dataloader, 2) for idx, batch in enumerate(dataloader): if idx >= num_consumed_batches: break - already_seen_x_set.update(batch["x"]) - already_seen_y_set.update(batch["y"]) + already_seen_x_set.update(batch["x"].reshape((-1, )).tolist()) + already_seen_y_set.update(batch["y"].reshape((-1, )).tolist()) # 同步 dist.barrier() @@ -669,10 +676,11 @@ class TestSaveLoad: assert start_batch == 2 * num_consumed_batches left_x_batches = set() left_y_batches = set() + self.driver2.set_sampler_epoch(replaced_loader, 2) for idx, batch in enumerate(replaced_loader): - left_x_batches.update(batch["x"]) - left_y_batches.update(batch["y"]) + left_x_batches.update(batch["x"].reshape((-1, )).tolist()) + left_y_batches.update(batch["y"].reshape((-1, )).tolist()) res1 = self.driver1.model( batch, fastnlp_fn=self.driver1.model._layers.model.evaluate_step, @@ -709,8 +717,8 @@ class TestSaveLoad: num_replicas = len(device) - self.driver1 = generate_driver(10, 10, device=device, fp16=fp16) - self.driver2 = generate_driver(10, 10, device=device, fp16=False) + self.driver1 = generate_driver(40, 1, device=device, fp16=fp16) + self.driver2 = generate_driver(40, 1, device=device, fp16=False) batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4) batch_sampler.sampler = RandomSampler(self.dataset, True) batch_sampler.sampler.set_distributed( @@ -726,11 +734,12 @@ class TestSaveLoad: already_seen_x_set = set() already_seen_y_set = set() + self.driver1.set_sampler_epoch(dataloader, 2) for idx, batch in enumerate(dataloader): if idx >= num_consumed_batches: break - already_seen_x_set.update(batch["x"]) - already_seen_y_set.update(batch["y"]) + already_seen_x_set.update(batch["x"].reshape((-1, )).tolist()) + already_seen_y_set.update(batch["y"].reshape((-1, )).tolist()) # 同步 dist.barrier() @@ -779,10 +788,11 @@ class TestSaveLoad: assert start_batch == 2 * num_consumed_batches left_x_batches = set() left_y_batches = set() + self.driver2.set_sampler_epoch(replaced_loader, 2) for idx, batch in enumerate(replaced_loader): - left_x_batches.update(batch["x"]) - left_y_batches.update(batch["y"]) + left_x_batches.update(batch["x"].reshape((-1, )).tolist()) + left_y_batches.update(batch["y"].reshape((-1, )).tolist()) res1 = self.driver1.model( batch, fastnlp_fn=self.driver1.model._layers.model.evaluate_step, diff --git a/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py b/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py index ad99d4a8..7e567c84 100644 --- a/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py +++ b/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py @@ -12,7 +12,7 @@ if _NEED_IMPORT_PADDLE: @pytest.mark.paddle def test_incorrect_driver(): - model = PaddleNormalModel_Classification_1(2, 100) + model = PaddleNormalModel_Classification_1(20, 10) with pytest.raises(ValueError): driver = initialize_paddle_driver("torch", 0, model) @@ -26,7 +26,7 @@ def test_get_single_device(device): 测试正常情况下初始化 PaddleSingleDriver 的情况 """ - model = PaddleNormalModel_Classification_1(2, 100) + model = PaddleNormalModel_Classification_1(20, 10) driver = initialize_paddle_driver("paddle", device, model) assert isinstance(driver, PaddleSingleDriver) @@ -41,7 +41,7 @@ def test_get_fleet(device): 测试 fleet 多卡的初始化情况 """ - model = PaddleNormalModel_Classification_1(64, 10) + model = PaddleNormalModel_Classification_1(20, 10) driver = initialize_paddle_driver("paddle", device, model) assert isinstance(driver, PaddleFleetDriver) @@ -56,6 +56,6 @@ def test_device_out_of_range(device): """ 测试传入的device超过范围的情况 """ - model = PaddleNormalModel_Classification_1(2, 100) + model = PaddleNormalModel_Classification_1(20, 10) with pytest.raises(ValueError): driver = initialize_paddle_driver("paddle", device, model) diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index 3c2d7e27..e2d63a4b 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -4,14 +4,16 @@ from pathlib import Path from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 -from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset +from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleNormalXYDataset from tests.helpers.datasets.torch_data import TorchNormalDataset from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from fastNLP.envs.distributed import rank_zero_rm from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH + if _NEED_IMPORT_PADDLE: import paddle from paddle.io import DataLoader, BatchSampler + if _NEED_IMPORT_TORCH: import torch @@ -31,102 +33,70 @@ class TestPaddleDriverFunctions: model = PaddleNormalModel_Classification_1(10, 32) self.driver = PaddleSingleDriver(model, device="cpu") - @pytest.mark.torchpaddle - def test_check_single_optimizer_legality(self): + @pytest.mark.paddle + def test_check_optimizers_legality(self): """ - 测试传入单个 optimizer 时的表现 + 测试对合法的 optimizers 的检查 """ + # 单个 optimizer optimizer = paddle.optimizer.Adam( parameters=self.driver.model.parameters(), learning_rate=0.01 ) - self.driver.set_optimizers(optimizer) - optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) - # 传入torch的optimizer时,应该报错ValueError - with pytest.raises(ValueError): - self.driver.set_optimizers(optimizer) - - @pytest.mark.torchpaddle - def test_check_optimizers_legality(self): - """ - 测试传入 optimizer list 的表现 - """ + # optimizer 列表 optimizers = [ paddle.optimizer.Adam( parameters=self.driver.model.parameters(), learning_rate=0.01 ) for i in range(10) ] - self.driver.set_optimizers(optimizers) - optimizers += [ + @pytest.mark.torchpaddle + def test_invalid_optimizers(self): + """ + 测试传入非法的 optimizers + """ + # 单个 optimizer + optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) + with pytest.raises(TypeError): + self.driver.set_optimizers(optimizer) + + optimizers = [ torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) ] - with pytest.raises(ValueError): + with pytest.raises(TypeError): self.driver.set_optimizers(optimizers) - @pytest.mark.torchpaddle - def test_check_dataloader_legality_in_train(self): + @pytest.mark.paddle + def test_check_dataloader_legality(self): """ - 测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 + 测试 check_dataloader_legality 函数的表现 """ dataloader = DataLoader(PaddleNormalDataset()) - PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") + self.driver.check_dataloader_legality(dataloader) # batch_size 和 batch_sampler 均为 None 的情形 dataloader = DataLoader(PaddleNormalDataset(), batch_size=None) with pytest.raises(ValueError): - PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") - - # 创建torch的dataloader - dataloader = torch.utils.data.DataLoader( - TorchNormalDataset(), - batch_size=32, shuffle=True - ) - with pytest.raises(ValueError): - PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") + self.driver.check_dataloader_legality(dataloader) @pytest.mark.torchpaddle - def test_check_dataloader_legality_in_test(self): + def test_check_dataloader_legality_invalid(self): """ - 测试 `is_train` 参数为 False 时,_check_dataloader_legality 函数的表现 + 测试 check_dataloader_legality 函数传入其他类型的表现 """ - # 此时传入的应该是dict - dataloader = { - "train": DataLoader(PaddleNormalDataset()), - "test":DataLoader(PaddleNormalDataset()) - } - PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") - - # batch_size 和 batch_sampler 均为 None 的情形 - dataloader = { - "train": DataLoader(PaddleNormalDataset()), - "test":DataLoader(PaddleNormalDataset(), batch_size=None) - } - with pytest.raises(ValueError): - PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") - - # 传入的不是 dict ,应该报错 - dataloader = DataLoader(PaddleNormalDataset()) - with pytest.raises(ValueError): - PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") - # 创建 torch 的 dataloader - train_loader = torch.utils.data.DataLoader( - TorchNormalDataset(), - batch_size=32, shuffle=True - ) - test_loader = torch.utils.data.DataLoader( + dataloader = torch.utils.data.DataLoader( TorchNormalDataset(), batch_size=32, shuffle=True ) - dataloader = {"train": train_loader, "test": test_loader} - with pytest.raises(ValueError): - PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") + with pytest.raises(TypeError): + self.driver.check_dataloader_legality(dataloader) + @pytest.mark.paddle def test_tensor_to_numeric(self): @@ -505,10 +475,14 @@ class TestSetDistReproDataloader: # 迭代两个 batch num_consumed_batches = 2 already_seen_idx = set() + if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): + sampler_states = replaced_loader.batch_sampler.set_epoch(5) + else: + sampler_states = replaced_loader.batch_sampler.sampler.set_epoch(5) for idx, batch in enumerate(replaced_loader): if idx >= num_consumed_batches: break - already_seen_idx.update(batch) + already_seen_idx.update(batch.tolist()) if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): sampler_states = replaced_loader.batch_sampler.state_dict() else: @@ -529,6 +503,7 @@ class TestSetDistReproDataloader: ) ) new_loader.batch_sampler.load_state_dict(sampler_states) + new_loader.batch_sampler.set_epoch(5) else: batch_size = replaced_loader.batch_sampler.batch_size sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size @@ -537,8 +512,9 @@ class TestSetDistReproDataloader: batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle) new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler) new_loader.batch_sampler.sampler.load_state_dict(sampler_states) + new_loader.batch_sampler.sampler.set_epoch(5) for idx, batch in enumerate(new_loader): - left_idxes.update(batch) + left_idxes.update(batch.tolist()) assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) assert len(left_idxes | already_seen_idx) == len(self.dataset) @@ -549,7 +525,7 @@ class TestSetDistReproDataloader: # ############################################################################ -def generate_random_driver(features, labels, fp16=False, device="cpu"): +def generate_random_driver(labels, features, fp16=False, device="cpu"): """ 生成driver """ @@ -569,9 +545,9 @@ def test_save_and_load_model(only_state_dict): """ try: path = "model" - dataset = PaddleRandomMaxDataset(40, 10) + dataset = PaddleNormalXYDataset(20) dataloader = DataLoader(dataset, batch_size=4) - driver1, driver2 = generate_random_driver(10, 10, device="gpu"), generate_random_driver(10, 10, device="gpu") + driver1, driver2 = generate_random_driver(20, 1, device="gpu"), generate_random_driver(20, 1, device="gpu") if only_state_dict: driver1.save_model(path, only_state_dict) @@ -580,6 +556,7 @@ def test_save_and_load_model(only_state_dict): driver2.load_model(path, only_state_dict) for batch in dataloader: + print("?") batch = driver1.move_data_to_device(batch) res1 = driver1.model.evaluate_step(**batch) res2 = driver2.model.evaluate_step(**batch) @@ -604,22 +581,23 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): try: path = "model.ckp" - dataset = PaddleRandomMaxDataset(40, 10) + dataset = PaddleNormalXYDataset(40) dataloader = DataLoader( dataset=dataset, batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) ) - driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu") + driver1, driver2 = generate_random_driver(40, 1, fp16, "gpu"), generate_random_driver(40, 1, False, "gpu") num_consumed_batches = 2 already_seen_x_set = set() already_seen_y_set = set() + driver1.set_sampler_epoch(dataloader, 3) for idx, batch in enumerate(dataloader): if idx >= num_consumed_batches: break - already_seen_x_set.update(batch["x"]) - already_seen_y_set.update(batch["y"]) + already_seen_x_set.update(batch["x"].reshape((-1, )).tolist()) + already_seen_y_set.update(batch["y"].reshape((-1, )).tolist()) sampler_states = dataloader.batch_sampler.state_dict() save_states = {"num_consumed_batches": num_consumed_batches} @@ -656,10 +634,11 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): assert start_batch == 2 * num_consumed_batches left_x_batches = set() left_y_batches = set() + driver2.set_sampler_epoch(replaced_loader, 3) for idx, batch in enumerate(replaced_loader): - left_x_batches.update(batch["x"]) - left_y_batches.update(batch["y"]) + left_x_batches.update(batch["x"].reshape((-1, )).tolist()) + left_y_batches.update(batch["y"].reshape((-1, )).tolist()) res1 = driver1.model.evaluate_step(**batch) res2 = driver2.model.evaluate_step(**batch) assert paddle.equal_all(res1["pred"], res2["pred"]) @@ -679,14 +658,14 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): @pytest.mark.parametrize("fp16", ([True, False])) def test_save_and_load_with_randomsampler(only_state_dict, fp16): """ - 测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 + 测试save和load函数,主要测试 dataloader 被替换了 sampler 的情况 """ try: path = "model.ckp" - driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu") - dataset = PaddleRandomMaxDataset(40, 10) + driver1, driver2 = generate_random_driver(40, 1, fp16, "gpu"), generate_random_driver(40, 1, False, "gpu") + dataset = PaddleNormalXYDataset(40) batch_sampler = BatchSampler(dataset=dataset, batch_size=4) batch_sampler.sampler = RandomSampler(dataset, True) dataloader = DataLoader( @@ -697,11 +676,12 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): already_seen_x_set = set() already_seen_y_set = set() + driver1.set_sampler_epoch(dataloader, 3) for idx, batch in enumerate(dataloader): if idx >= num_consumed_batches: break - already_seen_x_set.update(batch["x"]) - already_seen_y_set.update(batch["y"]) + already_seen_x_set.update(batch["x"].reshape((-1, )).tolist()) + already_seen_y_set.update(batch["y"].reshape((-1, )).tolist()) sampler_states = dataloader.batch_sampler.sampler.state_dict() save_states = {"num_consumed_batches": num_consumed_batches} @@ -743,10 +723,11 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): assert start_batch == 2 * num_consumed_batches left_x_batches = set() left_y_batches = set() + driver1.set_sampler_epoch(replaced_loader, 3) for idx, batch in enumerate(replaced_loader): - left_x_batches.update(batch["x"]) - left_y_batches.update(batch["y"]) + left_x_batches.update(batch["x"].reshape((-1, )).tolist()) + left_y_batches.update(batch["y"].reshape((-1, )).tolist()) res1 = driver1.model.evaluate_step(**batch) res2 = driver2.model.evaluate_step(**batch) assert paddle.equal_all(res1["pred"], res2["pred"]) diff --git a/tests/core/drivers/torch_driver/test_ddp.py b/tests/core/drivers/torch_driver/test_ddp.py index d9e4da66..4ec0eb61 100644 --- a/tests/core/drivers/torch_driver/test_ddp.py +++ b/tests/core/drivers/torch_driver/test_ddp.py @@ -10,7 +10,7 @@ from fastNLP.core.samplers import ( UnrepeatedSequentialSampler, ) from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 -from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset +from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchNormalXYDataset from tests.helpers.utils import magic_argv_env_context from fastNLP.envs.distributed import rank_zero_rm from fastNLP.envs.imports import _NEED_IMPORT_TORCH @@ -19,8 +19,8 @@ if _NEED_IMPORT_TORCH: import torch.distributed as dist from torch.utils.data import DataLoader, BatchSampler -def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="all"): - torch_model = TorchNormalModel_Classification_1(num_labels, feature_dimension) +def generate_driver(labels, features, device=[0,1], fp16=False, output_from_new_proc="all"): + torch_model = TorchNormalModel_Classification_1(labels, features) torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) device = [torch.device(i) for i in device] driver = TorchDDPDriver( @@ -504,10 +504,14 @@ class TestSetDistReproDataloader: num_replicas = len(self.device) num_consumed_batches = 2 already_seen_idx = set() + if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): + sampler_states = replaced_loader.batch_sampler.set_epoch(4) + else: + sampler_states = replaced_loader.batch_sampler.sampler.set_epoch(4) for idx, batch in enumerate(replaced_loader): if idx >= num_consumed_batches: break - already_seen_idx.update(batch) + already_seen_idx.update(batch.tolist()) dist.barrier() if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): sampler_states = replaced_loader.batch_sampler.state_dict() @@ -533,6 +537,7 @@ class TestSetDistReproDataloader: pad=True ) new_loader.batch_sampler.load_state_dict(sampler_states) + new_loader.batch_sampler.set_epoch(4) else: batch_size = replaced_loader.batch_sampler.batch_size sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas @@ -543,8 +548,9 @@ class TestSetDistReproDataloader: rank=driver.global_rank ) new_loader.batch_sampler.sampler.load_state_dict(sampler_states) + new_loader.batch_sampler.sampler.set_epoch(4) for idx, batch in enumerate(new_loader): - left_idxes.update(batch) + left_idxes.update(batch.tolist()) assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas @@ -562,7 +568,7 @@ class TestSaveLoad: """ def setup_method(self): - self.dataset = TorchArgMaxDataset(10, 20) + self.dataset = TorchNormalXYDataset(20) @magic_argv_env_context @pytest.mark.parametrize("only_state_dict", ([True, False])) @@ -574,7 +580,7 @@ class TestSaveLoad: path = "model" dataloader = DataLoader(self.dataset, batch_size=2) - driver1, driver2 = generate_driver(10, 10), generate_driver(10, 10) + driver1, driver2 = generate_driver(20, 1), generate_driver(20, 1) driver1.save_model(path, only_state_dict) @@ -618,8 +624,8 @@ class TestSaveLoad: path = "model.ckp" num_replicas = len(device) - driver1, driver2 = generate_driver(10, 10, device=device, fp16=fp16), \ - generate_driver(10, 10, device=device, fp16=False) + driver1, driver2 = generate_driver(20, 1, device=device, fp16=fp16), \ + generate_driver(20, 1, device=device, fp16=False) dataloader = dataloader_with_bucketedbatchsampler( self.dataset, length=[10 for i in range(len(self.dataset))], @@ -636,11 +642,12 @@ class TestSaveLoad: already_seen_x_set = set() already_seen_y_set = set() + driver1.set_sampler_epoch(dataloader, 4) for idx, batch in enumerate(dataloader): if idx >= num_consumed_batches: break - already_seen_x_set.update(batch["x"]) - already_seen_y_set.update(batch["y"]) + already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) + already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) # 同步 dist.barrier() @@ -665,7 +672,6 @@ class TestSaveLoad: pad=True ) dist.barrier() - print("========load=======", driver1.global_rank, driver2.global_rank) load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) dist.barrier() replaced_loader = load_states.pop("dataloader") @@ -690,10 +696,11 @@ class TestSaveLoad: assert start_batch == 2 * num_consumed_batches left_x_batches = set() left_y_batches = set() + driver2.set_sampler_epoch(replaced_loader, 4) for idx, batch in enumerate(replaced_loader): - left_x_batches.update(batch["x"]) - left_y_batches.update(batch["y"]) + left_x_batches.update(batch["x"].reshape(-1, ).tolist()) + left_y_batches.update(batch["y"].reshape(-1, ).tolist()) res1 = driver1.model( batch, fastnlp_fn=driver1.model.module.model.evaluate_step, @@ -716,7 +723,6 @@ class TestSaveLoad: dist.barrier() finally: rank_zero_rm(path) - print("=======delete======") if dist.is_initialized(): dist.destroy_process_group() @@ -735,8 +741,8 @@ class TestSaveLoad: num_replicas = len(device) - driver1 = generate_driver(10, 10, device=device, fp16=fp16) - driver2 = generate_driver(10, 10, device=device, fp16=False) + driver1 = generate_driver(20, 1, device=device, fp16=fp16) + driver2 = generate_driver(20, 1, device=device, fp16=False) dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False) dataloader.batch_sampler.sampler.set_distributed( @@ -748,11 +754,12 @@ class TestSaveLoad: already_seen_x_set = set() already_seen_y_set = set() + driver1.set_sampler_epoch(dataloader, 4) for idx, batch in enumerate(dataloader): if idx >= num_consumed_batches: break - already_seen_x_set.update(batch["x"]) - already_seen_y_set.update(batch["y"]) + already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) + already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) # 同步 dist.barrier() @@ -797,10 +804,11 @@ class TestSaveLoad: assert start_batch == 2 * num_consumed_batches left_x_batches = set() left_y_batches = set() + driver2.set_sampler_epoch(replaced_loader, 4) for idx, batch in enumerate(replaced_loader): - left_x_batches.update(batch["x"]) - left_y_batches.update(batch["y"]) + left_x_batches.update(batch["x"].reshape(-1, ).tolist()) + left_y_batches.update(batch["y"].reshape(-1, ).tolist()) res1 = driver1.model( batch, fastnlp_fn=driver1.model.module.model.evaluate_step, diff --git a/tests/core/drivers/torch_driver/test_initialize_torch_driver.py b/tests/core/drivers/torch_driver/test_initialize_torch_driver.py index 950477c1..f3e04541 100644 --- a/tests/core/drivers/torch_driver/test_initialize_torch_driver.py +++ b/tests/core/drivers/torch_driver/test_initialize_torch_driver.py @@ -14,7 +14,7 @@ else: @pytest.mark.torch def test_incorrect_driver(): - model = TorchNormalModel_Classification_1(2, 100) + model = TorchNormalModel_Classification_1(20, 10) with pytest.raises(ValueError): driver = initialize_torch_driver("paddle", 0, model) @@ -33,7 +33,7 @@ def test_get_single_device(driver, device): 测试正常情况下初始化TorchSingleDriver的情况 """ - model = TorchNormalModel_Classification_1(2, 100) + model = TorchNormalModel_Classification_1(20, 10) driver = initialize_torch_driver(driver, device, model) assert isinstance(driver, TorchSingleDriver) @@ -52,7 +52,7 @@ def test_get_ddp(driver, device): 测试 ddp 多卡的初始化情况 """ - model = TorchNormalModel_Classification_1(64, 10) + model = TorchNormalModel_Classification_1(20, 10) driver = initialize_torch_driver(driver, device, model) assert isinstance(driver, TorchDDPDriver) @@ -70,6 +70,6 @@ def test_device_out_of_range(driver, device): """ 测试传入的device超过范围的情况 """ - model = TorchNormalModel_Classification_1(2, 100) + model = TorchNormalModel_Classification_1(20, 10) with pytest.raises(ValueError): driver = initialize_torch_driver(driver, device, model) \ No newline at end of file diff --git a/tests/core/drivers/torch_driver/test_single_device.py b/tests/core/drivers/torch_driver/test_single_device.py index 4d92b05a..73ffbb8d 100644 --- a/tests/core/drivers/torch_driver/test_single_device.py +++ b/tests/core/drivers/torch_driver/test_single_device.py @@ -6,7 +6,7 @@ from pkg_resources import parse_version from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 -from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset +from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchNormalXYDataset from tests.helpers.datasets.paddle_data import PaddleNormalDataset from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 from fastNLP.envs.distributed import rank_zero_rm @@ -15,6 +15,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: import torch from torch.utils.data import DataLoader, BatchSampler + if _NEED_IMPORT_PADDLE: import paddle @@ -67,95 +68,67 @@ class TestTorchDriverFunctions: model = TorchNormalModel_Classification_1(10, 32) self.driver = TorchSingleDriver(model, device="cpu") - @pytest.mark.torchpaddle - def test_check_single_optimizer_legality(self): + @pytest.mark.torch + def test_check_optimizers_legality(self): """ - 测试传入单个 optimizer 时的表现 + 测试对合法 optimizers 的检查 """ + # 单个 optimizer optimizer = torch.optim.Adam( params=self.driver.model.parameters(), lr=0.01 ) - self.driver.set_optimizers(optimizer) - - optimizer = paddle.optimizer.Adam( - parameters=PaddleNormalModel_Classification_1(10, 32).parameters(), - learning_rate=0.01, - ) - # 传入 torch 的 optimize r时,应该报错 ValueError - with pytest.raises(ValueError): - self.driver.set_optimizers(optimizer) - - @pytest.mark.torchpaddle - def test_check_optimizers_legality(self): - """ - 测试传入 optimizer list 的表现 - """ + # 列表 optimizers = [ torch.optim.Adam( params=self.driver.model.parameters(), lr=0.01 ) for i in range(10) ] - self.driver.set_optimizers(optimizers) - optimizers += [ + @pytest.mark.torchpaddle + def test_invalid_optimizers(self): + """ + 测试传入非法的 optimizers + """ + optimizer = paddle.optimizer.Adam( + parameters=PaddleNormalModel_Classification_1(10, 32).parameters(), + learning_rate=0.01, + ) + with pytest.raises(TypeError): + self.driver.set_optimizers(optimizer) + + optimizers = [ paddle.optimizer.Adam( parameters=PaddleNormalModel_Classification_1(10, 32).parameters(), learning_rate=0.01, ) ] - - with pytest.raises(ValueError): + with pytest.raises(TypeError): self.driver.set_optimizers(optimizers) - @pytest.mark.torchpaddle - def test_check_dataloader_legality_in_train(self): + @pytest.mark.torch + def test_check_dataloader_legality(self): """ - 测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 + 测试 check_dataloader_legality 函数的表现 """ dataloader = DataLoader(TorchNormalDataset()) - TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") - - # 创建 paddle 的 dataloader - dataloader = paddle.io.DataLoader( - PaddleNormalDataset(), - batch_size=32, shuffle=True - ) - with pytest.raises(ValueError): - TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") + self.driver.check_dataloader_legality(dataloader) @pytest.mark.torchpaddle - def test_check_dataloader_legality_in_test(self): + def test_check_dataloader_legality_invalid(self): """ - 测试 `is_train` 参数为 False 时,_check_dataloader_legality 函数的表现 + 测试 check_dataloader_legality 函数传入其他类型的表现 """ - # 此时传入的应该是dict - dataloader = { - "train": DataLoader(TorchNormalDataset()), - "test": DataLoader(TorchNormalDataset()) - } - TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") - - # 传入的不是 dict,应该报错 - dataloader = DataLoader(TorchNormalDataset()) - with pytest.raises(ValueError): - TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") - # 创建 paddle 的 dataloader - train_loader = paddle.io.DataLoader( - PaddleNormalDataset(), - batch_size=32, shuffle=True - ) - test_loader = paddle.io.DataLoader( + dataloader = paddle.io.DataLoader( PaddleNormalDataset(), batch_size=32, shuffle=True ) - dataloader = {"train": train_loader, "test": test_loader} - with pytest.raises(ValueError): - TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") + with pytest.raises(TypeError): + self.driver.check_dataloader_legality(dataloader) @pytest.mark.torch def test_tensor_to_numeric(self): @@ -515,10 +488,14 @@ class TestSetDistReproDataloader: # 迭代两个 batch num_consumed_batches = 2 already_seen_idx = set() + if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): + replaced_loader.batch_sampler.set_epoch(3) + else: + replaced_loader.batch_sampler.sampler.set_epoch(3) for idx, batch in enumerate(replaced_loader): if idx >= num_consumed_batches: break - already_seen_idx.update(batch) + already_seen_idx.update(batch.tolist()) if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): sampler_states = replaced_loader.batch_sampler.state_dict() else: @@ -532,14 +509,16 @@ class TestSetDistReproDataloader: # 重新改造 dataloader new_loader = dataloader_with_randombatchsampler(replaced_loader.dataset, batch_size, shuffle, False) new_loader.batch_sampler.load_state_dict(sampler_states) + new_loader.batch_sampler.set_epoch(3) else: batch_size = replaced_loader.batch_sampler.batch_size sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size # 重新构造 dataloader new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, False) new_loader.batch_sampler.sampler.load_state_dict(sampler_states) + new_loader.batch_sampler.sampler.set_epoch(3) for idx, batch in enumerate(new_loader): - left_idxes.update(batch) + left_idxes.update(batch.tolist()) assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) assert len(left_idxes | already_seen_idx) == len(self.dataset) @@ -550,7 +529,7 @@ class TestSetDistReproDataloader: # ############################################################################ -def generate_random_driver(features, labels, fp16=False, device="cpu"): +def generate_random_driver(labels, features, fp16=False, device="cpu"): """ 生成driver """ @@ -570,9 +549,9 @@ def test_save_and_load_model(only_state_dict): """ try: path = "model" - dataset = TorchArgMaxDataset(10, 40) + dataset = TorchNormalXYDataset(20) dataloader = DataLoader(dataset, batch_size=4) - driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) + driver1, driver2 = generate_random_driver(20, 1), generate_random_driver(20, 1) driver1.save_model(path, only_state_dict) driver2.load_model(path, only_state_dict) @@ -596,19 +575,20 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): try: path = "model.ckp" - dataset = TorchArgMaxDataset(10, 40) + dataset = TorchNormalXYDataset(20) dataloader = dataloader_with_randombatchsampler(dataset, 4, True, False) - driver1, driver2 = generate_random_driver(10, 10, fp16, "cuda"), generate_random_driver(10, 10, False, "cuda") + driver1, driver2 = generate_random_driver(20, 1, fp16, "cuda"), generate_random_driver(20, 1, False, "cuda") num_consumed_batches = 2 already_seen_x_set = set() already_seen_y_set = set() + driver1.set_sampler_epoch(dataloader, 3) for idx, batch in enumerate(dataloader): if idx >= num_consumed_batches: break - already_seen_x_set.update(batch["x"]) - already_seen_y_set.update(batch["y"]) + already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) + already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) sampler_states = dataloader.batch_sampler.state_dict() save_states = {"num_consumed_batches": num_consumed_batches} @@ -639,11 +619,12 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): assert start_batch == 2 * num_consumed_batches left_x_batches = set() left_y_batches = set() + driver1.set_sampler_epoch(replaced_loader, 3) for idx, batch in enumerate(replaced_loader): batch = driver2.move_data_to_device(batch) - left_x_batches.update(batch["x"]) - left_y_batches.update(batch["y"]) + left_x_batches.update(batch["x"].reshape(-1, ).tolist()) + left_y_batches.update(batch["y"].reshape(-1, ).tolist()) res1 = driver1.model.evaluate_step(**batch) res2 = driver2.model.evaluate_step(**batch) assert torch.equal(res1["preds"], res2["preds"]) @@ -660,24 +641,25 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): @pytest.mark.parametrize("fp16", ([True, False])) def test_save_and_load_with_randomsampler(only_state_dict, fp16): """ - 测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 + 测试save和load函数,主要测试 dataloader 被替换了 sampler 的情况 """ try: path = "model.ckp" - driver1, driver2 = generate_random_driver(10, 10, fp16, "cuda"), generate_random_driver(10, 10, False, "cuda") - dataset = TorchArgMaxDataset(10, 40) + driver1, driver2 = generate_random_driver(40, 1, fp16, "cuda"), generate_random_driver(40, 1, False, "cuda") + dataset = TorchNormalXYDataset(40) dataloader = dataloader_with_randomsampler(dataset, 4, True, False) num_consumed_batches = 2 already_seen_x_set = set() already_seen_y_set = set() + driver1.set_sampler_epoch(dataloader, 3) for idx, batch in enumerate(dataloader): if idx >= num_consumed_batches: break - already_seen_x_set.update(batch["x"]) - already_seen_y_set.update(batch["y"]) + already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) + already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) sampler_states = dataloader.batch_sampler.sampler.state_dict() save_states = {"num_consumed_batches": num_consumed_batches} @@ -711,11 +693,13 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): assert start_batch == 2 * num_consumed_batches left_x_batches = set() left_y_batches = set() + # set epoch + driver2.set_sampler_epoch(replaced_loader, 3) for idx, batch in enumerate(replaced_loader): batch = driver2.move_data_to_device(batch) - left_x_batches.update(batch["x"]) - left_y_batches.update(batch["y"]) + left_x_batches.update(batch["x"].reshape(-1, ).tolist()) + left_y_batches.update(batch["y"].reshape(-1, ).tolist()) res1 = driver1.model.evaluate_step(**batch) res2 = driver2.model.evaluate_step(**batch) assert torch.equal(res1["preds"], res2["preds"]) diff --git a/tests/helpers/datasets/jittor_data.py b/tests/helpers/datasets/jittor_data.py new file mode 100644 index 00000000..f97445d9 --- /dev/null +++ b/tests/helpers/datasets/jittor_data.py @@ -0,0 +1,46 @@ +from fastNLP.envs.imports import _NEED_IMPORT_JITTOR + +if _NEED_IMPORT_JITTOR: + import jittor as jt + from jittor.dataset import Dataset +else: + from fastNLP.core.utils.dummy_class import DummyClass as Dataset + +class JittorNormalDataset(Dataset): + def __init__(self, num_of_data=100, **kwargs): + super(JittorNormalDataset, self).__init__(**kwargs) + self._data = list(range(num_of_data)) + self.set_attrs(total_len=num_of_data) + + def __getitem__(self, item): + return self._data[item] + +class JittorNormalXYDataset(Dataset): + """ + 可以被输入到分类模型中的普通数据集 + """ + def __init__(self, num_of_data=1000, **kwargs): + super(JittorNormalXYDataset, self).__init__(**kwargs) + self.num_of_data = num_of_data + self._data = list(range(num_of_data)) + self.set_attrs(total_len=num_of_data) + + def __getitem__(self, item): + return { + "x": jt.Var([self._data[item]]), + "y": jt.Var([self._data[item]]) + } + +class JittorArgMaxDataset(Dataset): + def __init__(self, num_samples, num_features, **kwargs): + super(JittorArgMaxDataset, self).__init__(**kwargs) + self.x = jt.randn(num_samples, num_features) + self.y = self.x.argmax(dim=-1) + self.set_attrs(total_len=num_samples) + + def __getitem__(self, item): + return {"x": self.x[item], "y": self.y[item]} + +if __name__ == "__main__": + dataset = JittorNormalDataset() + print(len(dataset)) \ No newline at end of file diff --git a/tests/helpers/datasets/paddle_data.py b/tests/helpers/datasets/paddle_data.py index 8a8d39b1..16fe10c2 100644 --- a/tests/helpers/datasets/paddle_data.py +++ b/tests/helpers/datasets/paddle_data.py @@ -19,8 +19,24 @@ class PaddleNormalDataset(Dataset): def __getitem__(self, item): return self._data[item] +class PaddleNormalXYDataset(Dataset): + """ + 可以被输入到分类模型中的普通数据集 + """ + def __init__(self, num_of_data=1000): + self.num_of_data = num_of_data + self._data = list(range(num_of_data)) + + def __len__(self): + return self.num_of_data + + def __getitem__(self, item): + return { + "x": paddle.to_tensor([self._data[item]], dtype="float32"), + "y": paddle.to_tensor([self._data[item]], dtype="float32") + } -class PaddleRandomMaxDataset(Dataset): +class PaddleArgMaxDataset(Dataset): def __init__(self, num_samples, num_features): self.x = paddle.randn((num_samples, num_features)) self.y = self.x.argmax(axis=-1) diff --git a/tests/helpers/datasets/torch_data.py b/tests/helpers/datasets/torch_data.py index 1244a2f6..55c21acc 100644 --- a/tests/helpers/datasets/torch_data.py +++ b/tests/helpers/datasets/torch_data.py @@ -1,4 +1,6 @@ from functools import reduce + +from numpy import dtype from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: @@ -19,6 +21,23 @@ class TorchNormalDataset(Dataset): def __getitem__(self, item): return self._data[item] +class TorchNormalXYDataset(Dataset): + """ + 可以被输入到分类模型中的普通数据集 + """ + def __init__(self, num_of_data=1000): + self.num_of_data = num_of_data + self._data = list(range(num_of_data)) + + def __len__(self): + return self.num_of_data + + def __getitem__(self, item): + return { + "x": torch.tensor([self._data[item]], dtype=torch.float), + "y": torch.tensor([self._data[item]], dtype=torch.float) + } + # 该类专门用于为 tests.helpers.models.torch_model.py/ TorchNormalModel_Classification_1 创建数据; class TorchNormalDataset_Classification(Dataset): diff --git a/tests/helpers/models/jittor_model.py b/tests/helpers/models/jittor_model.py new file mode 100644 index 00000000..feabb94b --- /dev/null +++ b/tests/helpers/models/jittor_model.py @@ -0,0 +1,57 @@ +from fastNLP.envs.imports import _NEED_IMPORT_JITTOR +if _NEED_IMPORT_JITTOR: + from jittor import Module, nn +else: + from fastNLP.core.utils.dummy_class import DummyClass as Module + +class JittorNormalModel_Classification_1(Module): + """ + 基础的 jittor 分类模型 + """ + def __init__(self, num_labels, feature_dimension): + super(JittorNormalModel_Classification_1, self).__init__() + self.num_labels = num_labels + + self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64) + self.ac1 = nn.ReLU() + self.linear2 = nn.Linear(in_features=64, out_features=32) + self.ac2 = nn.ReLU() + self.output = nn.Linear(in_features=32, out_features=num_labels) + self.loss_fn = nn.CrossEntropyLoss() + + def execute(self, x): + x = self.ac1(self.linear1(x)) + x = self.ac2(self.linear2(x)) + x = self.output(x) + return x + + def train_step(self, x, y): + x = self(x) + return {"loss": self.loss_fn(x, y)} + + def evaluate_step(self, x, y): + + x = self(x) + return {"pred": x, "target": y.reshape((-1,))} + + +class JittorNormalModel_Classification_2(Module): + """ + 基础的 jittor 分类模型,只实现 execute 函数测试用户自己初始化了分布式的场景 + """ + def __init__(self, num_labels, feature_dimension): + super(JittorNormalModel_Classification_2, self).__init__() + self.num_labels = num_labels + + self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64) + self.ac1 = nn.ReLU() + self.linear2 = nn.Linear(in_features=64, out_features=32) + self.ac2 = nn.ReLU() + self.output = nn.Linear(in_features=32, out_features=num_labels) + self.loss_fn = nn.CrossEntropyLoss() + + def execute(self, x, y): + x = self.ac1(self.linear1(x)) + x = self.ac2(self.linear2(x)) + x = self.output(x) + return {"loss": self.loss_fn(x, y), "pred": x, "target": y.reshape((-1,))} diff --git a/tests/helpers/models/paddle_model.py b/tests/helpers/models/paddle_model.py index d2969b8e..b55b2bfe 100644 --- a/tests/helpers/models/paddle_model.py +++ b/tests/helpers/models/paddle_model.py @@ -8,7 +8,7 @@ else: class PaddleNormalModel_Classification_1(Layer): """ - 基础的paddle分类模型 + 基础的 paddle 分类模型 """ def __init__(self, num_labels, feature_dimension): super(PaddleNormalModel_Classification_1, self).__init__() @@ -39,7 +39,7 @@ class PaddleNormalModel_Classification_1(Layer): class PaddleNormalModel_Classification_2(Layer): """ - 基础的paddle分类模型,只实现 forward 函数测试用户自己初始化了分布式的场景 + 基础的 paddle 分类模型,只实现 forward 函数测试用户自己初始化了分布式的场景 """ def __init__(self, num_labels, feature_dimension): super(PaddleNormalModel_Classification_2, self).__init__() @@ -56,5 +56,4 @@ class PaddleNormalModel_Classification_2(Layer): x = self.ac1(self.linear1(x)) x = self.ac2(self.linear2(x)) x = self.output(x) - loss = self.loss_fn(x, y) return {"loss": self.loss_fn(x, y), "pred": x, "target": y.reshape((-1,))} diff --git a/tests/modules/mix_modules/test_utils.py b/tests/modules/mix_modules/test_utils.py index c046d648..8d5982aa 100644 --- a/tests/modules/mix_modules/test_utils.py +++ b/tests/modules/mix_modules/test_utils.py @@ -33,7 +33,11 @@ class TestPaddle2Torch: """ assert isinstance(tensor, torch.Tensor) - assert tensor.device == torch.device(device) + if device == "cpu": + assert not tensor.is_cuda + else: + assert tensor.is_cuda + assert tensor.device.index == torch.device(device).index assert tensor.requires_grad == requires_grad def test_gradient(self): @@ -261,7 +265,8 @@ class TestJittor2Torch: if device == "cpu": assert not tensor.is_cuda else: - assert tensor.device == torch.device(device) + assert tensor.is_cuda + assert tensor.device.index == torch.device(device).index assert tensor.requires_grad == requires_grad def test_var_transfer(self): @@ -271,7 +276,10 @@ class TestJittor2Torch: jittor_var = jittor.rand((3, 4, 5)) res = jittor2torch(jittor_var) - self.check_torch_tensor(res, "cpu", True) + if jittor.flags.use_cuda: + self.check_torch_tensor(res, "cuda:0", True) + else: + self.check_torch_tensor(res, "cpu", True) res = jittor2torch(jittor_var, device="cuda:2", no_gradient=None) self.check_torch_tensor(res, "cuda:2", True) @@ -291,7 +299,10 @@ class TestJittor2Torch: res = jittor2torch(jittor_list) assert isinstance(res, list) for t in res: - self.check_torch_tensor(t, "cpu", True) + if jittor.flags.use_cuda: + self.check_torch_tensor(t, "cuda:0", True) + else: + self.check_torch_tensor(t, "cpu", True) res = jittor2torch(jittor_list, device="cuda:1", no_gradient=False) assert isinstance(res, list) @@ -327,17 +338,29 @@ class TestJittor2Torch: } res = jittor2torch(jittor_dict) assert isinstance(res, dict) - self.check_torch_tensor(res["tensor"], "cpu", True) + if jittor.flags.use_cuda: + self.check_torch_tensor(res["tensor"], "cuda:0", True) + else: + self.check_torch_tensor(res["tensor"], "cpu", True) assert isinstance(res["list"], list) for t in res["list"]: - self.check_torch_tensor(t, "cpu", True) + if jittor.flags.use_cuda: + self.check_torch_tensor(t, "cuda:0", True) + else: + self.check_torch_tensor(t, "cpu", True) assert isinstance(res["int"], int) assert isinstance(res["string"], str) assert isinstance(res["dict"], dict) assert isinstance(res["dict"]["list"], list) for t in res["dict"]["list"]: - self.check_torch_tensor(t, "cpu", True) - self.check_torch_tensor(res["dict"]["tensor"], "cpu", True) + if jittor.flags.use_cuda: + self.check_torch_tensor(t, "cuda:0", True) + else: + self.check_torch_tensor(t, "cpu", True) + if jittor.flags.use_cuda: + self.check_torch_tensor(res["dict"]["tensor"], "cuda:0", True) + else: + self.check_torch_tensor(res["dict"]["tensor"], "cpu", True) ############################################################################