@@ -24,7 +24,6 @@ from fastNLP.core.dataset import DataSet as FDataSet | |||||
class _JittorDataset(Dataset): | class _JittorDataset(Dataset): | ||||
""" | """ | ||||
对用户传的dataset进行封装,以便JittorDataLoader能够支持使用自定义的dataset | 对用户传的dataset进行封装,以便JittorDataLoader能够支持使用自定义的dataset | ||||
""" | """ | ||||
def __init__(self, dataset) -> None: | def __init__(self, dataset) -> None: | ||||
@@ -83,7 +82,7 @@ class JittorDataLoader: | |||||
# TODO 验证支持replacesampler (以后完成) 增加Sampler | # TODO 验证支持replacesampler (以后完成) 增加Sampler | ||||
# 将内部dataset批次设置为1 | # 将内部dataset批次设置为1 | ||||
if isinstance(dataset, Dataset): | if isinstance(dataset, Dataset): | ||||
dataset.set_attrs(batch_size=1) | |||||
dataset.set_attrs(batch_size=1, shuffle=False, endless=False) | |||||
# FastNLP Datset, collate_fn not None | # FastNLP Datset, collate_fn not None | ||||
if isinstance(dataset, FDataSet) and collate_fn is None: | if isinstance(dataset, FDataSet) and collate_fn is None: | ||||
@@ -115,6 +114,12 @@ class JittorDataLoader: | |||||
self.cur_batch_indices = None | self.cur_batch_indices = None | ||||
def __getattr__(self, attr): | |||||
if attr in ["batch_size", "shuffle", "drop_last", "num_workers", "buffer_size", "stop_grad", | |||||
"keep_numpy_array", "endless", "sampler"]: | |||||
return getattr(self.dataset, attr) | |||||
raise AttributeError(f"{self} has not attribute '{attr}'") | |||||
def __iter__(self): | def __iter__(self): | ||||
# TODO 第一次迭代后不能设置collate_fn,设置是无效的 | # TODO 第一次迭代后不能设置collate_fn,设置是无效的 | ||||
if self.cur_batch_indices is None: | if self.cur_batch_indices is None: | ||||
@@ -10,7 +10,7 @@ if _NEED_IMPORT_JITTOR: | |||||
__all__ = [] | __all__ = [] | ||||
def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], model: jittor.Module, **kwargs) -> JittorDriver: | |||||
def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], model: "jittor.Module", **kwargs) -> JittorDriver: | |||||
r""" | r""" | ||||
用来根据参数 ``device`` 来确定并且初始化一个具体的 ``Driver`` 实例然后返回回去。 | 用来根据参数 ``device`` 来确定并且初始化一个具体的 ``Driver`` 实例然后返回回去。 | ||||
@@ -30,7 +30,7 @@ def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], mo | |||||
raise ValueError("Parameter `driver` can only be one of these values: ['jittor'].") | raise ValueError("Parameter `driver` can only be one of these values: ['jittor'].") | ||||
# TODO 实现更详细的判断 | # TODO 实现更详细的判断 | ||||
if device in ["cpu", "gpu", "cuda", "cuda:0", 0, None]: | |||||
if device in ["cpu", "gpu", "cuda", None]: | |||||
return JittorSingleDriver(model, device, **kwargs) | return JittorSingleDriver(model, device, **kwargs) | ||||
elif type(device) is int: | elif type(device) is int: | ||||
return JittorMPIDriver(model, device, **kwargs) | return JittorMPIDriver(model, device, **kwargs) | ||||
@@ -1,23 +1,31 @@ | |||||
import os | import os | ||||
import random | |||||
from pathlib import Path | from pathlib import Path | ||||
from typing import Union, Optional | |||||
from functools import partial | |||||
import numpy as np | |||||
from typing import Union, Optional, Dict | |||||
from contextlib import nullcontext | |||||
from dataclasses import dataclass | |||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | ||||
from fastNLP.core.drivers.driver import Driver | from fastNLP.core.drivers.driver import Driver | ||||
from fastNLP.core.dataloaders import JittorDataLoader | from fastNLP.core.dataloaders import JittorDataLoader | ||||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.utils import apply_to_collection | from fastNLP.core.utils import apply_to_collection | ||||
from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_SEED_WORKERS | |||||
from fastNLP.envs import ( | |||||
FASTNLP_MODEL_FILENAME, | |||||
FASTNLP_CHECKPOINT_FILENAME, | |||||
) | |||||
if _NEED_IMPORT_JITTOR: | if _NEED_IMPORT_JITTOR: | ||||
import jittor as jt | import jittor as jt | ||||
from jittor import Module | from jittor import Module | ||||
from jittor.optim import Optimizer | from jittor.optim import Optimizer | ||||
from jittor.dataset import Dataset | from jittor.dataset import Dataset | ||||
from jittor.dataset import ( | |||||
BatchSampler as JittorBatchSampler, | |||||
Sampler as JittorSampler, | |||||
RandomSampler as JittorRandomSampler, | |||||
SequentialSampler as JittorSequentialSampler | |||||
) | |||||
_reduces = { | _reduces = { | ||||
'max': jt.max, | 'max': jt.max, | ||||
@@ -56,6 +64,7 @@ class JittorDriver(Driver): | |||||
else: | else: | ||||
jt.flags.auto_mixed_precision_level = 0 | jt.flags.auto_mixed_precision_level = 0 | ||||
self.fp16 = fp16 | self.fp16 = fp16 | ||||
self._auto_cast = nullcontext | |||||
# 用来设置是否关闭 auto_param_call 中的参数匹配问题; | # 用来设置是否关闭 auto_param_call 中的参数匹配问题; | ||||
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | ||||
@@ -68,7 +77,7 @@ class JittorDriver(Driver): | |||||
def _check_optimizer_legality(optimizers): | def _check_optimizer_legality(optimizers): | ||||
for each_optimizer in optimizers: | for each_optimizer in optimizers: | ||||
if not isinstance(each_optimizer, Optimizer): | if not isinstance(each_optimizer, Optimizer): | ||||
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'jittor.optim.Optimizer' type, " | |||||
raise TypeError(f"Each optimizer of parameter `optimizers` should be 'jittor.optim.Optimizer' type, " | |||||
f"not {type(each_optimizer)}.") | f"not {type(each_optimizer)}.") | ||||
def step(self): | def step(self): | ||||
@@ -117,30 +126,118 @@ class JittorDriver(Driver): | |||||
model = self.unwrap_model() | model = self.unwrap_model() | ||||
model.load(filepath) | model.load(filepath) | ||||
def save_checkpoint(self): | |||||
... | |||||
def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||||
dataloader_args = self.get_dataloader_args(dataloader) | |||||
if dataloader_args.sampler: | |||||
sampler = dataloader_args.sampler | |||||
else: | |||||
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | |||||
num_consumed_batches = states.pop('num_consumed_batches') | |||||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | |||||
sampler_states = sampler.state_dict() | |||||
# 需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||||
# 会造成多余实际消耗的问题。因为 | |||||
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||||
if num_consumed_samples_array is not None: | |||||
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||||
if dataloader_args.batch_size is not None: | |||||
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||||
else: # 有可能 batch_size 为 None,就只有损失精度了 | |||||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||||
"it may cause missing some samples when reload.") | |||||
num_consumed_batches = sampler_states['num_consumed_samples'] | |||||
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||||
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||||
else: | |||||
if dataloader_args.batch_size is not None: | |||||
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||||
* num_consumed_batches | |||||
else: | |||||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||||
"it may cause missing some samples when reload.") | |||||
states['sampler_states'] = sampler_states | |||||
else: | |||||
raise RuntimeError('The sampler has no `state_dict()` method, fastNLP cannot save the training ' | |||||
'state.') | |||||
# 2. 保存模型的状态; | |||||
if should_save_model: | |||||
if not os.path.exists(folder): | |||||
os.mkdir(folder) | |||||
model_path = folder.joinpath(FASTNLP_MODEL_FILENAME) | |||||
self.save_model(model_path, only_state_dict=only_state_dict) | |||||
# 3. 保存 optimizers 的状态; | |||||
states["optimizers_state_dict"] = self.get_optimizer_state() | |||||
# 4. 保存fp16的状态 | |||||
logger.debug("Save optimizer state dict") | |||||
jt.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) | |||||
def get_optimizer_state(self): | def get_optimizer_state(self): | ||||
# optimizers_state_dict = {} | |||||
# for i in range(len(self.optimizers)): | |||||
# optimizer: torch.optim.Optimizer = self.optimizers[i] | |||||
# optimizer_state = optimizer.state_dict() | |||||
# optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu")) | |||||
# optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | |||||
# return optimizers_state_dict | |||||
... | |||||
optimizers_state_dict = {} | |||||
for i in range(len(self.optimizers)): | |||||
optimizer: Optimizer = self.optimizers[i] | |||||
optimizers_state_dict[f"optimizer{i}"] = optimizer.state_dict() # 注意这里没有使用 deepcopy,测试是不需要的; | |||||
return optimizers_state_dict | |||||
def load_optimizer_state(self, states): | def load_optimizer_state(self, states): | ||||
# assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \ | |||||
# f"checkpoint it is:{len(states)}" | |||||
# for i in range(len(self.optimizers)): | |||||
# optimizer: torch.optim.Optimizer = self.optimizers[i] | |||||
# optimizer.load_state_dict(states[f"optimizer{i}"]) | |||||
# logger.debug("Load optimizer state dict.") | |||||
... | |||||
assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \ | |||||
f"checkpoint it is:{len(states)}" | |||||
for i in range(len(self.optimizers)): | |||||
optimizer: Optimizer = self.optimizers[i] | |||||
optimizer.load_state_dict(states[f"optimizer{i}"]) | |||||
logger.debug("Load optimizer state dict.") | |||||
def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | |||||
states = jt.load(str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) | |||||
# 1. 加载 optimizers 的状态; | |||||
optimizers_state_dict = states.pop("optimizers_state_dict") | |||||
self.load_optimizer_state(optimizers_state_dict) | |||||
# 2. 加载模型状态; | |||||
if should_load_model: | |||||
self.load_model(filepath=folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict=only_state_dict) | |||||
# 3. 加载fp16的状态 | |||||
# 4. 恢复 sampler 的状态; | |||||
dataloader_args = self.get_dataloader_args(dataloader) | |||||
if dataloader_args.sampler is None: | |||||
sampler = RandomSampler(dataloader_args.sampler.dataset, shuffle=dataloader_args.shuffle) | |||||
elif isinstance(dataloader_args.sampler, ReproducibleSampler): | |||||
sampler = dataloader_args.sampler | |||||
elif isinstance(dataloader_args.sampler, JittorRandomSampler): | |||||
sampler = RandomSampler(dataloader_args.sampler.dataset) | |||||
logger.debug("Replace jittor RandomSampler into fastNLP RandomSampler.") | |||||
elif isinstance(dataloader_args.sampler, JittorSequentialSampler): | |||||
sampler = RandomSampler(dataloader_args.sampler.dataset, shuffle=False) | |||||
logger.debug("Replace jittor Sampler into fastNLP RandomSampler without shuffle.") | |||||
elif self.is_distributed(): | |||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our" | |||||
"`ReproducibleSampler`.") | |||||
else: | |||||
raise RuntimeError(f"Jittor sampler {type(dataloader_args.sampler)} is not supported now.") | |||||
sampler.load_state_dict(states.pop('sampler_states')) | |||||
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) | |||||
# 4. 修改 trainer_state.batch_idx_in_epoch | |||||
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | |||||
if dataloader_args.drop_last: | |||||
batch_idx_in_epoch = len( | |||||
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size | |||||
else: | |||||
batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ | |||||
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size | |||||
def load_checkpoint(self): | |||||
... | |||||
states["batch_idx_in_epoch"] = batch_idx_in_epoch | |||||
return states | |||||
def get_evaluate_context(self): | def get_evaluate_context(self): | ||||
return jt.no_grad | return jt.no_grad | ||||
@@ -198,26 +295,8 @@ class JittorDriver(Driver): | |||||
""" | """ | ||||
return batch | return batch | ||||
@staticmethod | |||||
def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover | |||||
global_rank = rank if rank is not None else int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) | |||||
process_seed = jt.get_seed() | |||||
# back out the base seed so we can use all the bits | |||||
base_seed = process_seed - worker_id | |||||
ss = np.random.SeedSequence([base_seed, worker_id, global_rank]) | |||||
# use 128 bits (4 x 32-bit words) | |||||
np.random.seed(ss.generate_state(4)) | |||||
# Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module | |||||
jittor_ss, stdlib_ss = ss.spawn(2) | |||||
jt.set_global_seed(jittor_ss.generate_state(1, dtype=np.uint64)[0]) | |||||
# use 128 bits expressed as an integer | |||||
stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum() | |||||
random.seed(stdlib_seed) | |||||
def set_deterministic_dataloader(self, dataloader: Union["JittorDataLoader", "Dataset"]): | def set_deterministic_dataloader(self, dataloader: Union["JittorDataLoader", "Dataset"]): | ||||
if int(os.environ.get(FASTNLP_SEED_WORKERS, 0)) and dataloader.worker_init_fn is None: | |||||
dataloader.worker_init_fn = partial(self.worker_init_function, | |||||
rank=int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))) | |||||
... | |||||
def set_sampler_epoch(self, dataloader: Union["JittorDataLoader", "Dataset"], cur_epoch_idx: int): | def set_sampler_epoch(self, dataloader: Union["JittorDataLoader", "Dataset"], cur_epoch_idx: int): | ||||
# 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; | # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; | ||||
@@ -226,4 +305,45 @@ class JittorDriver(Driver): | |||||
@staticmethod | @staticmethod | ||||
def get_dataloader_args(dataloader: Union["JittorDataLoader", "Dataset"]): | def get_dataloader_args(dataloader: Union["JittorDataLoader", "Dataset"]): | ||||
pass | |||||
@dataclass | |||||
class Res: | |||||
dataset: Optional[Dataset] = None | |||||
batch_sampler: Optional[JittorBatchSampler] = None | |||||
sampler: Optional[JittorSampler] = None | |||||
batch_size: Optional[int] = None | |||||
shuffle: Optional[bool] = None | |||||
drop_last: Optional[bool] = None | |||||
res = Res() | |||||
from fastNLP.core.dataloaders.jittor_dataloader.fdl import _JittorDataset | |||||
if isinstance(dataloader, JittorDataLoader): | |||||
# JittorDataLoader 实际上是迭代 dataset 成员的 | |||||
dataloader = dataloader.dataset | |||||
if isinstance(dataloader, _JittorDataset): | |||||
# 获取最原始的 dataset | |||||
res.dataset = dataloader.dataset | |||||
else: | |||||
res.dataset = dataloader | |||||
# jittor 现在不支持 batch_sampler,所以除了 shuffle 都可以直接获取 | |||||
res.batch_size = dataloader.batch_size | |||||
res.drop_last = dataloader.drop_last | |||||
if dataloader.sampler is None: | |||||
# sampler 是 None,那么就从 Dataset 的属性中获取 | |||||
res.shuffle = dataloader.shuffle | |||||
elif isinstance(list(dataloader.sampler.__iter__())[0], (list,tuple)): | |||||
# jittor 目前不支持 batch_sampler | |||||
raise NotImplementedError("Jittor does not support using batch_sampler in `Dataset` now, " | |||||
"please check if you have set `Dataset.sampler` as `BatchSampler`") | |||||
else: | |||||
# sampler 不为 None | |||||
res.sampler = dataloader.sampler | |||||
if hasattr(dataloader.sampler, "shuffle"): | |||||
# 这种情况一般出现在 fastNLP 的 ReproduceSampler 中 | |||||
res.shuffle = dataloader.sampler.shuffle | |||||
elif isinstance(dataloader.sampler, JittorRandomSampler): | |||||
res.shuffle = True | |||||
else: | |||||
res.shuffle = False | |||||
return res |
@@ -38,6 +38,7 @@ class JittorMPIDriver(JittorDriver): | |||||
): | ): | ||||
super(JittorMPIDriver, self).__init__(model, fp16=fp16, **kwargs) | super(JittorMPIDriver, self).__init__(model, fp16=fp16, **kwargs) | ||||
raise NotImplementedError("MPI for Jittor is not supported right now.") | |||||
self.is_pull_by_jittor_run = is_pull_by_jittor_run | self.is_pull_by_jittor_run = is_pull_by_jittor_run | ||||
self.parallel_device = parallel_device | self.parallel_device = parallel_device | ||||
@@ -100,22 +101,6 @@ class JittorMPIDriver(JittorDriver): | |||||
return self._data_device | return self._data_device | ||||
return self.parallel_device | return self.parallel_device | ||||
def step(self): | |||||
# for optimizer in self.optimizers: | |||||
# self.grad_scaler.step(optimizer) | |||||
# self.grad_scaler.update() | |||||
for optimizer in self.optimizers: | |||||
optimizer.step() | |||||
def backward(self, loss): | |||||
# self.grad_scaler.scale(loss).backward() | |||||
for optimizer in self.optimizers: | |||||
optimizer.backward(loss) | |||||
def zero_grad(self): | |||||
for optimizer in self.optimizers: | |||||
optimizer.zero_grad() | |||||
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | ||||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | if isinstance(batch, Dict) and not self.wo_auto_param_call: | ||||
return auto_param_call(fn, batch, signature_fn=signature_fn) | return auto_param_call(fn, batch, signature_fn=signature_fn) | ||||
@@ -1,14 +1,21 @@ | |||||
from typing import Dict, Union, Tuple, Callable, Optional | from typing import Dict, Union, Tuple, Callable, Optional | ||||
from .jittor_driver import JittorDriver | from .jittor_driver import JittorDriver | ||||
from .utils import replace_batch_sampler, replace_sampler | |||||
from fastNLP.core.utils import auto_param_call | from fastNLP.core.utils import auto_param_call | ||||
from fastNLP.core.utils.utils import _get_fun_msg | from fastNLP.core.utils.utils import _get_fun_msg | ||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, \ | |||||
ReproduceBatchSampler | |||||
from fastNLP.core.samplers import RandomSampler | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
if _NEED_IMPORT_JITTOR: | if _NEED_IMPORT_JITTOR: | ||||
import jittor as jt | import jittor as jt | ||||
from jittor.dataset import ( | |||||
RandomSampler as JittorRandomSampler, | |||||
SequentialSampler as JittorSequentialSampler, | |||||
) | |||||
__all__ = [ | __all__ = [ | ||||
"JittorSingleDriver", | "JittorSingleDriver", | ||||
@@ -89,31 +96,46 @@ class JittorSingleDriver(JittorDriver): | |||||
""" | """ | ||||
return False | return False | ||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], | |||||
reproducible: bool = False, sampler_or_batch_sampler=None): | |||||
# reproducible 的相关功能暂时没有实现 | |||||
def set_dist_repro_dataloader(self, dataloader, | |||||
dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler] = None, | |||||
reproducible: bool = False): | |||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load_checkpoint 函数调用; | |||||
if isinstance(dist, ReproducibleBatchSampler): | if isinstance(dist, ReproducibleBatchSampler): | ||||
raise NotImplementedError | |||||
dataloader.batch_sampler = dist_sample | |||||
if isinstance(dist, ReproducibleSampler): | |||||
raise NotImplementedError | |||||
dataloader.batch_sampler.sampler = dist | |||||
return replace_batch_sampler(dataloader, dist) | |||||
elif isinstance(dist, ReproducibleSampler): | |||||
return replace_sampler(dataloader, dist) | |||||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | |||||
args = self.get_dataloader_args(dataloader) | |||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||||
batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
elif isinstance(args.sampler, ReproducibleSampler): | |||||
sampler = re_instantiate_sampler(args.sampler) | |||||
return replace_sampler(dataloader, sampler) | |||||
if reproducible: | if reproducible: | ||||
raise NotImplementedError | |||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): | |||||
return dataloader | |||||
elif isinstance(dataloader.batch_sampler, RandomBatchSampler): | |||||
return dataloader | |||||
else: | |||||
# TODO | |||||
batch_sampler = RandomBatchSampler( | |||||
batch_sampler=dataloader.batch_sampler, | |||||
batch_size=dataloader.batch_sampler.batch_size, | |||||
drop_last=dataloader.drop_last | |||||
) | |||||
dataloader.batch_sampler = batch_sampler | |||||
return dataloader | |||||
if args.sampler is None: | |||||
sampler = RandomSampler(args.dataset, args.shuffle) | |||||
return replace_sampler(dataloader, sampler) | |||||
elif isinstance(args.sampler, JittorRandomSampler): | |||||
if getattr(args.sampler, '_num_samples', None) is None \ | |||||
and getattr(args.sampler, 'rep', False) is False: | |||||
# 如果本来就是随机的,并且没有定制,直接替换掉吧。 | |||||
sampler = RandomSampler(args.sampler.dataset, shuffle=True) | |||||
logger.debug("Replace jittor RandomSampler into fastNLP RandomSampler.") | |||||
return replace_sampler(dataloader, sampler) | |||||
elif isinstance(args.sampler, JittorSequentialSampler): | |||||
# 需要替换为不要 shuffle 的。 | |||||
sampler = RandomSampler(args.sampler.dataset, shuffle=False) | |||||
logger.debug("Replace jittor SequentialSampler into fastNLP RandomSampler.") | |||||
return replace_sampler(dataloader, sampler) | |||||
batch_sampler = ReproduceBatchSampler( | |||||
batch_sampler=args.batch_sampler, | |||||
batch_size=args.batch_size, | |||||
drop_last=args.drop_last | |||||
) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
else: | else: | ||||
return dataloader | return dataloader | ||||
@@ -1,6 +1,29 @@ | |||||
import inspect | |||||
from copy import deepcopy | |||||
from typing import Union | |||||
from fastNLP.core.dataloaders import JittorDataLoader | |||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | ||||
if _NEED_IMPORT_JITTOR: | if _NEED_IMPORT_JITTOR: | ||||
import jittor | |||||
from jittor.dataset import Dataset | |||||
__all__ = [] | __all__ = [] | ||||
def replace_batch_sampler(dataloader, batch_sampler): | |||||
raise NotImplementedError("Jittor does not support using batch_sampler in `Dataset` now, " | |||||
"please check if you have set `Dataset.sampler` as `BatchSampler`" | |||||
"or report this bug to us.") | |||||
def replace_sampler(dataloader: Union["Dataset", "JittorDataLoader"], sampler): | |||||
if isinstance(dataloader, JittorDataLoader): | |||||
init_params = dict(inspect.signature(dataloader.__init__).parameters) | |||||
reconstruct_args = {name: getattr(dataloader, name, p.default) for name, p in init_params.items()} | |||||
reconstruct_args["dataset"] = replace_sampler(reconstruct_args["dataset"].dataset, reconstruct_args["dataset"].sampler) | |||||
new_dataloader = type(dataloader)(**reconstruct_args) | |||||
new_dataloader.dataset.set_attrs(sampler=sampler) | |||||
else: | |||||
new_dataloader = deepcopy(dataloader) | |||||
new_dataloader.set_attrs(sampler=sampler) | |||||
return new_dataloader |
@@ -31,7 +31,6 @@ if _NEED_IMPORT_PADDLE: | |||||
import paddle | import paddle | ||||
from paddle.io import ( | from paddle.io import ( | ||||
DataLoader, | DataLoader, | ||||
IterableDataset, | |||||
Dataset, | Dataset, | ||||
Sampler, | Sampler, | ||||
BatchSampler, | BatchSampler, | ||||
@@ -97,6 +96,9 @@ class PaddleDriver(Driver): | |||||
def check_dataloader_legality(self, dataloader): | def check_dataloader_legality(self, dataloader): | ||||
if not isinstance(dataloader, DataLoader): | if not isinstance(dataloader, DataLoader): | ||||
raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") | raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") | ||||
if dataloader.batch_size is None and dataloader.batch_sampler is None: | |||||
raise ValueError("Please ensure at least one of your dataloader's batch_size and batch_sampler" | |||||
"is not None") | |||||
@staticmethod | @staticmethod | ||||
def _check_optimizer_legality(optimizers): | def _check_optimizer_legality(optimizers): | ||||
@@ -107,7 +109,7 @@ class PaddleDriver(Driver): | |||||
""" | """ | ||||
for each_optimizer in optimizers: | for each_optimizer in optimizers: | ||||
if not isinstance(each_optimizer, Optimizer): | if not isinstance(each_optimizer, Optimizer): | ||||
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, " | |||||
raise TypeError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, " | |||||
f"not {type(each_optimizer)}.") | f"not {type(each_optimizer)}.") | ||||
@staticmethod | @staticmethod | ||||
@@ -263,9 +265,7 @@ class PaddleDriver(Driver): | |||||
optimizers_state_dict = {} | optimizers_state_dict = {} | ||||
for i in range(len(self.optimizers)): | for i in range(len(self.optimizers)): | ||||
optimizer: Optimizer = self.optimizers[i] | optimizer: Optimizer = self.optimizers[i] | ||||
optimizer_state = optimizer.state_dict() | |||||
optimizer_state["state"] = optimizer_state_to_device(optimizer_state, "cpu") | |||||
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | |||||
optimizers_state_dict[f"optimizer{i}"] = optimizer_state_to_device(optimizer.state_dict(), "cpu") | |||||
return optimizers_state_dict | return optimizers_state_dict | ||||
@@ -399,6 +399,8 @@ class PaddleDriver(Driver): | |||||
def set_sampler_epoch(self, dataloader: "DataLoader", cur_epoch_idx): | def set_sampler_epoch(self, dataloader: "DataLoader", cur_epoch_idx): | ||||
if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): | if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): | ||||
dataloader.batch_sampler.set_epoch(cur_epoch_idx) | dataloader.batch_sampler.set_epoch(cur_epoch_idx) | ||||
elif callable(getattr(dataloader.batch_sampler.sampler, "set_epoch", None)): | |||||
dataloader.batch_sampler.sampler.set_epoch(cur_epoch_idx) | |||||
@staticmethod | @staticmethod | ||||
def get_dataloader_args(dataloader: "DataLoader"): | def get_dataloader_args(dataloader: "DataLoader"): | ||||
@@ -99,7 +99,7 @@ class TorchDriver(Driver): | |||||
def _check_optimizer_legality(optimizers): | def _check_optimizer_legality(optimizers): | ||||
for each_optimizer in optimizers: | for each_optimizer in optimizers: | ||||
if not isinstance(each_optimizer, Optimizer): | if not isinstance(each_optimizer, Optimizer): | ||||
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' type, " | |||||
raise TypeError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' type, " | |||||
f"not {type(each_optimizer)}.") | f"not {type(each_optimizer)}.") | ||||
@staticmethod | @staticmethod | ||||
@@ -210,7 +210,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
self.during_iter = True | self.during_iter = True | ||||
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||||
indices = list(range(self.num_samples)) | |||||
if self.shuffle: | if self.shuffle: | ||||
if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的 | if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的 | ||||
@@ -237,7 +237,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
if len(indices)%self.batch_size!=0: | if len(indices)%self.batch_size!=0: | ||||
batches.append(indices[_num_batches*self.batch_size:]) | batches.append(indices[_num_batches*self.batch_size:]) | ||||
need_pad_num = (getattr(self.dataset, 'total_len', len(self.dataset))-self.num_consumed_samples) % self.num_replicas | |||||
need_pad_num = (self.num_samples-self.num_consumed_samples) % self.num_replicas | |||||
if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: | if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: | ||||
if len(batches) > 0: | if len(batches) > 0: | ||||
if len(batches[-1])<self.batch_size: | if len(batches[-1])<self.batch_size: | ||||
@@ -290,9 +290,9 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
@property | @property | ||||
def batch_idx_in_epoch(self): | def batch_idx_in_epoch(self): | ||||
if self.drop_last: | if self.drop_last: | ||||
return getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||||
return self.num_samples // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||||
else: | else: | ||||
return (getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||||
return (self.num_samples // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||||
(self.num_left_samples + self.batch_size - 1) // self.batch_size | (self.num_left_samples + self.batch_size - 1) // self.batch_size | ||||
@property | @property | ||||
@@ -313,8 +313,12 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
:return: | :return: | ||||
""" | """ | ||||
num_consumed_samples = self.num_consumed_samples | num_consumed_samples = self.num_consumed_samples | ||||
return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \ | |||||
self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas)) | |||||
return math.ceil((self.num_samples - num_consumed_samples) / self.num_replicas) if \ | |||||
self.pad else math.floor(((self.num_samples - num_consumed_samples) / self.num_replicas)) | |||||
@property | |||||
def num_samples(self): | |||||
return getattr(self.dataset, 'total_len', len(self.dataset)) | |||||
def __len__(self)->int: | def __len__(self)->int: | ||||
""" | """ | ||||
@@ -332,7 +336,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | ||||
" consumed. ") | " consumed. ") | ||||
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | ||||
'sampler_type': self.__class__.__name__, 'length': getattr(self.dataset, 'total_len', len(self.dataset)), 'shuffle': self.shuffle, | |||||
'sampler_type': self.__class__.__name__, 'length': self.num_samples, 'shuffle': self.shuffle, | |||||
'batch_size': self.batch_size, | 'batch_size': self.batch_size, | ||||
'num_replicas': self.num_replicas} | 'num_replicas': self.num_replicas} | ||||
@@ -347,7 +351,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
f"we cannot use {self.__class__.__name__} to load it." | f"we cannot use {self.__class__.__name__} to load it." | ||||
length = states['length'] | length = states['length'] | ||||
assert length == getattr(self.dataset, 'total_len', len(self.dataset)), "The number of samples is different between the checkpoint record " \ | |||||
assert length == self.num_samples, "The number of samples is different between the checkpoint record " \ | |||||
"and current dataset." | "and current dataset." | ||||
self.seed = states['seed'] | self.seed = states['seed'] | ||||
self.epoch = states['epoch'] | self.epoch = states['epoch'] | ||||
@@ -464,8 +468,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
:return: | :return: | ||||
""" | """ | ||||
num_consumed_samples = self.num_consumed_samples | num_consumed_samples = self.num_consumed_samples | ||||
return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \ | |||||
self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas)) | |||||
return math.ceil((self.num_samples - num_consumed_samples) / self.num_replicas) if \ | |||||
self.pad else math.floor(((self.num_samples - num_consumed_samples) / self.num_replicas)) | |||||
@property | |||||
def num_samples(self): | |||||
return getattr(self.dataset, 'total_len', len(self.dataset)) | |||||
def __len__(self)->int: | def __len__(self)->int: | ||||
""" | """ | ||||
@@ -515,7 +523,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
if len(sorted_indices)%self.batch_size!=0: | if len(sorted_indices)%self.batch_size!=0: | ||||
batches.append(sorted_indices[_num_batches*self.batch_size:]) | batches.append(sorted_indices[_num_batches*self.batch_size:]) | ||||
need_pad_num = (getattr(self.dataset, 'total_len', len(self.dataset))-self.num_consumed_samples) % self.num_replicas | |||||
need_pad_num = (self.num_samples-self.num_consumed_samples) % self.num_replicas | |||||
if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: | if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: | ||||
if len(batches) > 0: | if len(batches) > 0: | ||||
if len(batches[-1])<self.batch_size: | if len(batches[-1])<self.batch_size: | ||||
@@ -593,7 +601,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | ||||
" consumed. ") | " consumed. ") | ||||
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | ||||
'sampler_type': self.__class__.__name__, 'length': getattr(self.dataset, 'total_len', len(self.dataset)), 'shuffle': self.shuffle, | |||||
'sampler_type': self.__class__.__name__, 'length': self.num_samples, 'shuffle': self.shuffle, | |||||
'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket, | 'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket, | ||||
'num_replicas': self.num_replicas | 'num_replicas': self.num_replicas | ||||
} | } | ||||
@@ -609,7 +617,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
f"we cannot use {self.__class__.__name__} to load it." | f"we cannot use {self.__class__.__name__} to load it." | ||||
length = states['length'] | length = states['length'] | ||||
assert length == getattr(self.dataset, 'total_len', len(self.dataset)), "The number of samples is different between the checkpoint record " \ | |||||
assert length == self.num_samples, "The number of samples is different between the checkpoint record " \ | |||||
"and current dataset." | "and current dataset." | ||||
self.seed = states['seed'] | self.seed = states['seed'] | ||||
self.epoch = states['epoch'] | self.epoch = states['epoch'] | ||||
@@ -630,7 +638,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
@property | @property | ||||
def batch_idx_in_epoch(self): | def batch_idx_in_epoch(self): | ||||
if self.drop_last: | if self.drop_last: | ||||
return getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||||
return self.num_samples // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||||
else: | else: | ||||
return (getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||||
return (self.num_samples // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||||
(self.num_left_samples + self.batch_size - 1) // self.batch_size | (self.num_left_samples + self.batch_size - 1) // self.batch_size |
@@ -48,6 +48,10 @@ class ReproducibleSampler: | |||||
def num_left_samples(self): | def num_left_samples(self): | ||||
raise NotImplementedError("Each specific sampler should implement its own `num_left_samples` method.") | raise NotImplementedError("Each specific sampler should implement its own `num_left_samples` method.") | ||||
@property | |||||
def num_samples(self): | |||||
raise NotImplementedError("Each specific sampler should implement its own `num_samples` method.") | |||||
def set_epoch(self, epoch): | def set_epoch(self, epoch): | ||||
pass | pass | ||||
@@ -131,19 +135,19 @@ class RandomSampler(ReproducibleSampler): | |||||
:return: | :return: | ||||
""" | """ | ||||
if self.shuffle: | if self.shuffle: | ||||
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||||
indices = list(range(self.num_samples)) | |||||
seed = self.seed + self.epoch | seed = self.seed + self.epoch | ||||
rng = np.random.default_rng(abs(seed)) | rng = np.random.default_rng(abs(seed)) | ||||
rng.shuffle(indices) | rng.shuffle(indices) | ||||
if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 | if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 | ||||
self.epoch -= 1 | self.epoch -= 1 | ||||
else: | else: | ||||
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||||
indices = list(range(self.num_samples)) | |||||
return indices | return indices | ||||
def state_dict(self) -> Dict: | def state_dict(self) -> Dict: | ||||
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | ||||
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle} | |||||
'sampler_type': self.__class__.__name__, 'length': self.num_samples, 'shuffle': self.shuffle} | |||||
return states | return states | ||||
def load_state_dict(self, states: Dict): | def load_state_dict(self, states: Dict): | ||||
@@ -155,8 +159,8 @@ class RandomSampler(ReproducibleSampler): | |||||
f"we cannot use {self.__class__.__name__} to load it." | f"we cannot use {self.__class__.__name__} to load it." | ||||
length = states['length'] | length = states['length'] | ||||
assert length == getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of samples is different between the checkpoint record({length}) " \ | |||||
f"and current dataset({getattr(self.dataset, 'total_len', len(self.dataset))})." | |||||
assert length == self.num_samples, "The number of samples is different between the checkpoint " \ | |||||
f"record({length}) and current dataset({self.num_samples})." | |||||
self.seed = states['seed'] | self.seed = states['seed'] | ||||
self.epoch = states['epoch'] | self.epoch = states['epoch'] | ||||
self.num_consumed_samples = states['num_consumed_samples'] | self.num_consumed_samples = states['num_consumed_samples'] | ||||
@@ -208,9 +212,17 @@ class RandomSampler(ReproducibleSampler): | |||||
:return: | :return: | ||||
""" | """ | ||||
num_consumed_samples = self.num_consumed_samples | num_consumed_samples = self.num_consumed_samples | ||||
return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \ | |||||
self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas)) | |||||
return math.ceil((self.num_samples - num_consumed_samples) / self.num_replicas) if \ | |||||
self.pad else math.floor(((self.num_samples - num_consumed_samples) / self.num_replicas)) | |||||
@property | |||||
def num_samples(self): | |||||
""" | |||||
返回样本的总数 | |||||
:return: | |||||
""" | |||||
return getattr(self.dataset, 'total_len', len(self.dataset)) | |||||
class SequentialSampler(RandomSampler): | class SequentialSampler(RandomSampler): | ||||
""" | """ | ||||
@@ -258,12 +270,10 @@ class SequentialSampler(RandomSampler): | |||||
:return: | :return: | ||||
""" | """ | ||||
return list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||||
return list(range(self.num_samples)) | |||||
def state_dict(self) -> Dict: | def state_dict(self) -> Dict: | ||||
states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, | |||||
'length': getattr(self.dataset, 'total_len', len(self.dataset)) | |||||
} | |||||
states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, 'length': self.num_samples} | |||||
return states | return states | ||||
def load_state_dict(self, states: Dict): | def load_state_dict(self, states: Dict): | ||||
@@ -275,8 +285,8 @@ class SequentialSampler(RandomSampler): | |||||
f"we cannot use {self.__class__.__name__} to load it." | f"we cannot use {self.__class__.__name__} to load it." | ||||
length = states['length'] | length = states['length'] | ||||
assert length == getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of samples is different between the checkpoint record({length}) " \ | |||||
f"and current dataset({getattr(self.dataset, 'total_len', len(self.dataset))})." | |||||
assert length == self.num_samples, "The number of samples is different between the checkpoint " \ | |||||
f"record({length}) and current dataset({self.num_samples})." | |||||
self.num_consumed_samples = states['num_consumed_samples'] | self.num_consumed_samples = states['num_consumed_samples'] | ||||
if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | ||||
self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
@@ -314,9 +324,9 @@ class SortedSampler(SequentialSampler): | |||||
except BaseException as e: | except BaseException as e: | ||||
logger.error(f"Cannot use {self.__class__.__name__} as length, since it is not sortable.") | logger.error(f"Cannot use {self.__class__.__name__} as length, since it is not sortable.") | ||||
assert len(length) == getattr(self.dataset, 'total_len', len(self.dataset)), f"The length of `dataset`({len(dataset)}) and " \ | |||||
f"`length`({getattr(self.dataset, 'total_len', len(self.dataset))}) should be equal." | |||||
assert len(self.sorted_indices) == getattr(self.dataset, 'total_len', len(self.dataset)), "The indices and dataset should have equal length." | |||||
assert len(length) == self.num_samples, f"The length of `dataset`({len(dataset)}) and " \ | |||||
f"`length`({self.num_samples}) should be equal." | |||||
assert len(self.sorted_indices) == self.num_samples, "The indices and dataset should have equal length." | |||||
self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | ||||
self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的 | self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的 | ||||
@@ -42,8 +42,8 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||||
返回 sampler 一次完整的迭代过程会产生多少个index。多卡的情况下,只考虑当前rank; | 返回 sampler 一次完整的迭代过程会产生多少个index。多卡的情况下,只考虑当前rank; | ||||
:return: | :return: | ||||
""" | """ | ||||
num_common = getattr(self.dataset, 'total_len', len(self.dataset))//self.num_replicas | |||||
num_samples = num_common + int(self.rank < (getattr(self.dataset, 'total_len', len(self.dataset))-num_common*self.num_replicas)) | |||||
num_common = self.num_samples//self.num_replicas | |||||
num_samples = num_common + int(self.rank < (self.num_samples-num_common*self.num_replicas)) | |||||
return num_samples | return num_samples | ||||
def __iter__(self): | def __iter__(self): | ||||
@@ -63,14 +63,14 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||||
:return: | :return: | ||||
""" | """ | ||||
if self.shuffle: | if self.shuffle: | ||||
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||||
indices = list(range(self.num_samples)) | |||||
seed = self.seed + self.epoch | seed = self.seed + self.epoch | ||||
rng = np.random.default_rng(abs(seed)) | rng = np.random.default_rng(abs(seed)) | ||||
rng.shuffle(indices) | rng.shuffle(indices) | ||||
if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 | if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 | ||||
self.epoch -= 1 | self.epoch -= 1 | ||||
else: | else: | ||||
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||||
indices = list(range(self.num_samples)) | |||||
return indices | return indices | ||||
def set_epoch(self, epoch: int) -> None: | def set_epoch(self, epoch: int) -> None: | ||||
@@ -84,8 +84,8 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||||
:param rank: | :param rank: | ||||
:return: | :return: | ||||
""" | """ | ||||
assert num_replicas<=getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of replicas({num_replicas}) should be lesser than the " \ | |||||
f"number of samples({getattr(self.dataset, 'total_len', len(self.dataset))})." | |||||
assert num_replicas<=self.num_samples, f"The number of replicas({num_replicas}) should be lesser than the " \ | |||||
f"number of samples({self.num_samples})." | |||||
assert num_replicas>0 and isinstance(num_replicas, int) | assert num_replicas>0 and isinstance(num_replicas, int) | ||||
assert isinstance(rank, int) and 0<=rank<num_replicas | assert isinstance(rank, int) and 0<=rank<num_replicas | ||||
# 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态; | # 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态; | ||||
@@ -94,6 +94,15 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||||
return self | return self | ||||
@property | |||||
def num_samples(self): | |||||
""" | |||||
返回样本的总数 | |||||
:return: | |||||
""" | |||||
return getattr(self.dataset, 'total_len', len(self.dataset)) | |||||
class UnrepeatedSortedSampler(UnrepeatedRandomSampler): | class UnrepeatedSortedSampler(UnrepeatedRandomSampler): | ||||
""" | """ | ||||
@@ -147,5 +156,5 @@ class UnrepeatedSequentialSampler(UnrepeatedRandomSampler): | |||||
yield index | yield index | ||||
def generate_indices(self) -> List[int]: | def generate_indices(self) -> List[int]: | ||||
return list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||||
return list(range(self.num_samples)) | |||||
@@ -27,7 +27,7 @@ from paddle.optimizer import Adam | |||||
from paddle.io import DataLoader | from paddle.io import DataLoader | ||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | ||||
from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset | |||||
from tests.helpers.datasets.paddle_data import PaddleArgMaxDataset | |||||
from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback | from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback | ||||
@dataclass | @dataclass | ||||
@@ -52,12 +52,12 @@ def test_trainer_fleet( | |||||
optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001) | optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001) | ||||
train_dataloader = DataLoader( | train_dataloader = DataLoader( | ||||
dataset=PaddleRandomMaxDataset(20, MNISTTrainFleetConfig.feature_dimension), | |||||
dataset=PaddleArgMaxDataset(20, MNISTTrainFleetConfig.feature_dimension), | |||||
batch_size=MNISTTrainFleetConfig.batch_size, | batch_size=MNISTTrainFleetConfig.batch_size, | ||||
shuffle=True | shuffle=True | ||||
) | ) | ||||
val_dataloader = DataLoader( | val_dataloader = DataLoader( | ||||
dataset=PaddleRandomMaxDataset(12, MNISTTrainFleetConfig.feature_dimension), | |||||
dataset=PaddleArgMaxDataset(12, MNISTTrainFleetConfig.feature_dimension), | |||||
batch_size=MNISTTrainFleetConfig.batch_size, | batch_size=MNISTTrainFleetConfig.batch_size, | ||||
shuffle=True | shuffle=True | ||||
) | ) | ||||
@@ -24,7 +24,7 @@ from paddle.io import DataLoader | |||||
import paddle.distributed.fleet as fleet | import paddle.distributed.fleet as fleet | ||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_2 | from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_2 | ||||
from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset | |||||
from tests.helpers.datasets.paddle_data import PaddleArgMaxDataset | |||||
from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback | from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback | ||||
@dataclass | @dataclass | ||||
@@ -54,12 +54,12 @@ def test_trainer_fleet( | |||||
optimizers = fleet.distributed_optimizer(optimizers) | optimizers = fleet.distributed_optimizer(optimizers) | ||||
train_dataloader = DataLoader( | train_dataloader = DataLoader( | ||||
dataset=PaddleRandomMaxDataset(20, MNISTTrainFleetConfig.feature_dimension), | |||||
dataset=PaddleArgMaxDataset(20, MNISTTrainFleetConfig.feature_dimension), | |||||
batch_size=MNISTTrainFleetConfig.batch_size, | batch_size=MNISTTrainFleetConfig.batch_size, | ||||
shuffle=True | shuffle=True | ||||
) | ) | ||||
val_dataloader = DataLoader( | val_dataloader = DataLoader( | ||||
dataset=PaddleRandomMaxDataset(12, MNISTTrainFleetConfig.feature_dimension), | |||||
dataset=PaddleArgMaxDataset(12, MNISTTrainFleetConfig.feature_dimension), | |||||
batch_size=MNISTTrainFleetConfig.batch_size, | batch_size=MNISTTrainFleetConfig.batch_size, | ||||
shuffle=True | shuffle=True | ||||
) | ) | ||||
@@ -46,8 +46,8 @@ class LSTM(Module): | |||||
def init_hidden(self, x): | def init_hidden(self, x): | ||||
# batch_first | # batch_first | ||||
batch_size = x.shape[0] | batch_size = x.shape[0] | ||||
h0 = jt.randn(1, batch_size, hidden_size) | |||||
c0 = jt.randn(1, batch_size, hidden_size) | |||||
h0 = jt.randn(1, batch_size, self.hidden_size) | |||||
c0 = jt.randn(1, batch_size, self.hidden_size) | |||||
return h0, c0 | return h0, c0 | ||||
@@ -1,4 +1,5 @@ | |||||
import pytest | import pytest | ||||
from fastNLP.core.callbacks import callback | |||||
from fastNLP.core.controllers.trainer import Trainer | from fastNLP.core.controllers.trainer import Trainer | ||||
from fastNLP.core.controllers.trainer import Evaluator | from fastNLP.core.controllers.trainer import Evaluator | ||||
@@ -14,6 +15,7 @@ if _NEED_IMPORT_JITTOR: | |||||
else: | else: | ||||
from fastNLP.core.utils.dummy_class import DummyClass as Module | from fastNLP.core.utils.dummy_class import DummyClass as Module | ||||
from fastNLP.core.utils.dummy_class import DummyClass as Dataset | from fastNLP.core.utils.dummy_class import DummyClass as Dataset | ||||
jt.flags.use_cuda=1 | |||||
class JittorNormalModel_Classification(Module): | class JittorNormalModel_Classification(Module): | ||||
@@ -68,11 +70,9 @@ class TrainJittorConfig: | |||||
batch_size: int = 4 | batch_size: int = 4 | ||||
shuffle: bool = True | shuffle: bool = True | ||||
@pytest.mark.parametrize("driver", ["jittor"]) | @pytest.mark.parametrize("driver", ["jittor"]) | ||||
@pytest.mark.parametrize("device", ["cpu", "gpu", "cuda:0"]) | |||||
@pytest.mark.parametrize("device", ["cpu", "gpu", "cuda", None]) | |||||
@pytest.mark.parametrize("callbacks", [[RichCallback(100)]]) | @pytest.mark.parametrize("callbacks", [[RichCallback(100)]]) | ||||
@pytest.mark.jittor | |||||
def test_trainer_jittor( | def test_trainer_jittor( | ||||
driver, | driver, | ||||
device, | device, | ||||
@@ -15,7 +15,7 @@ if _NEED_IMPORT_PADDLE: | |||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | ||||
from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset | |||||
from tests.helpers.datasets.paddle_data import PaddleArgMaxDataset | |||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
@dataclass | @dataclass | ||||
@@ -44,12 +44,12 @@ def test_trainer_paddle( | |||||
) | ) | ||||
optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001) | optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001) | ||||
train_dataloader = DataLoader( | train_dataloader = DataLoader( | ||||
dataset=PaddleRandomMaxDataset(20, TrainPaddleConfig.feature_dimension), | |||||
dataset=PaddleArgMaxDataset(20, TrainPaddleConfig.feature_dimension), | |||||
batch_size=TrainPaddleConfig.batch_size, | batch_size=TrainPaddleConfig.batch_size, | ||||
shuffle=True | shuffle=True | ||||
) | ) | ||||
val_dataloader = DataLoader( | val_dataloader = DataLoader( | ||||
dataset=PaddleRandomMaxDataset(12, TrainPaddleConfig.feature_dimension), | |||||
dataset=PaddleArgMaxDataset(12, TrainPaddleConfig.feature_dimension), | |||||
batch_size=TrainPaddleConfig.batch_size, | batch_size=TrainPaddleConfig.batch_size, | ||||
shuffle=True | shuffle=True | ||||
) | ) | ||||
@@ -76,7 +76,7 @@ class TestPaddle: | |||||
from paddle.io import Dataset | from paddle.io import Dataset | ||||
import paddle | import paddle | ||||
class PaddleRandomMaxDataset(Dataset): | |||||
class PaddleArgMaxDataset(Dataset): | |||||
def __init__(self, num_samples, num_features): | def __init__(self, num_samples, num_features): | ||||
self.x = paddle.randn((num_samples, num_features)) | self.x = paddle.randn((num_samples, num_features)) | ||||
self.y = self.x.argmax(axis=-1) | self.y = self.x.argmax(axis=-1) | ||||
@@ -87,7 +87,7 @@ class TestPaddle: | |||||
def __getitem__(self, item): | def __getitem__(self, item): | ||||
return {"x": self.x[item], "y": self.y[item]} | return {"x": self.x[item], "y": self.y[item]} | ||||
ds = PaddleRandomMaxDataset(100, 2) | |||||
ds = PaddleArgMaxDataset(100, 2) | |||||
dl = DataLoader(ds, places=None, collate_fn=Collator(), batch_size=4) | dl = DataLoader(ds, places=None, collate_fn=Collator(), batch_size=4) | ||||
for batch in dl: | for batch in dl: | ||||
print(batch) | print(batch) |
@@ -0,0 +1,45 @@ | |||||
import pytest | |||||
from fastNLP.core.drivers import JittorSingleDriver, JittorMPIDriver | |||||
from fastNLP.core.drivers.jittor_driver.initialize_jittor_driver import initialize_jittor_driver | |||||
from tests.helpers.models.jittor_model import JittorNormalModel_Classification_1 | |||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||||
if _NEED_IMPORT_JITTOR: | |||||
import jittor as jt | |||||
@pytest.mark.jittor | |||||
def test_incorrect_driver(): | |||||
model = JittorNormalModel_Classification_1(20, 10) | |||||
with pytest.raises(ValueError): | |||||
driver = initialize_jittor_driver("torch", 0, model) | |||||
@pytest.mark.jittor | |||||
@pytest.mark.parametrize( | |||||
"device", | |||||
["cpu", "gpu", None, "cuda"] | |||||
) | |||||
def test_get_single_device(device): | |||||
""" | |||||
测试正常情况下初始化 JittorSingleDriver 的情况 | |||||
""" | |||||
model = JittorNormalModel_Classification_1(20, 10) | |||||
driver = initialize_jittor_driver("jittor", device, model) | |||||
assert isinstance(driver, JittorSingleDriver) | |||||
@pytest.mark.jittor | |||||
@pytest.mark.parametrize( | |||||
"device", | |||||
[[0, 2, 3], 1, 2] | |||||
) | |||||
def test_get_mpi(device): | |||||
""" | |||||
测试 jittor 多卡的初始化情况 | |||||
""" | |||||
model = JittorNormalModel_Classification_1(20, 10) | |||||
with pytest.raises(NotImplementedError): | |||||
driver = initialize_jittor_driver("jittor", device, model) | |||||
# assert isinstance(driver, JittorMPIDriver) |
@@ -1,99 +1,614 @@ | |||||
import pytest | import pytest | ||||
import os | |||||
from copy import deepcopy | |||||
from pathlib import Path | |||||
import numpy as np | |||||
from fastNLP.core.drivers.jittor_driver.single_device import JittorSingleDriver | |||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||||
from fastNLP.core.drivers.jittor_driver import JittorSingleDriver | |||||
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | |||||
from fastNLP.core.dataloaders import JittorDataLoader | |||||
from tests.helpers.models.jittor_model import JittorNormalModel_Classification_1 | |||||
from tests.helpers.datasets.jittor_data import JittorNormalDataset, JittorNormalXYDataset | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from fastNLP.envs.distributed import rank_zero_rm | |||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR, _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_JITTOR: | if _NEED_IMPORT_JITTOR: | ||||
import jittor as jt # 将 jittor 引入 | |||||
from jittor import nn, Module # 引入相关的模块 | |||||
from jittor import init | |||||
from jittor.dataset import MNIST | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as Module | |||||
import jittor as jt | |||||
from jittor.dataset import ( | |||||
BatchSampler as JittorBatchSampler, | |||||
RandomSampler as JittorRandomSampler, | |||||
SequentialSampler as JittorSequentialSampler, | |||||
SubsetRandomSampler as JittorSubsetRandomSampler | |||||
) | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
def get_dataloader(dataset, use_dataloader, sampler, batch_size, shuffle, drop_last=False): | |||||
""" | |||||
:param dataset: | |||||
:param use_dataloader: 是否使用 JittorDataLoader 包裹 | |||||
:param sampler: 使用 BatchSampler Samlper 还是不使用 Sampler | |||||
""" | |||||
if use_dataloader: | |||||
dataloader = JittorDataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last) | |||||
dataloader.dataset.set_attrs(sampler=sampler) | |||||
else: | |||||
dataloader = dataset | |||||
dataloader.set_attrs(batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, sampler=sampler) | |||||
class Model(Module): | |||||
def __init__ (self): | |||||
super (Model, self).__init__() | |||||
self.conv1 = nn.Conv (3, 32, 3, 1) # no padding | |||||
self.conv2 = nn.Conv (32, 64, 3, 1) | |||||
self.bn = nn.BatchNorm(64) | |||||
self.max_pool = nn.Pool (2, 2) | |||||
self.relu = nn.Relu() | |||||
self.fc1 = nn.Linear (64 * 12 * 12, 256) | |||||
self.fc2 = nn.Linear (256, 10) | |||||
def execute(self, x) : | |||||
# it's simliar to forward function in Pytorch | |||||
x = self.conv1 (x) | |||||
x = self.relu (x) | |||||
x = self.conv2 (x) | |||||
x = self.bn (x) | |||||
x = self.relu (x) | |||||
x = self.max_pool (x) | |||||
x = jt.reshape (x, [x.shape[0], -1]) | |||||
x = self.fc1 (x) | |||||
x = self.relu(x) | |||||
x = self.fc2 (x) | |||||
return x | |||||
return dataloader | |||||
############################################################################ | |||||
# | |||||
# 测试基类 JittorDrvier 中的一些简单函数 | |||||
# | |||||
############################################################################ | |||||
class TestJittorDriverFunctions: | |||||
""" | |||||
使用 JittorSingleDriver 测试基类的函数 | |||||
""" | |||||
@classmethod | |||||
def setup_class(self): | |||||
model = JittorNormalModel_Classification_1(10, 32) | |||||
self.driver = JittorSingleDriver(model, device="cpu") | |||||
@pytest.mark.jittor | |||||
def test_check_optimizers_legality(self): | |||||
""" | |||||
测试对合法的 optimizers 的检查 | |||||
""" | |||||
# 单个 optimizer | |||||
optimizer = jt.optim.Adam( | |||||
params=self.driver.model.parameters(), | |||||
lr=0.01 | |||||
) | |||||
self.driver.set_optimizers(optimizer) | |||||
# optimizer 列表 | |||||
optimizers = [ | |||||
jt.optim.Adam( | |||||
params=self.driver.model.parameters(), | |||||
lr=0.01 | |||||
) for i in range(10) | |||||
] | |||||
self.driver.set_optimizers(optimizers) | |||||
@pytest.mark.torchjittor | |||||
def test_invalid_optimizers(self): | |||||
""" | |||||
测试传入非法的 optimizers | |||||
""" | |||||
# 单个 optimizer | |||||
optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | |||||
with pytest.raises(TypeError): | |||||
self.driver.set_optimizers(optimizer) | |||||
optimizers = [ | |||||
torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | |||||
] | |||||
with pytest.raises(TypeError): | |||||
self.driver.set_optimizers(optimizers) | |||||
@pytest.mark.jittor | |||||
def test_check_dataloader_legality(self): | |||||
""" | |||||
测试 check_dataloader_legality 函数的表现 | |||||
""" | |||||
# 使用 JittorDataLoader | |||||
dataloader = JittorDataLoader(JittorNormalDataset()) | |||||
self.driver.check_dataloader_legality(dataloader) | |||||
# 使用 jittor.dataset.Dataset | |||||
self.driver.check_dataloader_legality(JittorNormalDataset()) | |||||
@pytest.mark.torchjittor | |||||
def test_check_dataloader_legality_invalid(self): | |||||
""" | |||||
测试 check_dataloader_legality 函数传入其他类型的表现 | |||||
""" | |||||
# 创建 torch 的 dataloader | |||||
dataloader = torch.utils.data.DataLoader( | |||||
TorchNormalDataset(), | |||||
batch_size=32, shuffle=True | |||||
) | |||||
with pytest.raises(TypeError): | |||||
self.driver.check_dataloader_legality(dataloader) | |||||
@pytest.mark.jittor | |||||
def test_tensor_to_numeric(self): | |||||
""" | |||||
测试 tensor_to_numeric 函数 | |||||
""" | |||||
# 单个张量 | |||||
tensor = jt.Var(3) | |||||
res = JittorSingleDriver.tensor_to_numeric(tensor) | |||||
assert res == 3 | |||||
tensor = jt.rand(3, 4) | |||||
res = JittorSingleDriver.tensor_to_numeric(tensor) | |||||
assert res == tensor.tolist() | |||||
# 张量list | |||||
tensor_list = [jt.rand(6, 4, 2) for i in range(10)] | |||||
res = JittorSingleDriver.tensor_to_numeric(tensor_list) | |||||
assert isinstance(res, list) | |||||
tensor_list = [t.tolist() for t in tensor_list] | |||||
assert res == tensor_list | |||||
# 张量tuple | |||||
tensor_tuple = tuple([jt.rand(6, 4, 2) for i in range(10)]) | |||||
res = JittorSingleDriver.tensor_to_numeric(tensor_tuple) | |||||
assert isinstance(res, tuple) | |||||
tensor_tuple = tuple([t.tolist() for t in tensor_tuple]) | |||||
assert res == tensor_tuple | |||||
# 张量dict | |||||
tensor_dict = { | |||||
"tensor": jt.rand(3, 4), | |||||
"list": [jt.rand(6, 4, 2) for i in range(10)], | |||||
"dict":{ | |||||
"list": [jt.rand(6, 4, 2) for i in range(10)], | |||||
"tensor": jt.rand(3, 4) | |||||
}, | |||||
"int": 2, | |||||
"string": "test string" | |||||
} | |||||
res = JittorSingleDriver.tensor_to_numeric(tensor_dict) | |||||
assert isinstance(res, dict) | |||||
assert res["tensor"] == tensor_dict["tensor"].tolist() | |||||
assert isinstance(res["list"], list) | |||||
for r, d in zip(res["list"], tensor_dict["list"]): | |||||
assert r == d.tolist() | |||||
assert isinstance(res["int"], int) | |||||
assert isinstance(res["string"], str) | |||||
assert isinstance(res["dict"], dict) | |||||
assert isinstance(res["dict"]["list"], list) | |||||
for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]): | |||||
assert r == d.tolist() | |||||
assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() | |||||
@pytest.mark.jittor | |||||
def test_tensor_to_numeric_reduce(self): | |||||
tensor = jt.Var([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) | |||||
res_max = JittorSingleDriver.tensor_to_numeric(tensor, reduce="max") | |||||
res_min = JittorSingleDriver.tensor_to_numeric(tensor, reduce="min") | |||||
res_sum = JittorSingleDriver.tensor_to_numeric(tensor, reduce="sum") | |||||
res_mean = JittorSingleDriver.tensor_to_numeric(tensor, reduce="mean") | |||||
assert res_max == 6 | |||||
assert res_min == 1 | |||||
assert res_sum == 21 | |||||
assert res_mean == 3.5 | |||||
@pytest.mark.jittor | |||||
def test_set_model_mode(self): | |||||
""" | |||||
测试 set_model_mode 函数 | |||||
""" | |||||
self.driver.set_model_mode("train") | |||||
assert self.driver.model.is_training() | |||||
self.driver.set_model_mode("eval") | |||||
assert not self.driver.model.is_training() | |||||
# 应该报错 | |||||
with pytest.raises(AssertionError): | |||||
self.driver.set_model_mode("test") | |||||
@pytest.mark.jittor | |||||
def test_move_model_to_device_cpu(self): | |||||
""" | |||||
测试 move_model_to_device 函数,仅测试能否运行 | |||||
""" | |||||
JittorSingleDriver.move_model_to_device(self.driver.model, "cpu") | |||||
@pytest.mark.jittor | |||||
def test_move_model_to_device_gpu(self): | |||||
""" | |||||
测试 move_model_to_device 函数,仅测试能否运行 | |||||
""" | |||||
JittorSingleDriver.move_model_to_device(self.driver.model, "gpu") | |||||
@pytest.mark.jittor | |||||
def test_set_deterministic_dataloader(self): | |||||
""" | |||||
测试 set_deterministic_dataloader,仅测试能否运行 | |||||
""" | |||||
# 先确保不影响运行 | |||||
# TODO:正确性 | |||||
dataloader = JittorDataLoader(JittorNormalDataset()) | |||||
self.driver.set_deterministic_dataloader(dataloader) | |||||
self.driver.set_deterministic_dataloader(JittorNormalDataset()) | |||||
@pytest.mark.jittor | |||||
def test_set_sampler_epoch(self): | |||||
""" | |||||
测试 set_sampler_epoch | |||||
""" | |||||
# 先确保不影响运行 | |||||
# TODO:正确性 | |||||
dataloader = JittorDataLoader(JittorNormalDataset()) | |||||
self.driver.set_sampler_epoch(dataloader, 0) | |||||
self.driver.set_sampler_epoch(JittorNormalDataset(), 0) | |||||
@pytest.mark.jittor | |||||
@pytest.mark.parametrize("batch_size", [16]) | |||||
@pytest.mark.parametrize("shuffle", [True, False]) | |||||
@pytest.mark.parametrize("drop_last", [True, False]) | |||||
@pytest.mark.parametrize("use_dataloader", [True, False]) | |||||
def test_get_dataloader_args(self, batch_size, shuffle, drop_last, use_dataloader): | |||||
""" | |||||
测试正常情况下 get_dataloader_args 的表现 | |||||
""" | |||||
dataloader = get_dataloader( | |||||
JittorNormalDataset(), | |||||
use_dataloader=use_dataloader, | |||||
sampler=None, | |||||
batch_size=batch_size, | |||||
shuffle=shuffle, | |||||
drop_last=drop_last | |||||
) | |||||
res = JittorSingleDriver.get_dataloader_args(dataloader) | |||||
assert isinstance(res.dataset, JittorNormalDataset) | |||||
assert res.sampler is None | |||||
assert res.shuffle == shuffle | |||||
assert res.batch_size == batch_size | |||||
assert res.drop_last == drop_last | |||||
@pytest.mark.jittor | |||||
@pytest.mark.parametrize("batch_size", [16]) | |||||
@pytest.mark.parametrize("shuffle", [True, False]) | |||||
@pytest.mark.parametrize("drop_last", [True, False]) | |||||
@pytest.mark.parametrize("use_dataloader", [True, False]) | |||||
def test_get_dataloader_args_with_randomsampler(self, batch_size, shuffle, drop_last, use_dataloader): | |||||
""" | |||||
测试替换了 sampler 后 get_dataloader_args 的表现 | |||||
""" | |||||
dataset = JittorNormalDataset() | |||||
dataloader = get_dataloader( | |||||
dataset, | |||||
use_dataloader=use_dataloader, | |||||
batch_size=batch_size, | |||||
sampler=RandomSampler(dataset, shuffle=shuffle), | |||||
shuffle=shuffle, | |||||
drop_last=drop_last | |||||
) | |||||
res = JittorSingleDriver.get_dataloader_args(dataloader) | |||||
assert isinstance(res.dataset, JittorNormalDataset) | |||||
assert isinstance(res.sampler, RandomSampler) | |||||
assert res.shuffle == shuffle | |||||
assert res.batch_size == batch_size | |||||
assert res.drop_last == drop_last | |||||
############################################################################ | |||||
# | |||||
# 测试 JittorSingleDrvier 中的一些简单函数 | |||||
# | |||||
############################################################################ | |||||
@pytest.mark.jittor | @pytest.mark.jittor | ||||
@pytest.mark.skip("Skip jittor tests now.") | |||||
class TestSingleDevice: | |||||
def test_on_gpu_without_fp16(self): | |||||
# TODO get_dataloader | |||||
batch_size = 64 | |||||
learning_rate = 0.1 | |||||
epochs = 5 | |||||
losses = [] | |||||
losses_idx = [] | |||||
train_loader = MNIST(train=True, batch_size=batch_size, shuffle=True) | |||||
val_loader = MNIST(train=False, batch_size=1, shuffle=False) | |||||
model = Model() | |||||
driver = JittorSingleDriver(model, device=[1]) | |||||
optimizer = nn.SGD(model.parameters(), learning_rate) | |||||
driver.set_optimizers(optimizer) | |||||
for epoch in range(epochs): | |||||
driver.set_model_mode("train") | |||||
lens = len(train_loader) | |||||
for batch_idx, (inputs, targets) in enumerate(train_loader): | |||||
outputs =driver.train_step(inputs) | |||||
loss = nn.cross_entropy_loss(outputs, targets) | |||||
driver.backward(loss) | |||||
driver.step() | |||||
driver.zero_grad() | |||||
losses.append(loss.data[0]) | |||||
losses_idx.append(epoch * lens + batch_idx) | |||||
test_loss = 0 | |||||
correct = 0 | |||||
total_acc = 0 | |||||
total_num = 0 | |||||
driver.set_model_mode("eval") | |||||
for batch_idx, (inputs, targets) in enumerate(val_loader): | |||||
batch_size = inputs.shape[0] | |||||
outputs = driver.test_step(inputs) | |||||
pred = np.argmax(outputs.data, axis=1) | |||||
acc = np.sum(targets.data==pred) | |||||
total_acc += acc | |||||
total_num += batch_size | |||||
acc = acc / batch_size | |||||
assert total_acc / total_num > 0.95 | |||||
def test_on_cpu_without_fp16(self): | |||||
pass | |||||
def test_on_gpu_with_fp16(self): | |||||
pass | |||||
class TestSingleDeviceFunction: | |||||
""" | |||||
测试其它函数的测试例 | |||||
""" | |||||
@classmethod | |||||
def setup_class(cls): | |||||
model = JittorNormalModel_Classification_1(10, 784) | |||||
cls.driver = JittorSingleDriver(model, device="cpu") | |||||
def test_unwrap_model(self): | |||||
""" | |||||
测试能否运行 | |||||
""" | |||||
res = self.driver.unwrap_model() | |||||
assert res is self.driver.model | |||||
def test_is_distributed(self): | |||||
assert self.driver.is_distributed() == False | |||||
def test_move_data_to_device(self): | |||||
self.driver.move_data_to_device(jt.rand(32, 64)) | |||||
############################################################################ | |||||
# | |||||
# 测试 set_dist_repro_dataloader 函数 | |||||
# | |||||
############################################################################ | |||||
@pytest.mark.jittor | |||||
class TestSetDistReproDataloader: | |||||
""" | |||||
专门测试 set_dist_repro_dataloader 函数的类 | |||||
""" | |||||
def setup_method(self): | |||||
self.dataset = JittorNormalDataset(20) | |||||
model = JittorNormalModel_Classification_1(10, 32) | |||||
self.driver = JittorSingleDriver(model, device="cpu") | |||||
@pytest.mark.parametrize("use_dataloader", [True, False]) | |||||
def test_with_reproducible_false(self, use_dataloader): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现 | |||||
当dist为字符串时,此时应该返回原来的 dataloader | |||||
""" | |||||
dataloader = get_dataloader(self.dataset, use_dataloader, sampler=None, batch_size=2, shuffle=True) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||||
assert replaced_loader is dataloader | |||||
@pytest.mark.parametrize("shuffle", [True, False]) | |||||
@pytest.mark.parametrize("sampler", [None, "random", "sequential"]) | |||||
@pytest.mark.parametrize("use_dataloader", [True, False]) | |||||
def test_with_reproducible_true(self, shuffle, sampler, use_dataloader): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | |||||
当dist为字符串时,此时应该返回新的 dataloader,会替换 sampler 为 RandomSampler | |||||
""" | |||||
if sampler == "random": | |||||
sampler = JittorRandomSampler(self.dataset) | |||||
_shuffle = True | |||||
elif sampler == "sequential": | |||||
sampler = JittorSequentialSampler(self.dataset) | |||||
_shuffle = False | |||||
else: | |||||
_shuffle = shuffle | |||||
dataloader = get_dataloader(self.dataset, use_dataloader, sampler=sampler, batch_size=2, shuffle=shuffle) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.sampler, RandomSampler) | |||||
assert replaced_loader.sampler.shuffle == _shuffle | |||||
assert replaced_loader.batch_size == dataloader.batch_size | |||||
assert replaced_loader.drop_last == dataloader.drop_last | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle, use_dataloader) | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
@pytest.mark.parametrize("use_dataloader", [True, False]) | |||||
def test_with_dist_batch_sampler(self, shuffle, use_dataloader): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler | |||||
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler | |||||
jittor 暂时不支持这种情况,会报错 | |||||
""" | |||||
dataloader = get_dataloader(self.dataset, use_dataloader, sampler=None, batch_size=2, shuffle=not shuffle) | |||||
dist = ReproduceBatchSampler(JittorBatchSampler(JittorRandomSampler(self.dataset), 4, False), 4, False) | |||||
with pytest.raises(RuntimeError): | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
@pytest.mark.parametrize("use_dataloader", [True, False]) | |||||
def test_with_dist_sampler(self, shuffle, use_dataloader): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 | |||||
应该返回新的 dataloader,并将 sampler 替换为 dist 对应的 Sampler | |||||
""" | |||||
dataloader = get_dataloader(self.dataset, use_dataloader, sampler=None, batch_size=2, shuffle=not shuffle) | |||||
dist = RandomSampler(self.dataset, shuffle=shuffle) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.sampler, RandomSampler) | |||||
assert replaced_loader.sampler is dist | |||||
assert replaced_loader.batch_size == dataloader.batch_size | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle, use_dataloader) | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
@pytest.mark.parametrize("use_dataloader", [True, False]) | |||||
def test_with_dataloader_reproducible_batch_sampler(self, shuffle, use_dataloader): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 | |||||
应该返回新的 dataloader,且其余各项设置和原来相同 | |||||
""" | |||||
dataloader = get_dataloader( | |||||
self.dataset, | |||||
use_dataloader=use_dataloader, | |||||
sampler=ReproduceBatchSampler( | |||||
JittorBatchSampler(JittorRandomSampler(self.dataset), 4, False), | |||||
batch_size=4, | |||||
drop_last=False, | |||||
), | |||||
batch_size=4, | |||||
shuffle=shuffle, | |||||
) | |||||
with pytest.raises(RuntimeError): | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
@pytest.mark.parametrize("use_dataloader", [True, False]) | |||||
def test_with_dataloader_reproducible_sampler(self, shuffle, use_dataloader): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 | |||||
应该返回新的 dataloader,且其余各项设置和原来相同 | |||||
""" | |||||
dataloader = get_dataloader( | |||||
self.dataset, | |||||
use_dataloader=use_dataloader, | |||||
sampler=RandomSampler(self.dataset, shuffle), | |||||
batch_size=2, | |||||
shuffle=shuffle, | |||||
) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert not (replaced_loader.sampler is dataloader.sampler) | |||||
assert isinstance(replaced_loader.sampler, RandomSampler) | |||||
assert replaced_loader.batch_size == 2 | |||||
assert replaced_loader.shuffle == shuffle | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle, use_dataloader) | |||||
def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle, use_dataloader): | |||||
""" | |||||
测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确 | |||||
""" | |||||
# 迭代两个 batch | |||||
num_consumed_batches = 2 | |||||
already_seen_idx = set() | |||||
replaced_loader.sampler.set_epoch(6) | |||||
for idx, batch in enumerate(replaced_loader): | |||||
if idx >= num_consumed_batches: | |||||
break | |||||
already_seen_idx.update(batch.tolist()) | |||||
sampler_states = replaced_loader.sampler.state_dict() | |||||
# 重新加载,应该可以输出剩下的内容,且对于 JittorNormalDataset 来说,排序后应该是一个 range | |||||
left_idxes = set() | |||||
batch_size = replaced_loader.batch_size | |||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | |||||
# 重新构造 dataloader | |||||
if use_dataloader: | |||||
dataset = deepcopy(replaced_loader.dataset.dataset) | |||||
else: | |||||
dataset = deepcopy(replaced_loader) | |||||
new_loader = get_dataloader( | |||||
dataset=dataset, | |||||
use_dataloader=use_dataloader, | |||||
sampler = RandomSampler(dataset, shuffle=shuffle), | |||||
batch_size=batch_size, | |||||
shuffle=shuffle, | |||||
drop_last=False | |||||
) | |||||
new_loader.sampler.load_state_dict(sampler_states) | |||||
new_loader.sampler.set_epoch(6) | |||||
for idx, batch in enumerate(new_loader): | |||||
left_idxes.update(batch.tolist()) | |||||
print(already_seen_idx) | |||||
print(left_idxes) | |||||
assert len(left_idxes) + len(already_seen_idx) == self.dataset.total_len | |||||
assert len(left_idxes | already_seen_idx) == self.dataset.total_len | |||||
############################################################################ | |||||
# | |||||
# 测试 save 和 load 相关的功能 | |||||
# | |||||
############################################################################ | |||||
def generate_random_driver(labels, features, fp16=False, device="cpu", lr=0.01): | |||||
""" | |||||
生成driver | |||||
""" | |||||
model = JittorNormalModel_Classification_1(labels, features) | |||||
opt = jt.optim.Adam(params=model.parameters(), lr=lr) | |||||
driver = JittorSingleDriver(model, device=device, fp16=fp16) | |||||
driver.set_optimizers(opt) | |||||
driver.setup() | |||||
return driver | |||||
@pytest.mark.jittor | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||||
@pytest.mark.parametrize("use_dataloader", [True, False]) | |||||
def test_save_and_load_model(only_state_dict, use_dataloader): | |||||
""" | |||||
测试 save_model 和 load_model 函数 | |||||
""" | |||||
try: | |||||
path = "model" | |||||
dataset = JittorNormalXYDataset(20) | |||||
dataloader = get_dataloader(dataset, sampler=None, use_dataloader=use_dataloader, batch_size=4, shuffle=True) | |||||
driver1, driver2 = generate_random_driver(20, 1, device="gpu"), generate_random_driver(20, 1, device="gpu") | |||||
driver1.save_model(path, only_state_dict) | |||||
driver2.load_model(path, only_state_dict) | |||||
for batch in dataloader: | |||||
batch = driver1.move_data_to_device(batch) | |||||
res1 = driver1.model.evaluate_step(**batch) | |||||
res2 = driver2.model.evaluate_step(**batch) | |||||
assert jt.all_(jt.equal(res1["pred"], res2["pred"])) | |||||
finally: | |||||
rank_zero_rm(path) | |||||
@pytest.mark.jittor | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||||
@pytest.mark.parametrize("use_dataloader", [True, False]) | |||||
def test_save_and_load_with_randomsampler(only_state_dict, use_dataloader): | |||||
""" | |||||
测试save和load函数,主要测试 dataloader 被替换了 sampler 的情况 | |||||
""" | |||||
try: | |||||
path = "model.ckp" | |||||
driver1, driver2 = generate_random_driver(20, 1, device="gpu", lr=0.01), \ | |||||
generate_random_driver(20, 1, device="gpu", lr=0.001) | |||||
dataset = JittorNormalXYDataset(20) | |||||
dataloader = get_dataloader( | |||||
dataset, use_dataloader, | |||||
sampler = RandomSampler(dataset, True), | |||||
batch_size=4, | |||||
shuffle=True | |||||
) | |||||
num_consumed_batches = 2 | |||||
already_seen_x_set = set() | |||||
already_seen_y_set = set() | |||||
driver1.set_sampler_epoch(dataloader, 7) | |||||
for idx, batch in enumerate(dataloader): | |||||
if idx >= num_consumed_batches: | |||||
break | |||||
already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) | |||||
already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) | |||||
sampler_states = dataloader.sampler.state_dict() | |||||
save_states = {"num_consumed_batches": num_consumed_batches} | |||||
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
# 加载 | |||||
# 更改 batch_size | |||||
dataloader = get_dataloader( | |||||
dataset, use_dataloader, | |||||
sampler=RandomSampler(dataset, True), | |||||
batch_size=2, | |||||
shuffle=True | |||||
) | |||||
load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | |||||
# 1. 检查 optimizer 的状态 | |||||
assert driver2.optimizers[0].lr == driver1.optimizers[0].lr | |||||
# 2. 检查 sampler 是否被正确地加载和替换 | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.sampler, RandomSampler) | |||||
assert replaced_loader.sampler.seed == sampler_states["seed"] | |||||
assert replaced_loader.sampler.epoch == sampler_states["epoch"] | |||||
assert replaced_loader.sampler.num_consumed_samples == 4 * num_consumed_batches | |||||
assert replaced_loader.sampler.dataset.total_len == sampler_states["length"] | |||||
assert replaced_loader.sampler.shuffle == sampler_states["shuffle"] | |||||
# 4. 检查 model 的参数是否正确 | |||||
# 5. 检查 batch_idx | |||||
start_batch = load_states.pop('batch_idx_in_epoch') | |||||
assert start_batch == 2 * num_consumed_batches | |||||
left_x_batches = set() | |||||
left_y_batches = set() | |||||
driver2.set_sampler_epoch(replaced_loader, 7) | |||||
for idx, batch in enumerate(replaced_loader): | |||||
left_x_batches.update(batch["x"].reshape(-1, ).tolist()) | |||||
left_y_batches.update(batch["y"].reshape(-1, ).tolist()) | |||||
res1 = driver1.model.evaluate_step(**batch) | |||||
res2 = driver2.model.evaluate_step(**batch) | |||||
assert jt.all_(jt.equal(res1["pred"], res2["pred"])) | |||||
assert len(left_x_batches) + len(already_seen_x_set) == dataset.total_len | |||||
assert len(left_x_batches | already_seen_x_set) == dataset.total_len | |||||
assert len(left_y_batches) + len(already_seen_y_set) == dataset.total_len | |||||
assert len(left_y_batches | already_seen_y_set) == dataset.total_len | |||||
finally: | |||||
rank_zero_rm(path) |
@@ -0,0 +1,43 @@ | |||||
import pytest | |||||
from fastNLP.core.drivers.jittor_driver.utils import replace_sampler | |||||
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | |||||
from fastNLP.core.dataloaders import JittorDataLoader | |||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||||
if _NEED_IMPORT_JITTOR: | |||||
import jittor as jt | |||||
from tests.helpers.datasets.jittor_data import JittorNormalDataset | |||||
@pytest.mark.jittor | |||||
@pytest.mark.parametrize("dataset", [ | |||||
JittorNormalDataset(20, batch_size=10, shuffle=True), | |||||
JittorNormalDataset(20, batch_size=5, drop_last=True), | |||||
JittorNormalDataset(20) | |||||
]) | |||||
def test_replace_sampler_dataset(dataset): | |||||
dataset = JittorNormalDataset(20) | |||||
sampler = RandomSampler(dataset) | |||||
replaced_loader = replace_sampler(dataset, sampler) | |||||
assert not (replaced_loader is dataset) | |||||
assert isinstance(replaced_loader.sampler, RandomSampler) | |||||
assert replaced_loader.batch_size == dataset.batch_size | |||||
assert replaced_loader.drop_last == dataset.drop_last | |||||
assert replaced_loader.shuffle == dataset.shuffle | |||||
assert replaced_loader.total_len == dataset.total_len | |||||
@pytest.mark.jittor | |||||
def test_replace_sampler_jittordataloader(): | |||||
dataset = JittorNormalDataset(20, batch_size=10, shuffle=True) | |||||
dataloader = JittorDataLoader(dataset, batch_size=8, shuffle=True) | |||||
sampler = RandomSampler(dataset) | |||||
replaced_loader = replace_sampler(dataloader, sampler) | |||||
assert not (replaced_loader is dataloader) | |||||
assert not (replaced_loader.dataset.dataset is dataloader.dataset.dataset) | |||||
assert isinstance(replaced_loader.sampler, RandomSampler) | |||||
assert replaced_loader.batch_size == 8 | |||||
assert replaced_loader.shuffle == True |
@@ -10,7 +10,7 @@ from fastNLP.core.samplers import ( | |||||
UnrepeatedSequentialSampler, | UnrepeatedSequentialSampler, | ||||
) | ) | ||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | ||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset | |||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleNormalXYDataset | |||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
from fastNLP.envs.distributed import rank_zero_rm | from fastNLP.envs.distributed import rank_zero_rm | ||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
@@ -19,8 +19,8 @@ if _NEED_IMPORT_PADDLE: | |||||
import paddle.distributed as dist | import paddle.distributed as dist | ||||
from paddle.io import DataLoader, BatchSampler | from paddle.io import DataLoader, BatchSampler | ||||
def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"): | |||||
paddle_model = PaddleNormalModel_Classification_1(num_labels, feature_dimension) | |||||
def generate_driver(labels, features, device=[0,1], fp16=False, output_from_new_proc="only_error"): | |||||
paddle_model = PaddleNormalModel_Classification_1(labels, features) | |||||
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | ||||
driver = PaddleFleetDriver( | driver = PaddleFleetDriver( | ||||
model=paddle_model, | model=paddle_model, | ||||
@@ -465,10 +465,14 @@ class TestSetDistReproDataloader: | |||||
num_replicas = len(self.device) | num_replicas = len(self.device) | ||||
num_consumed_batches = 2 | num_consumed_batches = 2 | ||||
already_seen_idx = set() | already_seen_idx = set() | ||||
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): | |||||
sampler_states = replaced_loader.batch_sampler.set_epoch(10) | |||||
else: | |||||
sampler_states = replaced_loader.batch_sampler.sampler.set_epoch(10) | |||||
for idx, batch in enumerate(replaced_loader): | for idx, batch in enumerate(replaced_loader): | ||||
if idx >= num_consumed_batches: | if idx >= num_consumed_batches: | ||||
break | break | ||||
already_seen_idx.update(batch) | |||||
already_seen_idx.update(batch.tolist()) | |||||
dist.barrier() | dist.barrier() | ||||
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): | if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): | ||||
sampler_states = replaced_loader.batch_sampler.state_dict() | sampler_states = replaced_loader.batch_sampler.state_dict() | ||||
@@ -496,6 +500,7 @@ class TestSetDistReproDataloader: | |||||
pad=True | pad=True | ||||
) | ) | ||||
new_loader.batch_sampler.load_state_dict(sampler_states) | new_loader.batch_sampler.load_state_dict(sampler_states) | ||||
new_loader.batch_sampler.set_epoch(10) | |||||
else: | else: | ||||
batch_size = replaced_loader.batch_sampler.batch_size | batch_size = replaced_loader.batch_sampler.batch_size | ||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas | sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas | ||||
@@ -508,8 +513,9 @@ class TestSetDistReproDataloader: | |||||
) | ) | ||||
new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler) | new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler) | ||||
new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | ||||
new_loader.batch_sampler.sampler.set_epoch(10) | |||||
for idx, batch in enumerate(new_loader): | for idx, batch in enumerate(new_loader): | ||||
left_idxes.update(batch) | |||||
left_idxes.update(batch.tolist()) | |||||
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas | assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas | ||||
assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas | assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas | ||||
@@ -533,7 +539,7 @@ class TestSaveLoad: | |||||
cls.driver = generate_driver(10, 10, device=[0,1]) | cls.driver = generate_driver(10, 10, device=[0,1]) | ||||
def setup_method(self): | def setup_method(self): | ||||
self.dataset = PaddleRandomMaxDataset(20, 10) | |||||
self.dataset = PaddleNormalXYDataset(40) | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | @pytest.mark.parametrize("only_state_dict", ([True, False])) | ||||
@@ -545,12 +551,12 @@ class TestSaveLoad: | |||||
path = "model" | path = "model" | ||||
dataloader = DataLoader(self.dataset, batch_size=2) | dataloader = DataLoader(self.dataset, batch_size=2) | ||||
self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10) | |||||
self.driver1, self.driver2 = generate_driver(40, 1), generate_driver(40, 1) | |||||
if only_state_dict: | if only_state_dict: | ||||
self.driver1.save_model(path, only_state_dict) | self.driver1.save_model(path, only_state_dict) | ||||
else: | else: | ||||
self.driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((4, 10))]) | |||||
self.driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((4, 1))]) | |||||
# 同步 | # 同步 | ||||
dist.barrier() | dist.barrier() | ||||
@@ -594,8 +600,8 @@ class TestSaveLoad: | |||||
path = "model.ckp" | path = "model.ckp" | ||||
num_replicas = len(device) | num_replicas = len(device) | ||||
self.driver1, self.driver2 = generate_driver(10, 10, device=device, fp16=fp16), \ | |||||
generate_driver(10, 10, device=device, fp16=False) | |||||
self.driver1, self.driver2 = generate_driver(40, 1, device=device, fp16=fp16), \ | |||||
generate_driver(40, 1, device=device, fp16=False) | |||||
dataloader = DataLoader( | dataloader = DataLoader( | ||||
dataset=self.dataset, | dataset=self.dataset, | ||||
batch_sampler=BucketedBatchSampler( | batch_sampler=BucketedBatchSampler( | ||||
@@ -613,11 +619,12 @@ class TestSaveLoad: | |||||
already_seen_x_set = set() | already_seen_x_set = set() | ||||
already_seen_y_set = set() | already_seen_y_set = set() | ||||
self.driver1.set_sampler_epoch(dataloader, 2) | |||||
for idx, batch in enumerate(dataloader): | for idx, batch in enumerate(dataloader): | ||||
if idx >= num_consumed_batches: | if idx >= num_consumed_batches: | ||||
break | break | ||||
already_seen_x_set.update(batch["x"]) | |||||
already_seen_y_set.update(batch["y"]) | |||||
already_seen_x_set.update(batch["x"].reshape((-1, )).tolist()) | |||||
already_seen_y_set.update(batch["y"].reshape((-1, )).tolist()) | |||||
# 同步 | # 同步 | ||||
dist.barrier() | dist.barrier() | ||||
@@ -669,10 +676,11 @@ class TestSaveLoad: | |||||
assert start_batch == 2 * num_consumed_batches | assert start_batch == 2 * num_consumed_batches | ||||
left_x_batches = set() | left_x_batches = set() | ||||
left_y_batches = set() | left_y_batches = set() | ||||
self.driver2.set_sampler_epoch(replaced_loader, 2) | |||||
for idx, batch in enumerate(replaced_loader): | for idx, batch in enumerate(replaced_loader): | ||||
left_x_batches.update(batch["x"]) | |||||
left_y_batches.update(batch["y"]) | |||||
left_x_batches.update(batch["x"].reshape((-1, )).tolist()) | |||||
left_y_batches.update(batch["y"].reshape((-1, )).tolist()) | |||||
res1 = self.driver1.model( | res1 = self.driver1.model( | ||||
batch, | batch, | ||||
fastnlp_fn=self.driver1.model._layers.model.evaluate_step, | fastnlp_fn=self.driver1.model._layers.model.evaluate_step, | ||||
@@ -709,8 +717,8 @@ class TestSaveLoad: | |||||
num_replicas = len(device) | num_replicas = len(device) | ||||
self.driver1 = generate_driver(10, 10, device=device, fp16=fp16) | |||||
self.driver2 = generate_driver(10, 10, device=device, fp16=False) | |||||
self.driver1 = generate_driver(40, 1, device=device, fp16=fp16) | |||||
self.driver2 = generate_driver(40, 1, device=device, fp16=False) | |||||
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4) | batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4) | ||||
batch_sampler.sampler = RandomSampler(self.dataset, True) | batch_sampler.sampler = RandomSampler(self.dataset, True) | ||||
batch_sampler.sampler.set_distributed( | batch_sampler.sampler.set_distributed( | ||||
@@ -726,11 +734,12 @@ class TestSaveLoad: | |||||
already_seen_x_set = set() | already_seen_x_set = set() | ||||
already_seen_y_set = set() | already_seen_y_set = set() | ||||
self.driver1.set_sampler_epoch(dataloader, 2) | |||||
for idx, batch in enumerate(dataloader): | for idx, batch in enumerate(dataloader): | ||||
if idx >= num_consumed_batches: | if idx >= num_consumed_batches: | ||||
break | break | ||||
already_seen_x_set.update(batch["x"]) | |||||
already_seen_y_set.update(batch["y"]) | |||||
already_seen_x_set.update(batch["x"].reshape((-1, )).tolist()) | |||||
already_seen_y_set.update(batch["y"].reshape((-1, )).tolist()) | |||||
# 同步 | # 同步 | ||||
dist.barrier() | dist.barrier() | ||||
@@ -779,10 +788,11 @@ class TestSaveLoad: | |||||
assert start_batch == 2 * num_consumed_batches | assert start_batch == 2 * num_consumed_batches | ||||
left_x_batches = set() | left_x_batches = set() | ||||
left_y_batches = set() | left_y_batches = set() | ||||
self.driver2.set_sampler_epoch(replaced_loader, 2) | |||||
for idx, batch in enumerate(replaced_loader): | for idx, batch in enumerate(replaced_loader): | ||||
left_x_batches.update(batch["x"]) | |||||
left_y_batches.update(batch["y"]) | |||||
left_x_batches.update(batch["x"].reshape((-1, )).tolist()) | |||||
left_y_batches.update(batch["y"].reshape((-1, )).tolist()) | |||||
res1 = self.driver1.model( | res1 = self.driver1.model( | ||||
batch, | batch, | ||||
fastnlp_fn=self.driver1.model._layers.model.evaluate_step, | fastnlp_fn=self.driver1.model._layers.model.evaluate_step, | ||||
@@ -12,7 +12,7 @@ if _NEED_IMPORT_PADDLE: | |||||
@pytest.mark.paddle | @pytest.mark.paddle | ||||
def test_incorrect_driver(): | def test_incorrect_driver(): | ||||
model = PaddleNormalModel_Classification_1(2, 100) | |||||
model = PaddleNormalModel_Classification_1(20, 10) | |||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
driver = initialize_paddle_driver("torch", 0, model) | driver = initialize_paddle_driver("torch", 0, model) | ||||
@@ -26,7 +26,7 @@ def test_get_single_device(device): | |||||
测试正常情况下初始化 PaddleSingleDriver 的情况 | 测试正常情况下初始化 PaddleSingleDriver 的情况 | ||||
""" | """ | ||||
model = PaddleNormalModel_Classification_1(2, 100) | |||||
model = PaddleNormalModel_Classification_1(20, 10) | |||||
driver = initialize_paddle_driver("paddle", device, model) | driver = initialize_paddle_driver("paddle", device, model) | ||||
assert isinstance(driver, PaddleSingleDriver) | assert isinstance(driver, PaddleSingleDriver) | ||||
@@ -41,7 +41,7 @@ def test_get_fleet(device): | |||||
测试 fleet 多卡的初始化情况 | 测试 fleet 多卡的初始化情况 | ||||
""" | """ | ||||
model = PaddleNormalModel_Classification_1(64, 10) | |||||
model = PaddleNormalModel_Classification_1(20, 10) | |||||
driver = initialize_paddle_driver("paddle", device, model) | driver = initialize_paddle_driver("paddle", device, model) | ||||
assert isinstance(driver, PaddleFleetDriver) | assert isinstance(driver, PaddleFleetDriver) | ||||
@@ -56,6 +56,6 @@ def test_device_out_of_range(device): | |||||
""" | """ | ||||
测试传入的device超过范围的情况 | 测试传入的device超过范围的情况 | ||||
""" | """ | ||||
model = PaddleNormalModel_Classification_1(2, 100) | |||||
model = PaddleNormalModel_Classification_1(20, 10) | |||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
driver = initialize_paddle_driver("paddle", device, model) | driver = initialize_paddle_driver("paddle", device, model) |
@@ -4,14 +4,16 @@ from pathlib import Path | |||||
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | ||||
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | ||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | ||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset | |||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleNormalXYDataset | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset | ||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from fastNLP.envs.distributed import rank_zero_rm | from fastNLP.envs.distributed import rank_zero_rm | ||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | ||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
import paddle | import paddle | ||||
from paddle.io import DataLoader, BatchSampler | from paddle.io import DataLoader, BatchSampler | ||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
@@ -31,102 +33,70 @@ class TestPaddleDriverFunctions: | |||||
model = PaddleNormalModel_Classification_1(10, 32) | model = PaddleNormalModel_Classification_1(10, 32) | ||||
self.driver = PaddleSingleDriver(model, device="cpu") | self.driver = PaddleSingleDriver(model, device="cpu") | ||||
@pytest.mark.torchpaddle | |||||
def test_check_single_optimizer_legality(self): | |||||
@pytest.mark.paddle | |||||
def test_check_optimizers_legality(self): | |||||
""" | """ | ||||
测试传入单个 optimizer 时的表现 | |||||
测试对合法的 optimizers 的检查 | |||||
""" | """ | ||||
# 单个 optimizer | |||||
optimizer = paddle.optimizer.Adam( | optimizer = paddle.optimizer.Adam( | ||||
parameters=self.driver.model.parameters(), | parameters=self.driver.model.parameters(), | ||||
learning_rate=0.01 | learning_rate=0.01 | ||||
) | ) | ||||
self.driver.set_optimizers(optimizer) | self.driver.set_optimizers(optimizer) | ||||
optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | |||||
# 传入torch的optimizer时,应该报错ValueError | |||||
with pytest.raises(ValueError): | |||||
self.driver.set_optimizers(optimizer) | |||||
@pytest.mark.torchpaddle | |||||
def test_check_optimizers_legality(self): | |||||
""" | |||||
测试传入 optimizer list 的表现 | |||||
""" | |||||
# optimizer 列表 | |||||
optimizers = [ | optimizers = [ | ||||
paddle.optimizer.Adam( | paddle.optimizer.Adam( | ||||
parameters=self.driver.model.parameters(), | parameters=self.driver.model.parameters(), | ||||
learning_rate=0.01 | learning_rate=0.01 | ||||
) for i in range(10) | ) for i in range(10) | ||||
] | ] | ||||
self.driver.set_optimizers(optimizers) | self.driver.set_optimizers(optimizers) | ||||
optimizers += [ | |||||
@pytest.mark.torchpaddle | |||||
def test_invalid_optimizers(self): | |||||
""" | |||||
测试传入非法的 optimizers | |||||
""" | |||||
# 单个 optimizer | |||||
optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | |||||
with pytest.raises(TypeError): | |||||
self.driver.set_optimizers(optimizer) | |||||
optimizers = [ | |||||
torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | ||||
] | ] | ||||
with pytest.raises(ValueError): | |||||
with pytest.raises(TypeError): | |||||
self.driver.set_optimizers(optimizers) | self.driver.set_optimizers(optimizers) | ||||
@pytest.mark.torchpaddle | |||||
def test_check_dataloader_legality_in_train(self): | |||||
@pytest.mark.paddle | |||||
def test_check_dataloader_legality(self): | |||||
""" | """ | ||||
测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 | |||||
测试 check_dataloader_legality 函数的表现 | |||||
""" | """ | ||||
dataloader = DataLoader(PaddleNormalDataset()) | dataloader = DataLoader(PaddleNormalDataset()) | ||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
self.driver.check_dataloader_legality(dataloader) | |||||
# batch_size 和 batch_sampler 均为 None 的情形 | # batch_size 和 batch_sampler 均为 None 的情形 | ||||
dataloader = DataLoader(PaddleNormalDataset(), batch_size=None) | dataloader = DataLoader(PaddleNormalDataset(), batch_size=None) | ||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
# 创建torch的dataloader | |||||
dataloader = torch.utils.data.DataLoader( | |||||
TorchNormalDataset(), | |||||
batch_size=32, shuffle=True | |||||
) | |||||
with pytest.raises(ValueError): | |||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
self.driver.check_dataloader_legality(dataloader) | |||||
@pytest.mark.torchpaddle | @pytest.mark.torchpaddle | ||||
def test_check_dataloader_legality_in_test(self): | |||||
def test_check_dataloader_legality_invalid(self): | |||||
""" | """ | ||||
测试 `is_train` 参数为 False 时,_check_dataloader_legality 函数的表现 | |||||
测试 check_dataloader_legality 函数传入其他类型的表现 | |||||
""" | """ | ||||
# 此时传入的应该是dict | |||||
dataloader = { | |||||
"train": DataLoader(PaddleNormalDataset()), | |||||
"test":DataLoader(PaddleNormalDataset()) | |||||
} | |||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
# batch_size 和 batch_sampler 均为 None 的情形 | |||||
dataloader = { | |||||
"train": DataLoader(PaddleNormalDataset()), | |||||
"test":DataLoader(PaddleNormalDataset(), batch_size=None) | |||||
} | |||||
with pytest.raises(ValueError): | |||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
# 传入的不是 dict ,应该报错 | |||||
dataloader = DataLoader(PaddleNormalDataset()) | |||||
with pytest.raises(ValueError): | |||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
# 创建 torch 的 dataloader | # 创建 torch 的 dataloader | ||||
train_loader = torch.utils.data.DataLoader( | |||||
TorchNormalDataset(), | |||||
batch_size=32, shuffle=True | |||||
) | |||||
test_loader = torch.utils.data.DataLoader( | |||||
dataloader = torch.utils.data.DataLoader( | |||||
TorchNormalDataset(), | TorchNormalDataset(), | ||||
batch_size=32, shuffle=True | batch_size=32, shuffle=True | ||||
) | ) | ||||
dataloader = {"train": train_loader, "test": test_loader} | |||||
with pytest.raises(ValueError): | |||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
with pytest.raises(TypeError): | |||||
self.driver.check_dataloader_legality(dataloader) | |||||
@pytest.mark.paddle | @pytest.mark.paddle | ||||
def test_tensor_to_numeric(self): | def test_tensor_to_numeric(self): | ||||
@@ -505,10 +475,14 @@ class TestSetDistReproDataloader: | |||||
# 迭代两个 batch | # 迭代两个 batch | ||||
num_consumed_batches = 2 | num_consumed_batches = 2 | ||||
already_seen_idx = set() | already_seen_idx = set() | ||||
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | |||||
sampler_states = replaced_loader.batch_sampler.set_epoch(5) | |||||
else: | |||||
sampler_states = replaced_loader.batch_sampler.sampler.set_epoch(5) | |||||
for idx, batch in enumerate(replaced_loader): | for idx, batch in enumerate(replaced_loader): | ||||
if idx >= num_consumed_batches: | if idx >= num_consumed_batches: | ||||
break | break | ||||
already_seen_idx.update(batch) | |||||
already_seen_idx.update(batch.tolist()) | |||||
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | ||||
sampler_states = replaced_loader.batch_sampler.state_dict() | sampler_states = replaced_loader.batch_sampler.state_dict() | ||||
else: | else: | ||||
@@ -529,6 +503,7 @@ class TestSetDistReproDataloader: | |||||
) | ) | ||||
) | ) | ||||
new_loader.batch_sampler.load_state_dict(sampler_states) | new_loader.batch_sampler.load_state_dict(sampler_states) | ||||
new_loader.batch_sampler.set_epoch(5) | |||||
else: | else: | ||||
batch_size = replaced_loader.batch_sampler.batch_size | batch_size = replaced_loader.batch_sampler.batch_size | ||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | ||||
@@ -537,8 +512,9 @@ class TestSetDistReproDataloader: | |||||
batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle) | batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle) | ||||
new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler) | new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler) | ||||
new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | ||||
new_loader.batch_sampler.sampler.set_epoch(5) | |||||
for idx, batch in enumerate(new_loader): | for idx, batch in enumerate(new_loader): | ||||
left_idxes.update(batch) | |||||
left_idxes.update(batch.tolist()) | |||||
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) | assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) | ||||
assert len(left_idxes | already_seen_idx) == len(self.dataset) | assert len(left_idxes | already_seen_idx) == len(self.dataset) | ||||
@@ -549,7 +525,7 @@ class TestSetDistReproDataloader: | |||||
# | # | ||||
############################################################################ | ############################################################################ | ||||
def generate_random_driver(features, labels, fp16=False, device="cpu"): | |||||
def generate_random_driver(labels, features, fp16=False, device="cpu"): | |||||
""" | """ | ||||
生成driver | 生成driver | ||||
""" | """ | ||||
@@ -569,9 +545,9 @@ def test_save_and_load_model(only_state_dict): | |||||
""" | """ | ||||
try: | try: | ||||
path = "model" | path = "model" | ||||
dataset = PaddleRandomMaxDataset(40, 10) | |||||
dataset = PaddleNormalXYDataset(20) | |||||
dataloader = DataLoader(dataset, batch_size=4) | dataloader = DataLoader(dataset, batch_size=4) | ||||
driver1, driver2 = generate_random_driver(10, 10, device="gpu"), generate_random_driver(10, 10, device="gpu") | |||||
driver1, driver2 = generate_random_driver(20, 1, device="gpu"), generate_random_driver(20, 1, device="gpu") | |||||
if only_state_dict: | if only_state_dict: | ||||
driver1.save_model(path, only_state_dict) | driver1.save_model(path, only_state_dict) | ||||
@@ -580,6 +556,7 @@ def test_save_and_load_model(only_state_dict): | |||||
driver2.load_model(path, only_state_dict) | driver2.load_model(path, only_state_dict) | ||||
for batch in dataloader: | for batch in dataloader: | ||||
print("?") | |||||
batch = driver1.move_data_to_device(batch) | batch = driver1.move_data_to_device(batch) | ||||
res1 = driver1.model.evaluate_step(**batch) | res1 = driver1.model.evaluate_step(**batch) | ||||
res2 = driver2.model.evaluate_step(**batch) | res2 = driver2.model.evaluate_step(**batch) | ||||
@@ -604,22 +581,23 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
try: | try: | ||||
path = "model.ckp" | path = "model.ckp" | ||||
dataset = PaddleRandomMaxDataset(40, 10) | |||||
dataset = PaddleNormalXYDataset(40) | |||||
dataloader = DataLoader( | dataloader = DataLoader( | ||||
dataset=dataset, | dataset=dataset, | ||||
batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) | batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) | ||||
) | ) | ||||
driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu") | |||||
driver1, driver2 = generate_random_driver(40, 1, fp16, "gpu"), generate_random_driver(40, 1, False, "gpu") | |||||
num_consumed_batches = 2 | num_consumed_batches = 2 | ||||
already_seen_x_set = set() | already_seen_x_set = set() | ||||
already_seen_y_set = set() | already_seen_y_set = set() | ||||
driver1.set_sampler_epoch(dataloader, 3) | |||||
for idx, batch in enumerate(dataloader): | for idx, batch in enumerate(dataloader): | ||||
if idx >= num_consumed_batches: | if idx >= num_consumed_batches: | ||||
break | break | ||||
already_seen_x_set.update(batch["x"]) | |||||
already_seen_y_set.update(batch["y"]) | |||||
already_seen_x_set.update(batch["x"].reshape((-1, )).tolist()) | |||||
already_seen_y_set.update(batch["y"].reshape((-1, )).tolist()) | |||||
sampler_states = dataloader.batch_sampler.state_dict() | sampler_states = dataloader.batch_sampler.state_dict() | ||||
save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
@@ -656,10 +634,11 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
assert start_batch == 2 * num_consumed_batches | assert start_batch == 2 * num_consumed_batches | ||||
left_x_batches = set() | left_x_batches = set() | ||||
left_y_batches = set() | left_y_batches = set() | ||||
driver2.set_sampler_epoch(replaced_loader, 3) | |||||
for idx, batch in enumerate(replaced_loader): | for idx, batch in enumerate(replaced_loader): | ||||
left_x_batches.update(batch["x"]) | |||||
left_y_batches.update(batch["y"]) | |||||
left_x_batches.update(batch["x"].reshape((-1, )).tolist()) | |||||
left_y_batches.update(batch["y"].reshape((-1, )).tolist()) | |||||
res1 = driver1.model.evaluate_step(**batch) | res1 = driver1.model.evaluate_step(**batch) | ||||
res2 = driver2.model.evaluate_step(**batch) | res2 = driver2.model.evaluate_step(**batch) | ||||
assert paddle.equal_all(res1["pred"], res2["pred"]) | assert paddle.equal_all(res1["pred"], res2["pred"]) | ||||
@@ -679,14 +658,14 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
@pytest.mark.parametrize("fp16", ([True, False])) | @pytest.mark.parametrize("fp16", ([True, False])) | ||||
def test_save_and_load_with_randomsampler(only_state_dict, fp16): | def test_save_and_load_with_randomsampler(only_state_dict, fp16): | ||||
""" | """ | ||||
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 | |||||
测试save和load函数,主要测试 dataloader 被替换了 sampler 的情况 | |||||
""" | """ | ||||
try: | try: | ||||
path = "model.ckp" | path = "model.ckp" | ||||
driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu") | |||||
dataset = PaddleRandomMaxDataset(40, 10) | |||||
driver1, driver2 = generate_random_driver(40, 1, fp16, "gpu"), generate_random_driver(40, 1, False, "gpu") | |||||
dataset = PaddleNormalXYDataset(40) | |||||
batch_sampler = BatchSampler(dataset=dataset, batch_size=4) | batch_sampler = BatchSampler(dataset=dataset, batch_size=4) | ||||
batch_sampler.sampler = RandomSampler(dataset, True) | batch_sampler.sampler = RandomSampler(dataset, True) | ||||
dataloader = DataLoader( | dataloader = DataLoader( | ||||
@@ -697,11 +676,12 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||||
already_seen_x_set = set() | already_seen_x_set = set() | ||||
already_seen_y_set = set() | already_seen_y_set = set() | ||||
driver1.set_sampler_epoch(dataloader, 3) | |||||
for idx, batch in enumerate(dataloader): | for idx, batch in enumerate(dataloader): | ||||
if idx >= num_consumed_batches: | if idx >= num_consumed_batches: | ||||
break | break | ||||
already_seen_x_set.update(batch["x"]) | |||||
already_seen_y_set.update(batch["y"]) | |||||
already_seen_x_set.update(batch["x"].reshape((-1, )).tolist()) | |||||
already_seen_y_set.update(batch["y"].reshape((-1, )).tolist()) | |||||
sampler_states = dataloader.batch_sampler.sampler.state_dict() | sampler_states = dataloader.batch_sampler.sampler.state_dict() | ||||
save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
@@ -743,10 +723,11 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||||
assert start_batch == 2 * num_consumed_batches | assert start_batch == 2 * num_consumed_batches | ||||
left_x_batches = set() | left_x_batches = set() | ||||
left_y_batches = set() | left_y_batches = set() | ||||
driver1.set_sampler_epoch(replaced_loader, 3) | |||||
for idx, batch in enumerate(replaced_loader): | for idx, batch in enumerate(replaced_loader): | ||||
left_x_batches.update(batch["x"]) | |||||
left_y_batches.update(batch["y"]) | |||||
left_x_batches.update(batch["x"].reshape((-1, )).tolist()) | |||||
left_y_batches.update(batch["y"].reshape((-1, )).tolist()) | |||||
res1 = driver1.model.evaluate_step(**batch) | res1 = driver1.model.evaluate_step(**batch) | ||||
res2 = driver2.model.evaluate_step(**batch) | res2 = driver2.model.evaluate_step(**batch) | ||||
assert paddle.equal_all(res1["pred"], res2["pred"]) | assert paddle.equal_all(res1["pred"], res2["pred"]) | ||||
@@ -10,7 +10,7 @@ from fastNLP.core.samplers import ( | |||||
UnrepeatedSequentialSampler, | UnrepeatedSequentialSampler, | ||||
) | ) | ||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchNormalXYDataset | |||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
from fastNLP.envs.distributed import rank_zero_rm | from fastNLP.envs.distributed import rank_zero_rm | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
@@ -19,8 +19,8 @@ if _NEED_IMPORT_TORCH: | |||||
import torch.distributed as dist | import torch.distributed as dist | ||||
from torch.utils.data import DataLoader, BatchSampler | from torch.utils.data import DataLoader, BatchSampler | ||||
def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="all"): | |||||
torch_model = TorchNormalModel_Classification_1(num_labels, feature_dimension) | |||||
def generate_driver(labels, features, device=[0,1], fp16=False, output_from_new_proc="all"): | |||||
torch_model = TorchNormalModel_Classification_1(labels, features) | |||||
torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) | torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) | ||||
device = [torch.device(i) for i in device] | device = [torch.device(i) for i in device] | ||||
driver = TorchDDPDriver( | driver = TorchDDPDriver( | ||||
@@ -504,10 +504,14 @@ class TestSetDistReproDataloader: | |||||
num_replicas = len(self.device) | num_replicas = len(self.device) | ||||
num_consumed_batches = 2 | num_consumed_batches = 2 | ||||
already_seen_idx = set() | already_seen_idx = set() | ||||
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): | |||||
sampler_states = replaced_loader.batch_sampler.set_epoch(4) | |||||
else: | |||||
sampler_states = replaced_loader.batch_sampler.sampler.set_epoch(4) | |||||
for idx, batch in enumerate(replaced_loader): | for idx, batch in enumerate(replaced_loader): | ||||
if idx >= num_consumed_batches: | if idx >= num_consumed_batches: | ||||
break | break | ||||
already_seen_idx.update(batch) | |||||
already_seen_idx.update(batch.tolist()) | |||||
dist.barrier() | dist.barrier() | ||||
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): | if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): | ||||
sampler_states = replaced_loader.batch_sampler.state_dict() | sampler_states = replaced_loader.batch_sampler.state_dict() | ||||
@@ -533,6 +537,7 @@ class TestSetDistReproDataloader: | |||||
pad=True | pad=True | ||||
) | ) | ||||
new_loader.batch_sampler.load_state_dict(sampler_states) | new_loader.batch_sampler.load_state_dict(sampler_states) | ||||
new_loader.batch_sampler.set_epoch(4) | |||||
else: | else: | ||||
batch_size = replaced_loader.batch_sampler.batch_size | batch_size = replaced_loader.batch_sampler.batch_size | ||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas | sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas | ||||
@@ -543,8 +548,9 @@ class TestSetDistReproDataloader: | |||||
rank=driver.global_rank | rank=driver.global_rank | ||||
) | ) | ||||
new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | ||||
new_loader.batch_sampler.sampler.set_epoch(4) | |||||
for idx, batch in enumerate(new_loader): | for idx, batch in enumerate(new_loader): | ||||
left_idxes.update(batch) | |||||
left_idxes.update(batch.tolist()) | |||||
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas | assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas | ||||
assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas | assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas | ||||
@@ -562,7 +568,7 @@ class TestSaveLoad: | |||||
""" | """ | ||||
def setup_method(self): | def setup_method(self): | ||||
self.dataset = TorchArgMaxDataset(10, 20) | |||||
self.dataset = TorchNormalXYDataset(20) | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | @pytest.mark.parametrize("only_state_dict", ([True, False])) | ||||
@@ -574,7 +580,7 @@ class TestSaveLoad: | |||||
path = "model" | path = "model" | ||||
dataloader = DataLoader(self.dataset, batch_size=2) | dataloader = DataLoader(self.dataset, batch_size=2) | ||||
driver1, driver2 = generate_driver(10, 10), generate_driver(10, 10) | |||||
driver1, driver2 = generate_driver(20, 1), generate_driver(20, 1) | |||||
driver1.save_model(path, only_state_dict) | driver1.save_model(path, only_state_dict) | ||||
@@ -618,8 +624,8 @@ class TestSaveLoad: | |||||
path = "model.ckp" | path = "model.ckp" | ||||
num_replicas = len(device) | num_replicas = len(device) | ||||
driver1, driver2 = generate_driver(10, 10, device=device, fp16=fp16), \ | |||||
generate_driver(10, 10, device=device, fp16=False) | |||||
driver1, driver2 = generate_driver(20, 1, device=device, fp16=fp16), \ | |||||
generate_driver(20, 1, device=device, fp16=False) | |||||
dataloader = dataloader_with_bucketedbatchsampler( | dataloader = dataloader_with_bucketedbatchsampler( | ||||
self.dataset, | self.dataset, | ||||
length=[10 for i in range(len(self.dataset))], | length=[10 for i in range(len(self.dataset))], | ||||
@@ -636,11 +642,12 @@ class TestSaveLoad: | |||||
already_seen_x_set = set() | already_seen_x_set = set() | ||||
already_seen_y_set = set() | already_seen_y_set = set() | ||||
driver1.set_sampler_epoch(dataloader, 4) | |||||
for idx, batch in enumerate(dataloader): | for idx, batch in enumerate(dataloader): | ||||
if idx >= num_consumed_batches: | if idx >= num_consumed_batches: | ||||
break | break | ||||
already_seen_x_set.update(batch["x"]) | |||||
already_seen_y_set.update(batch["y"]) | |||||
already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) | |||||
already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) | |||||
# 同步 | # 同步 | ||||
dist.barrier() | dist.barrier() | ||||
@@ -665,7 +672,6 @@ class TestSaveLoad: | |||||
pad=True | pad=True | ||||
) | ) | ||||
dist.barrier() | dist.barrier() | ||||
print("========load=======", driver1.global_rank, driver2.global_rank) | |||||
load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | ||||
dist.barrier() | dist.barrier() | ||||
replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
@@ -690,10 +696,11 @@ class TestSaveLoad: | |||||
assert start_batch == 2 * num_consumed_batches | assert start_batch == 2 * num_consumed_batches | ||||
left_x_batches = set() | left_x_batches = set() | ||||
left_y_batches = set() | left_y_batches = set() | ||||
driver2.set_sampler_epoch(replaced_loader, 4) | |||||
for idx, batch in enumerate(replaced_loader): | for idx, batch in enumerate(replaced_loader): | ||||
left_x_batches.update(batch["x"]) | |||||
left_y_batches.update(batch["y"]) | |||||
left_x_batches.update(batch["x"].reshape(-1, ).tolist()) | |||||
left_y_batches.update(batch["y"].reshape(-1, ).tolist()) | |||||
res1 = driver1.model( | res1 = driver1.model( | ||||
batch, | batch, | ||||
fastnlp_fn=driver1.model.module.model.evaluate_step, | fastnlp_fn=driver1.model.module.model.evaluate_step, | ||||
@@ -716,7 +723,6 @@ class TestSaveLoad: | |||||
dist.barrier() | dist.barrier() | ||||
finally: | finally: | ||||
rank_zero_rm(path) | rank_zero_rm(path) | ||||
print("=======delete======") | |||||
if dist.is_initialized(): | if dist.is_initialized(): | ||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@@ -735,8 +741,8 @@ class TestSaveLoad: | |||||
num_replicas = len(device) | num_replicas = len(device) | ||||
driver1 = generate_driver(10, 10, device=device, fp16=fp16) | |||||
driver2 = generate_driver(10, 10, device=device, fp16=False) | |||||
driver1 = generate_driver(20, 1, device=device, fp16=fp16) | |||||
driver2 = generate_driver(20, 1, device=device, fp16=False) | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False) | dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False) | ||||
dataloader.batch_sampler.sampler.set_distributed( | dataloader.batch_sampler.sampler.set_distributed( | ||||
@@ -748,11 +754,12 @@ class TestSaveLoad: | |||||
already_seen_x_set = set() | already_seen_x_set = set() | ||||
already_seen_y_set = set() | already_seen_y_set = set() | ||||
driver1.set_sampler_epoch(dataloader, 4) | |||||
for idx, batch in enumerate(dataloader): | for idx, batch in enumerate(dataloader): | ||||
if idx >= num_consumed_batches: | if idx >= num_consumed_batches: | ||||
break | break | ||||
already_seen_x_set.update(batch["x"]) | |||||
already_seen_y_set.update(batch["y"]) | |||||
already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) | |||||
already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) | |||||
# 同步 | # 同步 | ||||
dist.barrier() | dist.barrier() | ||||
@@ -797,10 +804,11 @@ class TestSaveLoad: | |||||
assert start_batch == 2 * num_consumed_batches | assert start_batch == 2 * num_consumed_batches | ||||
left_x_batches = set() | left_x_batches = set() | ||||
left_y_batches = set() | left_y_batches = set() | ||||
driver2.set_sampler_epoch(replaced_loader, 4) | |||||
for idx, batch in enumerate(replaced_loader): | for idx, batch in enumerate(replaced_loader): | ||||
left_x_batches.update(batch["x"]) | |||||
left_y_batches.update(batch["y"]) | |||||
left_x_batches.update(batch["x"].reshape(-1, ).tolist()) | |||||
left_y_batches.update(batch["y"].reshape(-1, ).tolist()) | |||||
res1 = driver1.model( | res1 = driver1.model( | ||||
batch, | batch, | ||||
fastnlp_fn=driver1.model.module.model.evaluate_step, | fastnlp_fn=driver1.model.module.model.evaluate_step, | ||||
@@ -14,7 +14,7 @@ else: | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
def test_incorrect_driver(): | def test_incorrect_driver(): | ||||
model = TorchNormalModel_Classification_1(2, 100) | |||||
model = TorchNormalModel_Classification_1(20, 10) | |||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
driver = initialize_torch_driver("paddle", 0, model) | driver = initialize_torch_driver("paddle", 0, model) | ||||
@@ -33,7 +33,7 @@ def test_get_single_device(driver, device): | |||||
测试正常情况下初始化TorchSingleDriver的情况 | 测试正常情况下初始化TorchSingleDriver的情况 | ||||
""" | """ | ||||
model = TorchNormalModel_Classification_1(2, 100) | |||||
model = TorchNormalModel_Classification_1(20, 10) | |||||
driver = initialize_torch_driver(driver, device, model) | driver = initialize_torch_driver(driver, device, model) | ||||
assert isinstance(driver, TorchSingleDriver) | assert isinstance(driver, TorchSingleDriver) | ||||
@@ -52,7 +52,7 @@ def test_get_ddp(driver, device): | |||||
测试 ddp 多卡的初始化情况 | 测试 ddp 多卡的初始化情况 | ||||
""" | """ | ||||
model = TorchNormalModel_Classification_1(64, 10) | |||||
model = TorchNormalModel_Classification_1(20, 10) | |||||
driver = initialize_torch_driver(driver, device, model) | driver = initialize_torch_driver(driver, device, model) | ||||
assert isinstance(driver, TorchDDPDriver) | assert isinstance(driver, TorchDDPDriver) | ||||
@@ -70,6 +70,6 @@ def test_device_out_of_range(driver, device): | |||||
""" | """ | ||||
测试传入的device超过范围的情况 | 测试传入的device超过范围的情况 | ||||
""" | """ | ||||
model = TorchNormalModel_Classification_1(2, 100) | |||||
model = TorchNormalModel_Classification_1(20, 10) | |||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
driver = initialize_torch_driver(driver, device, model) | driver = initialize_torch_driver(driver, device, model) |
@@ -6,7 +6,7 @@ from pkg_resources import parse_version | |||||
from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver | from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver | ||||
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | ||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchNormalXYDataset | |||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset | from tests.helpers.datasets.paddle_data import PaddleNormalDataset | ||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | ||||
from fastNLP.envs.distributed import rank_zero_rm | from fastNLP.envs.distributed import rank_zero_rm | ||||
@@ -15,6 +15,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
from torch.utils.data import DataLoader, BatchSampler | from torch.utils.data import DataLoader, BatchSampler | ||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
import paddle | import paddle | ||||
@@ -67,95 +68,67 @@ class TestTorchDriverFunctions: | |||||
model = TorchNormalModel_Classification_1(10, 32) | model = TorchNormalModel_Classification_1(10, 32) | ||||
self.driver = TorchSingleDriver(model, device="cpu") | self.driver = TorchSingleDriver(model, device="cpu") | ||||
@pytest.mark.torchpaddle | |||||
def test_check_single_optimizer_legality(self): | |||||
@pytest.mark.torch | |||||
def test_check_optimizers_legality(self): | |||||
""" | """ | ||||
测试传入单个 optimizer 时的表现 | |||||
测试对合法 optimizers 的检查 | |||||
""" | """ | ||||
# 单个 optimizer | |||||
optimizer = torch.optim.Adam( | optimizer = torch.optim.Adam( | ||||
params=self.driver.model.parameters(), | params=self.driver.model.parameters(), | ||||
lr=0.01 | lr=0.01 | ||||
) | ) | ||||
self.driver.set_optimizers(optimizer) | self.driver.set_optimizers(optimizer) | ||||
optimizer = paddle.optimizer.Adam( | |||||
parameters=PaddleNormalModel_Classification_1(10, 32).parameters(), | |||||
learning_rate=0.01, | |||||
) | |||||
# 传入 torch 的 optimize r时,应该报错 ValueError | |||||
with pytest.raises(ValueError): | |||||
self.driver.set_optimizers(optimizer) | |||||
@pytest.mark.torchpaddle | |||||
def test_check_optimizers_legality(self): | |||||
""" | |||||
测试传入 optimizer list 的表现 | |||||
""" | |||||
# 列表 | |||||
optimizers = [ | optimizers = [ | ||||
torch.optim.Adam( | torch.optim.Adam( | ||||
params=self.driver.model.parameters(), | params=self.driver.model.parameters(), | ||||
lr=0.01 | lr=0.01 | ||||
) for i in range(10) | ) for i in range(10) | ||||
] | ] | ||||
self.driver.set_optimizers(optimizers) | self.driver.set_optimizers(optimizers) | ||||
optimizers += [ | |||||
@pytest.mark.torchpaddle | |||||
def test_invalid_optimizers(self): | |||||
""" | |||||
测试传入非法的 optimizers | |||||
""" | |||||
optimizer = paddle.optimizer.Adam( | |||||
parameters=PaddleNormalModel_Classification_1(10, 32).parameters(), | |||||
learning_rate=0.01, | |||||
) | |||||
with pytest.raises(TypeError): | |||||
self.driver.set_optimizers(optimizer) | |||||
optimizers = [ | |||||
paddle.optimizer.Adam( | paddle.optimizer.Adam( | ||||
parameters=PaddleNormalModel_Classification_1(10, 32).parameters(), | parameters=PaddleNormalModel_Classification_1(10, 32).parameters(), | ||||
learning_rate=0.01, | learning_rate=0.01, | ||||
) | ) | ||||
] | ] | ||||
with pytest.raises(ValueError): | |||||
with pytest.raises(TypeError): | |||||
self.driver.set_optimizers(optimizers) | self.driver.set_optimizers(optimizers) | ||||
@pytest.mark.torchpaddle | |||||
def test_check_dataloader_legality_in_train(self): | |||||
@pytest.mark.torch | |||||
def test_check_dataloader_legality(self): | |||||
""" | """ | ||||
测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 | |||||
测试 check_dataloader_legality 函数的表现 | |||||
""" | """ | ||||
dataloader = DataLoader(TorchNormalDataset()) | dataloader = DataLoader(TorchNormalDataset()) | ||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
# 创建 paddle 的 dataloader | |||||
dataloader = paddle.io.DataLoader( | |||||
PaddleNormalDataset(), | |||||
batch_size=32, shuffle=True | |||||
) | |||||
with pytest.raises(ValueError): | |||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
self.driver.check_dataloader_legality(dataloader) | |||||
@pytest.mark.torchpaddle | @pytest.mark.torchpaddle | ||||
def test_check_dataloader_legality_in_test(self): | |||||
def test_check_dataloader_legality_invalid(self): | |||||
""" | """ | ||||
测试 `is_train` 参数为 False 时,_check_dataloader_legality 函数的表现 | |||||
测试 check_dataloader_legality 函数传入其他类型的表现 | |||||
""" | """ | ||||
# 此时传入的应该是dict | |||||
dataloader = { | |||||
"train": DataLoader(TorchNormalDataset()), | |||||
"test": DataLoader(TorchNormalDataset()) | |||||
} | |||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
# 传入的不是 dict,应该报错 | |||||
dataloader = DataLoader(TorchNormalDataset()) | |||||
with pytest.raises(ValueError): | |||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
# 创建 paddle 的 dataloader | # 创建 paddle 的 dataloader | ||||
train_loader = paddle.io.DataLoader( | |||||
PaddleNormalDataset(), | |||||
batch_size=32, shuffle=True | |||||
) | |||||
test_loader = paddle.io.DataLoader( | |||||
dataloader = paddle.io.DataLoader( | |||||
PaddleNormalDataset(), | PaddleNormalDataset(), | ||||
batch_size=32, shuffle=True | batch_size=32, shuffle=True | ||||
) | ) | ||||
dataloader = {"train": train_loader, "test": test_loader} | |||||
with pytest.raises(ValueError): | |||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
with pytest.raises(TypeError): | |||||
self.driver.check_dataloader_legality(dataloader) | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
def test_tensor_to_numeric(self): | def test_tensor_to_numeric(self): | ||||
@@ -515,10 +488,14 @@ class TestSetDistReproDataloader: | |||||
# 迭代两个 batch | # 迭代两个 batch | ||||
num_consumed_batches = 2 | num_consumed_batches = 2 | ||||
already_seen_idx = set() | already_seen_idx = set() | ||||
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | |||||
replaced_loader.batch_sampler.set_epoch(3) | |||||
else: | |||||
replaced_loader.batch_sampler.sampler.set_epoch(3) | |||||
for idx, batch in enumerate(replaced_loader): | for idx, batch in enumerate(replaced_loader): | ||||
if idx >= num_consumed_batches: | if idx >= num_consumed_batches: | ||||
break | break | ||||
already_seen_idx.update(batch) | |||||
already_seen_idx.update(batch.tolist()) | |||||
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | ||||
sampler_states = replaced_loader.batch_sampler.state_dict() | sampler_states = replaced_loader.batch_sampler.state_dict() | ||||
else: | else: | ||||
@@ -532,14 +509,16 @@ class TestSetDistReproDataloader: | |||||
# 重新改造 dataloader | # 重新改造 dataloader | ||||
new_loader = dataloader_with_randombatchsampler(replaced_loader.dataset, batch_size, shuffle, False) | new_loader = dataloader_with_randombatchsampler(replaced_loader.dataset, batch_size, shuffle, False) | ||||
new_loader.batch_sampler.load_state_dict(sampler_states) | new_loader.batch_sampler.load_state_dict(sampler_states) | ||||
new_loader.batch_sampler.set_epoch(3) | |||||
else: | else: | ||||
batch_size = replaced_loader.batch_sampler.batch_size | batch_size = replaced_loader.batch_sampler.batch_size | ||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | ||||
# 重新构造 dataloader | # 重新构造 dataloader | ||||
new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, False) | new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, False) | ||||
new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | ||||
new_loader.batch_sampler.sampler.set_epoch(3) | |||||
for idx, batch in enumerate(new_loader): | for idx, batch in enumerate(new_loader): | ||||
left_idxes.update(batch) | |||||
left_idxes.update(batch.tolist()) | |||||
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) | assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) | ||||
assert len(left_idxes | already_seen_idx) == len(self.dataset) | assert len(left_idxes | already_seen_idx) == len(self.dataset) | ||||
@@ -550,7 +529,7 @@ class TestSetDistReproDataloader: | |||||
# | # | ||||
############################################################################ | ############################################################################ | ||||
def generate_random_driver(features, labels, fp16=False, device="cpu"): | |||||
def generate_random_driver(labels, features, fp16=False, device="cpu"): | |||||
""" | """ | ||||
生成driver | 生成driver | ||||
""" | """ | ||||
@@ -570,9 +549,9 @@ def test_save_and_load_model(only_state_dict): | |||||
""" | """ | ||||
try: | try: | ||||
path = "model" | path = "model" | ||||
dataset = TorchArgMaxDataset(10, 40) | |||||
dataset = TorchNormalXYDataset(20) | |||||
dataloader = DataLoader(dataset, batch_size=4) | dataloader = DataLoader(dataset, batch_size=4) | ||||
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) | |||||
driver1, driver2 = generate_random_driver(20, 1), generate_random_driver(20, 1) | |||||
driver1.save_model(path, only_state_dict) | driver1.save_model(path, only_state_dict) | ||||
driver2.load_model(path, only_state_dict) | driver2.load_model(path, only_state_dict) | ||||
@@ -596,19 +575,20 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
try: | try: | ||||
path = "model.ckp" | path = "model.ckp" | ||||
dataset = TorchArgMaxDataset(10, 40) | |||||
dataset = TorchNormalXYDataset(20) | |||||
dataloader = dataloader_with_randombatchsampler(dataset, 4, True, False) | dataloader = dataloader_with_randombatchsampler(dataset, 4, True, False) | ||||
driver1, driver2 = generate_random_driver(10, 10, fp16, "cuda"), generate_random_driver(10, 10, False, "cuda") | |||||
driver1, driver2 = generate_random_driver(20, 1, fp16, "cuda"), generate_random_driver(20, 1, False, "cuda") | |||||
num_consumed_batches = 2 | num_consumed_batches = 2 | ||||
already_seen_x_set = set() | already_seen_x_set = set() | ||||
already_seen_y_set = set() | already_seen_y_set = set() | ||||
driver1.set_sampler_epoch(dataloader, 3) | |||||
for idx, batch in enumerate(dataloader): | for idx, batch in enumerate(dataloader): | ||||
if idx >= num_consumed_batches: | if idx >= num_consumed_batches: | ||||
break | break | ||||
already_seen_x_set.update(batch["x"]) | |||||
already_seen_y_set.update(batch["y"]) | |||||
already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) | |||||
already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) | |||||
sampler_states = dataloader.batch_sampler.state_dict() | sampler_states = dataloader.batch_sampler.state_dict() | ||||
save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
@@ -639,11 +619,12 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
assert start_batch == 2 * num_consumed_batches | assert start_batch == 2 * num_consumed_batches | ||||
left_x_batches = set() | left_x_batches = set() | ||||
left_y_batches = set() | left_y_batches = set() | ||||
driver1.set_sampler_epoch(replaced_loader, 3) | |||||
for idx, batch in enumerate(replaced_loader): | for idx, batch in enumerate(replaced_loader): | ||||
batch = driver2.move_data_to_device(batch) | batch = driver2.move_data_to_device(batch) | ||||
left_x_batches.update(batch["x"]) | |||||
left_y_batches.update(batch["y"]) | |||||
left_x_batches.update(batch["x"].reshape(-1, ).tolist()) | |||||
left_y_batches.update(batch["y"].reshape(-1, ).tolist()) | |||||
res1 = driver1.model.evaluate_step(**batch) | res1 = driver1.model.evaluate_step(**batch) | ||||
res2 = driver2.model.evaluate_step(**batch) | res2 = driver2.model.evaluate_step(**batch) | ||||
assert torch.equal(res1["preds"], res2["preds"]) | assert torch.equal(res1["preds"], res2["preds"]) | ||||
@@ -660,24 +641,25 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
@pytest.mark.parametrize("fp16", ([True, False])) | @pytest.mark.parametrize("fp16", ([True, False])) | ||||
def test_save_and_load_with_randomsampler(only_state_dict, fp16): | def test_save_and_load_with_randomsampler(only_state_dict, fp16): | ||||
""" | """ | ||||
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 | |||||
测试save和load函数,主要测试 dataloader 被替换了 sampler 的情况 | |||||
""" | """ | ||||
try: | try: | ||||
path = "model.ckp" | path = "model.ckp" | ||||
driver1, driver2 = generate_random_driver(10, 10, fp16, "cuda"), generate_random_driver(10, 10, False, "cuda") | |||||
dataset = TorchArgMaxDataset(10, 40) | |||||
driver1, driver2 = generate_random_driver(40, 1, fp16, "cuda"), generate_random_driver(40, 1, False, "cuda") | |||||
dataset = TorchNormalXYDataset(40) | |||||
dataloader = dataloader_with_randomsampler(dataset, 4, True, False) | dataloader = dataloader_with_randomsampler(dataset, 4, True, False) | ||||
num_consumed_batches = 2 | num_consumed_batches = 2 | ||||
already_seen_x_set = set() | already_seen_x_set = set() | ||||
already_seen_y_set = set() | already_seen_y_set = set() | ||||
driver1.set_sampler_epoch(dataloader, 3) | |||||
for idx, batch in enumerate(dataloader): | for idx, batch in enumerate(dataloader): | ||||
if idx >= num_consumed_batches: | if idx >= num_consumed_batches: | ||||
break | break | ||||
already_seen_x_set.update(batch["x"]) | |||||
already_seen_y_set.update(batch["y"]) | |||||
already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) | |||||
already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) | |||||
sampler_states = dataloader.batch_sampler.sampler.state_dict() | sampler_states = dataloader.batch_sampler.sampler.state_dict() | ||||
save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
@@ -711,11 +693,13 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||||
assert start_batch == 2 * num_consumed_batches | assert start_batch == 2 * num_consumed_batches | ||||
left_x_batches = set() | left_x_batches = set() | ||||
left_y_batches = set() | left_y_batches = set() | ||||
# set epoch | |||||
driver2.set_sampler_epoch(replaced_loader, 3) | |||||
for idx, batch in enumerate(replaced_loader): | for idx, batch in enumerate(replaced_loader): | ||||
batch = driver2.move_data_to_device(batch) | batch = driver2.move_data_to_device(batch) | ||||
left_x_batches.update(batch["x"]) | |||||
left_y_batches.update(batch["y"]) | |||||
left_x_batches.update(batch["x"].reshape(-1, ).tolist()) | |||||
left_y_batches.update(batch["y"].reshape(-1, ).tolist()) | |||||
res1 = driver1.model.evaluate_step(**batch) | res1 = driver1.model.evaluate_step(**batch) | ||||
res2 = driver2.model.evaluate_step(**batch) | res2 = driver2.model.evaluate_step(**batch) | ||||
assert torch.equal(res1["preds"], res2["preds"]) | assert torch.equal(res1["preds"], res2["preds"]) | ||||
@@ -0,0 +1,46 @@ | |||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||||
if _NEED_IMPORT_JITTOR: | |||||
import jittor as jt | |||||
from jittor.dataset import Dataset | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as Dataset | |||||
class JittorNormalDataset(Dataset): | |||||
def __init__(self, num_of_data=100, **kwargs): | |||||
super(JittorNormalDataset, self).__init__(**kwargs) | |||||
self._data = list(range(num_of_data)) | |||||
self.set_attrs(total_len=num_of_data) | |||||
def __getitem__(self, item): | |||||
return self._data[item] | |||||
class JittorNormalXYDataset(Dataset): | |||||
""" | |||||
可以被输入到分类模型中的普通数据集 | |||||
""" | |||||
def __init__(self, num_of_data=1000, **kwargs): | |||||
super(JittorNormalXYDataset, self).__init__(**kwargs) | |||||
self.num_of_data = num_of_data | |||||
self._data = list(range(num_of_data)) | |||||
self.set_attrs(total_len=num_of_data) | |||||
def __getitem__(self, item): | |||||
return { | |||||
"x": jt.Var([self._data[item]]), | |||||
"y": jt.Var([self._data[item]]) | |||||
} | |||||
class JittorArgMaxDataset(Dataset): | |||||
def __init__(self, num_samples, num_features, **kwargs): | |||||
super(JittorArgMaxDataset, self).__init__(**kwargs) | |||||
self.x = jt.randn(num_samples, num_features) | |||||
self.y = self.x.argmax(dim=-1) | |||||
self.set_attrs(total_len=num_samples) | |||||
def __getitem__(self, item): | |||||
return {"x": self.x[item], "y": self.y[item]} | |||||
if __name__ == "__main__": | |||||
dataset = JittorNormalDataset() | |||||
print(len(dataset)) |
@@ -19,8 +19,24 @@ class PaddleNormalDataset(Dataset): | |||||
def __getitem__(self, item): | def __getitem__(self, item): | ||||
return self._data[item] | return self._data[item] | ||||
class PaddleNormalXYDataset(Dataset): | |||||
""" | |||||
可以被输入到分类模型中的普通数据集 | |||||
""" | |||||
def __init__(self, num_of_data=1000): | |||||
self.num_of_data = num_of_data | |||||
self._data = list(range(num_of_data)) | |||||
def __len__(self): | |||||
return self.num_of_data | |||||
def __getitem__(self, item): | |||||
return { | |||||
"x": paddle.to_tensor([self._data[item]], dtype="float32"), | |||||
"y": paddle.to_tensor([self._data[item]], dtype="float32") | |||||
} | |||||
class PaddleRandomMaxDataset(Dataset): | |||||
class PaddleArgMaxDataset(Dataset): | |||||
def __init__(self, num_samples, num_features): | def __init__(self, num_samples, num_features): | ||||
self.x = paddle.randn((num_samples, num_features)) | self.x = paddle.randn((num_samples, num_features)) | ||||
self.y = self.x.argmax(axis=-1) | self.y = self.x.argmax(axis=-1) | ||||
@@ -1,4 +1,6 @@ | |||||
from functools import reduce | from functools import reduce | ||||
from numpy import dtype | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
@@ -19,6 +21,23 @@ class TorchNormalDataset(Dataset): | |||||
def __getitem__(self, item): | def __getitem__(self, item): | ||||
return self._data[item] | return self._data[item] | ||||
class TorchNormalXYDataset(Dataset): | |||||
""" | |||||
可以被输入到分类模型中的普通数据集 | |||||
""" | |||||
def __init__(self, num_of_data=1000): | |||||
self.num_of_data = num_of_data | |||||
self._data = list(range(num_of_data)) | |||||
def __len__(self): | |||||
return self.num_of_data | |||||
def __getitem__(self, item): | |||||
return { | |||||
"x": torch.tensor([self._data[item]], dtype=torch.float), | |||||
"y": torch.tensor([self._data[item]], dtype=torch.float) | |||||
} | |||||
# 该类专门用于为 tests.helpers.models.torch_model.py/ TorchNormalModel_Classification_1 创建数据; | # 该类专门用于为 tests.helpers.models.torch_model.py/ TorchNormalModel_Classification_1 创建数据; | ||||
class TorchNormalDataset_Classification(Dataset): | class TorchNormalDataset_Classification(Dataset): | ||||
@@ -0,0 +1,57 @@ | |||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||||
if _NEED_IMPORT_JITTOR: | |||||
from jittor import Module, nn | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as Module | |||||
class JittorNormalModel_Classification_1(Module): | |||||
""" | |||||
基础的 jittor 分类模型 | |||||
""" | |||||
def __init__(self, num_labels, feature_dimension): | |||||
super(JittorNormalModel_Classification_1, self).__init__() | |||||
self.num_labels = num_labels | |||||
self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64) | |||||
self.ac1 = nn.ReLU() | |||||
self.linear2 = nn.Linear(in_features=64, out_features=32) | |||||
self.ac2 = nn.ReLU() | |||||
self.output = nn.Linear(in_features=32, out_features=num_labels) | |||||
self.loss_fn = nn.CrossEntropyLoss() | |||||
def execute(self, x): | |||||
x = self.ac1(self.linear1(x)) | |||||
x = self.ac2(self.linear2(x)) | |||||
x = self.output(x) | |||||
return x | |||||
def train_step(self, x, y): | |||||
x = self(x) | |||||
return {"loss": self.loss_fn(x, y)} | |||||
def evaluate_step(self, x, y): | |||||
x = self(x) | |||||
return {"pred": x, "target": y.reshape((-1,))} | |||||
class JittorNormalModel_Classification_2(Module): | |||||
""" | |||||
基础的 jittor 分类模型,只实现 execute 函数测试用户自己初始化了分布式的场景 | |||||
""" | |||||
def __init__(self, num_labels, feature_dimension): | |||||
super(JittorNormalModel_Classification_2, self).__init__() | |||||
self.num_labels = num_labels | |||||
self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64) | |||||
self.ac1 = nn.ReLU() | |||||
self.linear2 = nn.Linear(in_features=64, out_features=32) | |||||
self.ac2 = nn.ReLU() | |||||
self.output = nn.Linear(in_features=32, out_features=num_labels) | |||||
self.loss_fn = nn.CrossEntropyLoss() | |||||
def execute(self, x, y): | |||||
x = self.ac1(self.linear1(x)) | |||||
x = self.ac2(self.linear2(x)) | |||||
x = self.output(x) | |||||
return {"loss": self.loss_fn(x, y), "pred": x, "target": y.reshape((-1,))} |
@@ -8,7 +8,7 @@ else: | |||||
class PaddleNormalModel_Classification_1(Layer): | class PaddleNormalModel_Classification_1(Layer): | ||||
""" | """ | ||||
基础的paddle分类模型 | |||||
基础的 paddle 分类模型 | |||||
""" | """ | ||||
def __init__(self, num_labels, feature_dimension): | def __init__(self, num_labels, feature_dimension): | ||||
super(PaddleNormalModel_Classification_1, self).__init__() | super(PaddleNormalModel_Classification_1, self).__init__() | ||||
@@ -39,7 +39,7 @@ class PaddleNormalModel_Classification_1(Layer): | |||||
class PaddleNormalModel_Classification_2(Layer): | class PaddleNormalModel_Classification_2(Layer): | ||||
""" | """ | ||||
基础的paddle分类模型,只实现 forward 函数测试用户自己初始化了分布式的场景 | |||||
基础的 paddle 分类模型,只实现 forward 函数测试用户自己初始化了分布式的场景 | |||||
""" | """ | ||||
def __init__(self, num_labels, feature_dimension): | def __init__(self, num_labels, feature_dimension): | ||||
super(PaddleNormalModel_Classification_2, self).__init__() | super(PaddleNormalModel_Classification_2, self).__init__() | ||||
@@ -56,5 +56,4 @@ class PaddleNormalModel_Classification_2(Layer): | |||||
x = self.ac1(self.linear1(x)) | x = self.ac1(self.linear1(x)) | ||||
x = self.ac2(self.linear2(x)) | x = self.ac2(self.linear2(x)) | ||||
x = self.output(x) | x = self.output(x) | ||||
loss = self.loss_fn(x, y) | |||||
return {"loss": self.loss_fn(x, y), "pred": x, "target": y.reshape((-1,))} | return {"loss": self.loss_fn(x, y), "pred": x, "target": y.reshape((-1,))} |
@@ -33,7 +33,11 @@ class TestPaddle2Torch: | |||||
""" | """ | ||||
assert isinstance(tensor, torch.Tensor) | assert isinstance(tensor, torch.Tensor) | ||||
assert tensor.device == torch.device(device) | |||||
if device == "cpu": | |||||
assert not tensor.is_cuda | |||||
else: | |||||
assert tensor.is_cuda | |||||
assert tensor.device.index == torch.device(device).index | |||||
assert tensor.requires_grad == requires_grad | assert tensor.requires_grad == requires_grad | ||||
def test_gradient(self): | def test_gradient(self): | ||||
@@ -261,7 +265,8 @@ class TestJittor2Torch: | |||||
if device == "cpu": | if device == "cpu": | ||||
assert not tensor.is_cuda | assert not tensor.is_cuda | ||||
else: | else: | ||||
assert tensor.device == torch.device(device) | |||||
assert tensor.is_cuda | |||||
assert tensor.device.index == torch.device(device).index | |||||
assert tensor.requires_grad == requires_grad | assert tensor.requires_grad == requires_grad | ||||
def test_var_transfer(self): | def test_var_transfer(self): | ||||
@@ -271,7 +276,10 @@ class TestJittor2Torch: | |||||
jittor_var = jittor.rand((3, 4, 5)) | jittor_var = jittor.rand((3, 4, 5)) | ||||
res = jittor2torch(jittor_var) | res = jittor2torch(jittor_var) | ||||
self.check_torch_tensor(res, "cpu", True) | |||||
if jittor.flags.use_cuda: | |||||
self.check_torch_tensor(res, "cuda:0", True) | |||||
else: | |||||
self.check_torch_tensor(res, "cpu", True) | |||||
res = jittor2torch(jittor_var, device="cuda:2", no_gradient=None) | res = jittor2torch(jittor_var, device="cuda:2", no_gradient=None) | ||||
self.check_torch_tensor(res, "cuda:2", True) | self.check_torch_tensor(res, "cuda:2", True) | ||||
@@ -291,7 +299,10 @@ class TestJittor2Torch: | |||||
res = jittor2torch(jittor_list) | res = jittor2torch(jittor_list) | ||||
assert isinstance(res, list) | assert isinstance(res, list) | ||||
for t in res: | for t in res: | ||||
self.check_torch_tensor(t, "cpu", True) | |||||
if jittor.flags.use_cuda: | |||||
self.check_torch_tensor(t, "cuda:0", True) | |||||
else: | |||||
self.check_torch_tensor(t, "cpu", True) | |||||
res = jittor2torch(jittor_list, device="cuda:1", no_gradient=False) | res = jittor2torch(jittor_list, device="cuda:1", no_gradient=False) | ||||
assert isinstance(res, list) | assert isinstance(res, list) | ||||
@@ -327,17 +338,29 @@ class TestJittor2Torch: | |||||
} | } | ||||
res = jittor2torch(jittor_dict) | res = jittor2torch(jittor_dict) | ||||
assert isinstance(res, dict) | assert isinstance(res, dict) | ||||
self.check_torch_tensor(res["tensor"], "cpu", True) | |||||
if jittor.flags.use_cuda: | |||||
self.check_torch_tensor(res["tensor"], "cuda:0", True) | |||||
else: | |||||
self.check_torch_tensor(res["tensor"], "cpu", True) | |||||
assert isinstance(res["list"], list) | assert isinstance(res["list"], list) | ||||
for t in res["list"]: | for t in res["list"]: | ||||
self.check_torch_tensor(t, "cpu", True) | |||||
if jittor.flags.use_cuda: | |||||
self.check_torch_tensor(t, "cuda:0", True) | |||||
else: | |||||
self.check_torch_tensor(t, "cpu", True) | |||||
assert isinstance(res["int"], int) | assert isinstance(res["int"], int) | ||||
assert isinstance(res["string"], str) | assert isinstance(res["string"], str) | ||||
assert isinstance(res["dict"], dict) | assert isinstance(res["dict"], dict) | ||||
assert isinstance(res["dict"]["list"], list) | assert isinstance(res["dict"]["list"], list) | ||||
for t in res["dict"]["list"]: | for t in res["dict"]["list"]: | ||||
self.check_torch_tensor(t, "cpu", True) | |||||
self.check_torch_tensor(res["dict"]["tensor"], "cpu", True) | |||||
if jittor.flags.use_cuda: | |||||
self.check_torch_tensor(t, "cuda:0", True) | |||||
else: | |||||
self.check_torch_tensor(t, "cpu", True) | |||||
if jittor.flags.use_cuda: | |||||
self.check_torch_tensor(res["dict"]["tensor"], "cuda:0", True) | |||||
else: | |||||
self.check_torch_tensor(res["dict"]["tensor"], "cpu", True) | |||||
############################################################################ | ############################################################################ | ||||