Browse Source

修改之前的错误名称引用

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
388e426d78
9 changed files with 44 additions and 33 deletions
  1. +0
    -1
      fastNLP/core/controllers/trainer.py
  2. +2
    -2
      fastNLP/core/drivers/driver.py
  3. +3
    -3
      fastNLP/core/drivers/jittor_driver/single_device.py
  4. +5
    -5
      fastNLP/core/drivers/paddle_driver/single_device.py
  5. +6
    -6
      fastNLP/core/drivers/torch_driver/ddp.py
  6. +6
    -6
      fastNLP/core/drivers/torch_driver/single_device.py
  7. +8
    -9
      fastNLP/core/drivers/torch_driver/torch_driver.py
  8. +5
    -1
      fastNLP/core/samplers/__init__.py
  9. +9
    -0
      fastNLP/core/samplers/reproducible_batch_sampler.py

+ 0
- 1
fastNLP/core/controllers/trainer.py View File

@@ -23,7 +23,6 @@ from fastNLP.core.drivers import Driver
from fastNLP.core.drivers.utils import choose_driver
from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext
from fastNLP.envs import rank_zero_call
from fastNLP.core.samplers import ReproducibleSampler, RandomBatchSampler
from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_MODEL_FILENAME



+ 2
- 2
fastNLP/core/drivers/driver.py View File

@@ -49,13 +49,13 @@ class Driver(ABC):
不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的
数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist";
否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None;
注意当 dist 为 ReproducibleIterator, RandomBatchSampler 时,是断点重训加载时 driver.load 函数在调用;
注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用;
当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数;

:param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得
可以可以加载。
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外,
如果传入的 dataloader 中是 ReproducibleSampler 或者 RandomBatchSampler 需要重新初始化一个放入返回的
如果传入的 dataloader 中是 ReproducibleSampler 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的
dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。
"""
if dist is None and reproducible is False:


+ 3
- 3
fastNLP/core/drivers/jittor_driver/single_device.py View File

@@ -3,7 +3,7 @@ from typing import Dict, Union
from .jittor_driver import JittorDriver
from fastNLP.core.utils import auto_param_call
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler

if _NEED_IMPORT_JITTOR:
import jittor
@@ -99,10 +99,10 @@ class JittorSingleDriver(JittorDriver):
def is_distributed(self):
return False

def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler],
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler],
reproducible: bool = False, sampler_or_batch_sampler=None):
# reproducible 的相关功能暂时没有实现
if isinstance(dist, RandomBatchSampler):
if isinstance(dist, ReproducibleBatchSampler):
raise NotImplementedError
dataloader.batch_sampler = dist_sample
if isinstance(dist, ReproducibleSampler):


+ 5
- 5
fastNLP/core/drivers/paddle_driver/single_device.py View File

@@ -10,7 +10,7 @@ from fastNLP.core.utils import (
get_paddle_device_id,
paddle_move_data_to_device,
)
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler
from fastNLP.core.log import logger

if _NEED_IMPORT_PADDLE:
@@ -139,12 +139,12 @@ class PaddleSingleDriver(PaddleDriver):
"""
return paddle_move_data_to_device(batch, "gpu:0")

def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler],
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler],
reproducible: bool = False, sampler_or_batch_sampler=None):
# 暂时不支持IteratorDataset
assert dataloader.dataset_kind != _DatasetKind.ITER, \
"FastNLP does not support `IteratorDataset` now."
if isinstance(dist, RandomBatchSampler):
if isinstance(dist, ReproducibleBatchSampler):
dataloader.batch_sampler = dist
return dataloader
if isinstance(dist, ReproducibleSampler):
@@ -154,11 +154,11 @@ class PaddleSingleDriver(PaddleDriver):
if reproducible:
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler):
return dataloader
elif isinstance(dataloader.batch_sampler, RandomBatchSampler):
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler):
return dataloader
else:
# TODO
batch_sampler = RandomBatchSampler(
batch_sampler = ReproducibleBatchSampler(
batch_sampler=dataloader.batch_sampler,
batch_size=dataloader.batch_sampler.batch_size,
drop_last=dataloader.drop_last


+ 6
- 6
fastNLP/core/drivers/torch_driver/ddp.py View File

@@ -28,7 +28,7 @@ from fastNLP.core.drivers.torch_driver.utils import (
)
from fastNLP.core.drivers.utils import distributed_open_proc
from fastNLP.core.utils import auto_param_call, check_user_specific_params
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, RandomBatchSampler, \
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, ReproducibleBatchSampler, \
re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED
from fastNLP.core.log import logger
@@ -446,11 +446,11 @@ class TorchDDPDriver(TorchDriver):
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST})
return self._test_step(batch)

