Browse Source

修复mixdataloader在torch1.6下参数不匹配的问题

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
296c1acc31
1 changed files with 17 additions and 7 deletions
  1. +17
    -7
      fastNLP/core/dataloaders/torch_dataloader/mix_dataloader.py

+ 17
- 7
fastNLP/core/dataloaders/torch_dataloader/mix_dataloader.py View File

@@ -5,6 +5,7 @@ __all__ = [
from typing import Optional, Callable, List, Union, Tuple, Dict, Sequence, Mapping from typing import Optional, Callable, List, Union, Tuple, Dict, Sequence, Mapping


import numpy as np import numpy as np
from pkg_resources import parse_version


from fastNLP.core.dataset import DataSet, Instance from fastNLP.core.dataset import DataSet, Instance
from fastNLP.core.samplers import PollingSampler, MixSequentialSampler, DopedSampler from fastNLP.core.samplers import PollingSampler, MixSequentialSampler, DopedSampler
@@ -12,6 +13,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.core.collators import Collator from fastNLP.core.collators import Collator


if _NEED_IMPORT_TORCH: if _NEED_IMPORT_TORCH:
from torch import __version__ as torchversion
from torch.utils.data import DataLoader, Sampler from torch.utils.data import DataLoader, Sampler
else: else:
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader from fastNLP.core.utils.dummy_class import DummyClass as DataLoader
@@ -213,13 +215,21 @@ class MixDataLoader(DataLoader):
else: else:
raise ValueError(f"{mode} must be sequential, polling, mix or batch_sampler") raise ValueError(f"{mode} must be sequential, polling, mix or batch_sampler")


super(MixDataLoader, self).__init__(
_MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None,
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn,
pin_memory=pin_memory, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None, generator=None,
prefetch_factor=2, persistent_workers=False
)
if parse_version(torchversion) >= parse_version('1.7'):
super(MixDataLoader, self).__init__(
_MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None,
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn,
pin_memory=pin_memory, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None, generator=None,
prefetch_factor=2, persistent_workers=False
)
else:
super(MixDataLoader, self).__init__(
_MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None,
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn,
pin_memory=pin_memory, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None, generator=None,
)


def __iter__(self): def __iter__(self):
return super().__iter__() return super().__iter__()

Loading…
Cancel
Save