Browse Source

调整num_consumed_batches在driver save中的逻辑

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
5c29cd384a
2 changed files with 66 additions and 44 deletions
  1. +21
    -21
      fastNLP/core/drivers/paddle_driver/paddle_driver.py
  2. +45
    -23
      tests/core/drivers/paddle_driver/test_single_device.py

+ 21
- 21
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -34,10 +34,10 @@ if _NEED_IMPORT_PADDLE:
from paddle.optimizer import Optimizer

_reduces = {
'max': paddle.max,
'min': paddle.min,
'mean': paddle.mean,
'sum': paddle.sum
"max": paddle.max,
"min": paddle.min,
"mean": paddle.mean,
"sum": paddle.sum
}

class PaddleDriver(Driver):
@@ -254,24 +254,24 @@ class PaddleDriver(Driver):
else:
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.")

num_consumed_batches = states.pop('num_consumed_batches')
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict):
num_consumed_batches = states.pop("num_consumed_batches")
if hasattr(sampler, "state_dict") and callable(sampler.state_dict):
sampler_states = sampler.state_dict()
# 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples
# 会造成多余实际消耗的问题。
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)
# 会造成多余实际消耗的问题。
num_consumed_samples_array = sampler_states.pop("num_consumed_samples_array", None)
if num_consumed_samples_array is not None:
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。
try:
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size
except: # 有可能 batch_size 为 None,就只有损失精度了
num_consumed_batches = sampler_states['num_consumed_samples']
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches]
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report."
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
try:
sampler_states["num_consumed_samples"] = num_consumed_batches * dataloader_args.batch_size
except: # 有可能 batch_size 为 None,就只有损失精度了
pass
assert sampler_states["num_consumed_samples"] != -1, "This is a bug, please report."
else:
raise RuntimeError(
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.')
"The sampler has no `state_dict()` method, it will fail to recover to the specific batch.")
states["sampler_states"] = sampler_states

# 2. 保存模型的状态;
if should_save_model:
@@ -326,7 +326,7 @@ class PaddleDriver(Driver):
batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last
)
sampler.load_state_dict(states['sampler_states'])
sampler.load_state_dict(states["sampler_states"])
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler)

# 4. 修改 trainer_state.batch_idx_in_epoch
@@ -355,7 +355,7 @@ class PaddleDriver(Driver):
return paddle.no_grad

@staticmethod
def move_model_to_device(model: 'paddle.nn.Layer', device: Union[str, int, 'paddle.CUDAPlace', 'paddle.CPUPlace']):
def move_model_to_device(model: "paddle.nn.Layer", device: Union[str, int, "paddle.CUDAPlace", "paddle.CPUPlace"]):
r"""
用来将模型转移到指定的 device 上;
在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。
@@ -363,7 +363,7 @@ class PaddleDriver(Driver):
if device is not None:
model.to(device)

def move_data_to_device(self, batch: 'paddle.Tensor'):
def move_data_to_device(self, batch: "paddle.Tensor"):
r"""
将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。
在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。
@@ -404,7 +404,7 @@ class PaddleDriver(Driver):
if int(os.environ.get(FASTNLP_SEED_WORKERS, 0)) and dataloader.worker_init_fn is None:
dataloader.worker_init_fn = partial(self.worker_init_function, rank=self.global_rank)

def set_sampler_epoch(self, dataloader: 'DataLoader', cur_epoch_idx):
def set_sampler_epoch(self, dataloader: "DataLoader", cur_epoch_idx):
r"""
对于分布式的 sampler,dataloader 需要在每一个 epoch 前设置随机数种子,来保证每一个进程上的 shuffle 是一样的;



+ 45
- 23
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -224,7 +224,6 @@ class TestSetDistReproDataloder:
"""
def setup_method(self):
self.dataset = PaddleNormalDataset(20)
self.dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)
model = PaddleNormalModel_Classification_1(10, 32)
self.driver = PaddleSingleDriver(model, device="cpu")
@@ -233,55 +232,59 @@ class TestSetDistReproDataloder:
测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现
当dist为字符串时,此时应该返回原来的 dataloader
"""
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=False)
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)

assert replaced_loader is self.dataloader
assert replaced_loader is dataloader

def test_set_dist_repro_dataloader_with_reproducible_true(self):
"""
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现
当dist为字符串时,此时应该返回新的 dataloader,且 batch_sampler 为 RandomBatchSampler
"""
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=True)
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True)

assert not (replaced_loader is self.dataloader)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler)
assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == self.dataloader.drop_last
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last

# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader)
# self.check_set_dist_repro_dataloader(dataloader, replaced_loader)

def test_set_dist_repro_dataloader_with_dist_batch_sampler(self):
"""
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False)
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)

assert not (replaced_loader is self.dataloader)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert replaced_loader.batch_sampler is dist

# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader)

def test_set_dist_repro_dataloader_with_dist_sampler(self):
"""
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现
应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)
dist = RandomSampler(self.dataset, shuffle=True)
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)

assert not (replaced_loader is self.dataloader)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler is self.dataloader.batch_sampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert replaced_loader.batch_sampler.sampler is dist
assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size

# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader)

def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self):
"""
@@ -295,11 +298,12 @@ class TestSetDistReproDataloder:
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last

# self.check_set_dist_repro_dataloader(dataloader, replaced_loader)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader)

def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self):
"""
@@ -316,34 +320,52 @@ class TestSetDistReproDataloder:

assert not (replaced_loader is dataloader)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.sampler.shuffle == True

# self.check_set_dist_repro_dataloader(dataloader, replaced_loader)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader)

def check_set_dist_repro_dataloader(self, dataloader, replaced_loader):
"""
测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确
"""
# 迭代两个 batch
# 这里会发生 BatchSampler 里 yield 了多次但 dataloader 只取出一次的情况。
num_consumed_batches = 2
already_seen_idx = set()
for idx, batch in replaced_loader:
already_seen_idx.update(batch)
if idx >= 1:
for idx, batch in enumerate(replaced_loader):
if idx >= num_consumed_batches:
break
already_seen_idx.update(batch)
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
sampler_states = replaced_loader.batch_sampler.state_dict()
else:
sampler_states = replaced_loader.batch_sampler.sampler.state_dict()
print(sampler_states["data_idx"])

# 加载 num_consumed_samples_array,设置正确取出的 batch 数目
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)

import time
time.sleep(5)

# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range
left_idxes = set()
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
batch_size = replaced_loader.batch_sampler.batch_size
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
replaced_loader.batch_sampler.load_state_dict(sampler_states)
else:
batch_size = replaced_loader.batch_sampler.batch_size
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
replaced_loader.batch_sampler.sampler.load_state_dict(sampler_states)
replaced_loader.batch_sampler.sampler.set_epoch(0)
for idx, batch in enumerate(replaced_loader):
left_idxes.update(batch)



Loading…
Cancel
Save