def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]]=None,
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproducibleBatchSampler]]=None,
reproducible: bool = False):
# 如果 dist 为 RandomBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用;
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load 函数调用;
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数;
if isinstance(dist, RandomBatchSampler):
if isinstance(dist, ReproducibleBatchSampler):
dist.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank,
@@ -472,7 +472,7 @@ class TorchDDPDriver(TorchDriver):
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our "
"control.")
else:
if isinstance(dist, RandomBatchSampler):
if isinstance(dist, ReproducibleBatchSampler):
dist = re_instantiate_sampler(dist)
return replace_batch_sampler(dataloader, dist)
if isinstance(dist, ReproducibleSampler):
@@ -483,7 +483,7 @@ class TorchDDPDriver(TorchDriver):
elif dist == "dist":
args = self.get_dataloader_args(dataloader)
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为;
if isinstance(args.batch_sampler, RandomBatchSampler):
if isinstance(args.batch_sampler, ReproducibleBatchSampler):
batch_sampler = re_instantiate_sampler(args.batch_sampler)
batch_sampler.set_distributed(
num_replicas=self.world_size,


+ 6
- 6
fastNLP/core/drivers/torch_driver/single_device.py View File

@@ -13,7 +13,7 @@ __all__ = [
from .torch_driver import TorchDriver
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler
from fastNLP.core.utils import auto_param_call
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler, re_instantiate_sampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler
from fastNLP.core.log import logger


@@ -129,18 +129,18 @@ class TorchSingleDriver(TorchDriver):
else:
return self._test_step(batch)

def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler]=None,
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None,
reproducible: bool = False):

# 如果 dist 为 RandomBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用;
if isinstance(dist, RandomBatchSampler):
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用;
if isinstance(dist, ReproducibleBatchSampler):
return replace_batch_sampler(dataloader, dist)
elif isinstance(dist, ReproducibleSampler):
return replace_sampler(dataloader, dist)

# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用;
args = self.get_dataloader_args(dataloader)
if isinstance(args.batch_sampler, RandomBatchSampler):
if isinstance(args.batch_sampler, ReproducibleBatchSampler):
batch_sampler = re_instantiate_sampler(args.batch_sampler)
return replace_batch_sampler(dataloader, batch_sampler)
elif isinstance(args.sampler, ReproducibleSampler):
@@ -148,7 +148,7 @@ class TorchSingleDriver(TorchDriver):
return replace_sampler(dataloader, sampler)

if reproducible:
batch_sampler = RandomBatchSampler(
batch_sampler = ReproducibleBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last


+ 8
- 9
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -30,7 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device
from fastNLP.envs import rank_zero_call
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
from fastNLP.core.log import logger
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleIterator
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler


class TorchDriver(Driver):
@@ -183,9 +183,9 @@ class TorchDriver(Driver):

# 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch;
# 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `set_` 中将 dataloader 的
# sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `RandomBatchSampler`;
# sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`;
dataloader_args = self.get_dataloader_args(dataloader)
if isinstance(dataloader_args.batch_sampler, RandomBatchSampler):
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
sampler = dataloader_args.batch_sampler
elif dataloader_args.sampler:
sampler = dataloader_args.sampler
@@ -245,15 +245,14 @@ class TorchDriver(Driver):

# 3. 恢复 sampler 的状态;
dataloader_args = self.get_dataloader_args(dataloader)
if isinstance(dataloader_args.batch_sampler, RandomBatchSampler):
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
sampler = dataloader_args.batch_sampler
elif isinstance(dataloader_args.sampler, ReproducibleIterator):
elif isinstance(dataloader_args.sampler, ReproducibleSampler):
sampler = dataloader_args.sampler
elif self.is_distributed():
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our "
"`RandomBatchSampler` or `ReproducibleIterator`.")
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.")
else:
sampler = RandomBatchSampler(
sampler = ReproducibleBatchSampler(
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last
@@ -263,7 +262,7 @@ class TorchDriver(Driver):

# 4. 修改 trainer_state.batch_idx_in_epoch
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler;
if not isinstance(sampler, RandomBatchSampler):
if not isinstance(sampler, ReproducibleBatchSampler):
if dataloader_args.drop_last:
batch_idx_in_epoch = len(
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size


+ 5
- 1
fastNLP/core/samplers/__init__.py View File

@@ -19,6 +19,10 @@ __all__ = [
"UnrepeatedSortedSampler",
"UnrepeatedSequentialSampler",

"RandomBatchSampler",
"BucketedBatchSampler",
"ReproducibleBatchSampler",

"re_instantiate_sampler",
"conversion_between_reproducible_and_unrepeated_sampler"
]
@@ -28,5 +32,5 @@ from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, Unre
from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler
from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler
from .utils import re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler


+ 9
- 0
fastNLP/core/samplers/reproducible_batch_sampler.py View File

@@ -17,6 +17,9 @@ from abc import abstractmethod


class ReproducibleBatchSampler:
def __init__(self, **kwargs):
pass

@abstractmethod
def set_distributed(self, num_replicas, rank, pad=True):
raise NotImplementedError("Each specific batch_sampler should implement its own `set_distributed` method.")
@@ -41,6 +44,10 @@ class ReproducibleBatchSampler:
def set_epoch(self, epoch):
pass

@property
def batch_idx_in_epoch(self):
raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.")


class RandomBatchSampler(ReproducibleBatchSampler):
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿;
@@ -54,6 +61,8 @@ class RandomBatchSampler(ReproducibleBatchSampler):
:param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。
:param kwargs: fastNLP 内部使用。
"""
super().__init__()

self.batch_sampler = batch_sampler
self.batch_size = batch_size
self.drop_last = drop_last


Loading…
Cancel
Save