From 288eb36afbae4d3583e662fbc50f7eb9920b128f Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Fri, 15 Apr 2022 09:05:44 +0000 Subject: [PATCH] =?UTF-8?q?=E6=96=AD=E7=82=B9=E9=87=8D=E8=AE=AD=20save?= =?UTF-8?q?=E6=97=B6=E7=9A=84=E9=80=BB=E8=BE=91=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../drivers/paddle_driver/paddle_driver.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index 37a5e59e..3b8ad7d8 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -244,17 +244,18 @@ class PaddleDriver(Driver): if hasattr(sampler, "state_dict") and callable(sampler.state_dict): sampler_states = sampler.state_dict() # 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples - # 会造成多余实际消耗的问题。 - num_consumed_samples_array = sampler_states.pop("num_consumed_samples_array", None) + # 会造成多余实际消耗的问题。 + num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) if num_consumed_samples_array is not None: - sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] - else: - try: - sampler_states["num_consumed_samples"] = num_consumed_batches * dataloader_args.batch_size - except: # 有可能 batch_size 为 None,就只有损失精度了 - pass - assert sampler_states["num_consumed_samples"] != -1, "This is a bug, please report." - states["sampler_states"] = sampler_states + if isinstance(sampler, ReproducibleSampler): + # 如果是 sampler 的话,需要计算出实际的 sample 数目 + try: + num_consumed_batches = num_consumed_batches * dataloader_args.batch_size + except: # 有可能 batch_size 为 None,就只有损失精度了 + num_consumed_batches = sampler_states['num_consumed_samples'] + sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] + assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." + states['sampler_states'] = sampler_states else: raise RuntimeError( "The sampler has no `state_dict()` method, it will fail to recover to the specific batch.")