|
|
@@ -247,18 +247,27 @@ class PaddleDriver(Driver): |
|
|
|
# 会造成多余实际消耗的问题。 |
|
|
|
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) |
|
|
|
if num_consumed_samples_array is not None: |
|
|
|
if isinstance(sampler, ReproducibleSampler): |
|
|
|
# 如果是 sampler 的话,需要计算出实际的 sample 数目 |
|
|
|
try: |
|
|
|
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 |
|
|
|
if dataloader_args.batch_size is not None: |
|
|
|
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size |
|
|
|
except: # 有可能 batch_size 为 None,就只有损失精度了 |
|
|
|
else: # 有可能 batch_size 为 None,就只有损失精度了 |
|
|
|
logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " |
|
|
|
"it may cause missing some samples when reload.") |
|
|
|
num_consumed_batches = sampler_states['num_consumed_samples'] |
|
|
|
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] |
|
|
|
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." |
|
|
|
states['sampler_states'] = sampler_states |
|
|
|
else: |
|
|
|
if dataloader_args.batch_size is not None: |
|
|
|
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ |
|
|
|
* num_consumed_batches |
|
|
|
else: |
|
|
|
logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " |
|
|
|
"it may cause missing some samples when reload.") |
|
|
|
else: |
|
|
|
raise RuntimeError( |
|
|
|
"The sampler has no `state_dict()` method, it will fail to recover to the specific batch.") |
|
|
|
|
|
|
|
states['sampler_states'] = sampler_states |
|
|
|
|
|
|
|
# 2. 保存模型的状态; |
|
|
|
if should_save_model: |
|
|
|