From 6db70df2a7eacd0df619b79ae31ee6f3f0becc90 Mon Sep 17 00:00:00 2001 From: YWMditto Date: Wed, 31 Aug 2022 20:37:28 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=20torch=5Fdrive?= =?UTF-8?q?r.load=5Fmodel=20=E5=9C=A8ddp=20=E6=97=B6=E5=87=BA=E7=8E=B0?= =?UTF-8?q?=E6=98=BE=E5=AD=98=E7=88=86=E7=82=B8=E7=9A=84todo=E4=BA=8B?= =?UTF-8?q?=E9=A1=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/drivers/torch_driver/torch_driver.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index db011403..6ca33476 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -190,6 +190,7 @@ class TorchDriver(Driver): :param load_state_dict: 保存的内容是否只是权重 """ model = self.unwrap_model() + # todo torch.load 在加载时会使得卡 0 多出一个(甚至多个)model 的显存;因此在多卡断点重训时可能会出现错误; res = torch.load(filepath, map_location='cpu') if isinstance(res, dict) and only_state_dict is False: logger.rank_zero_warning(f"It seems like that {filepath} only contains state, you may need to use " From 3f4a1f8e80686f5c7ad18eadf4dfc3bcb85a5694 Mon Sep 17 00:00:00 2001 From: YWMditto Date: Sat, 10 Sep 2022 11:10:31 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=86=20fdl.=5FFDataS?= =?UTF-8?q?et=20=E4=B8=AD=E5=9C=A8=E4=BD=BF=E7=94=A8=20pickle=20=E6=97=B6?= =?UTF-8?q?=E5=87=BA=E7=8E=B0=E7=9A=84=20bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataloaders/torch_dataloader/fdl.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index 5ae72367..8aa48382 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -47,6 +47,16 @@ class _FDataSet: def __len__(self) -> int: return len(self.dataset) + # 这里需要显示地带上这两个方法,因为可能会涉及到 pickle 的 dumps 和 loads;否则会导致 pickle 在 loads 时调用 __setstate__ 方法 + # 进入到 __getattr__ 内部,引发死循环; + # https://docs.python.org/3/library/pickle.html#pickling-class-instances + # https://stackoverflow.com/questions/73662315/when-using-multiprocessing-and-spawn-in-python-use-self-a-in-getattr-cause?noredirect=1 + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, state): + self.__dict__ = state + class TorchDataLoader(DataLoader): """