From 77f6b63ba669e5844af7398892e2301e9f46a7e0 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sat, 16 Apr 2022 08:40:16 +0000 Subject: [PATCH] =?UTF-8?q?paddle=20save=E5=87=BD=E6=95=B0=E9=80=82?= =?UTF-8?q?=E5=BA=94=E6=96=B0=E7=9A=84sampler?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../drivers/paddle_driver/paddle_driver.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index 75e0352f..fe8bf404 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -247,18 +247,27 @@ class PaddleDriver(Driver): # 会造成多余实际消耗的问题。 num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) if num_consumed_samples_array is not None: - if isinstance(sampler, ReproducibleSampler): - # 如果是 sampler 的话,需要计算出实际的 sample 数目 - try: + if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 + if dataloader_args.batch_size is not None: num_consumed_batches = num_consumed_batches * dataloader_args.batch_size - except: # 有可能 batch_size 为 None,就只有损失精度了 + else: # 有可能 batch_size 为 None,就只有损失精度了 + logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " + "it may cause missing some samples when reload.") num_consumed_batches = sampler_states['num_consumed_samples'] sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." - states['sampler_states'] = sampler_states + else: + if dataloader_args.batch_size is not None: + sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ + * num_consumed_batches + else: + logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " + "it may cause missing some samples when reload.") else: raise RuntimeError( "The sampler has no `state_dict()` method, it will fail to recover to the specific batch.") + + states['sampler_states'] = sampler_states # 2. 保存模型的状态; if should_save_model: