Browse Source

paddle save函数适应新的sampler

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
77f6b63ba6
1 changed files with 14 additions and 5 deletions
  1. +14
    -5
      fastNLP/core/drivers/paddle_driver/paddle_driver.py

+ 14
- 5
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -247,18 +247,27 @@ class PaddleDriver(Driver):
# 会造成多余实际消耗的问题。
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)
if num_consumed_samples_array is not None:
if isinstance(sampler, ReproducibleSampler):
# 如果是 sampler 的话,需要计算出实际的 sample 数目
try:
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。
if dataloader_args.batch_size is not None:
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size
except: # 有可能 batch_size 为 None,就只有损失精度了
else: # 有可能 batch_size 为 None,就只有损失精度了
logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")
num_consumed_batches = sampler_states['num_consumed_samples']
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches]
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report."
states['sampler_states'] = sampler_states
else:
if dataloader_args.batch_size is not None:
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \
* num_consumed_batches
else:
logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")
else:
raise RuntimeError(
"The sampler has no `state_dict()` method, it will fail to recover to the specific batch.")
states['sampler_states'] = sampler_states

# 2. 保存模型的状态;
if should_save_model:


Loading…
Cancel
Save