Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

dev0.8.0
x54-729 2 years ago
parent
commit
72e51c0f1f
2 changed files with 11 additions and 0 deletions
  1. +10
    -0
      fastNLP/core/dataloaders/torch_dataloader/fdl.py
  2. +1
    -0
      fastNLP/core/drivers/torch_driver/torch_driver.py

+ 10
- 0
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -47,6 +47,16 @@ class _FDataSet:
def __len__(self) -> int:
return len(self.dataset)

# 这里需要显示地带上这两个方法,因为可能会涉及到 pickle 的 dumps 和 loads;否则会导致 pickle 在 loads 时调用 __setstate__ 方法
# 进入到 __getattr__ 内部,引发死循环;
# https://docs.python.org/3/library/pickle.html#pickling-class-instances
# https://stackoverflow.com/questions/73662315/when-using-multiprocessing-and-spawn-in-python-use-self-a-in-getattr-cause?noredirect=1
def __getstate__(self):
return self.__dict__

def __setstate__(self, state):
self.__dict__ = state


class TorchDataLoader(DataLoader):
"""


+ 1
- 0
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -190,6 +190,7 @@ class TorchDriver(Driver):
:param load_state_dict: 保存的内容是否只是权重
"""
model = self.unwrap_model()
# todo torch.load 在加载时会使得卡 0 多出一个(甚至多个)model 的显存;因此在多卡断点重训时可能会出现错误;
res = torch.load(filepath, map_location='cpu')
if isinstance(res, dict) and only_state_dict is False:
logger.rank_zero_warning(f"It seems like that {filepath} only contains state, you may need to use "


Loading…
Cancel
Save