diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index 37a5e59e..3b8ad7d8 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -244,17 +244,18 @@ class PaddleDriver(Driver): if hasattr(sampler, "state_dict") and callable(sampler.state_dict): sampler_states = sampler.state_dict() # 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples - # 会造成多余实际消耗的问题。 - num_consumed_samples_array = sampler_states.pop("num_consumed_samples_array", None) + # 会造成多余实际消耗的问题。 + num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) if num_consumed_samples_array is not None: - sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] - else: - try: - sampler_states["num_consumed_samples"] = num_consumed_batches * dataloader_args.batch_size - except: # 有可能 batch_size 为 None,就只有损失精度了 - pass - assert sampler_states["num_consumed_samples"] != -1, "This is a bug, please report." - states["sampler_states"] = sampler_states + if isinstance(sampler, ReproducibleSampler): + # 如果是 sampler 的话,需要计算出实际的 sample 数目 + try: + num_consumed_batches = num_consumed_batches * dataloader_args.batch_size + except: # 有可能 batch_size 为 None,就只有损失精度了 + num_consumed_batches = sampler_states['num_consumed_samples'] + sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] + assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." + states['sampler_states'] = sampler_states else: raise RuntimeError( "The sampler has no `state_dict()` method, it will fail to recover to the specific batch.")