| @@ -1,72 +1,63 @@ | |||
| import numpy as np | |||
| import random | |||
| import torch | |||
| import torch.multiprocessing as multiprocessing | |||
| from torch.utils.data.dataloader import _set_worker_signal_handlers, _update_worker_pids, \ | |||
| _remove_worker_pids, _error_if_any_worker_fails | |||
| import signal | |||
| import sys | |||
| import threading | |||
| import traceback | |||
| import os | |||
| from torch._six import FileNotFoundError | |||
| from fastNLP.core.sampler import RandomSampler | |||
| class Batch(object): | |||
| def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False, num_workers=0, pin_memory=False, | |||
| timeout=0.0, keep_process=False): | |||
| """ | |||
| Batch is an iterable object which iterates over mini-batches. | |||
| """Batch is an iterable object which iterates over mini-batches. | |||
| Example:: | |||
| Example:: | |||
| iterator = Batch(data_set, batch_size=16, sampler=SequentialSampler()) | |||
| for epoch in range(num_epochs): | |||
| for batch_x, batch_y in iterator: # 每次epoch会重新使用sampler生成index的。 | |||
| # ... | |||
| for batch_x, batch_y in Batch(data_set, batch_size=16, sampler=SequentialSampler()): | |||
| # ... | |||
| :param DataSet dataset: a DataSet object | |||
| :param int batch_size: the size of the batch | |||
| :param Sampler sampler: a Sampler object | |||
| :param bool as_numpy: If True, return Numpy array. Otherwise, return torch tensors. | |||
| :param num_workers: int, 使用多少个进程来准备数据。默认为0, 即使用主线程生成数据。 特性处于实验阶段,谨慎使用。 | |||
| 如果DataSet较大,且每个batch的准备时间很短,使用多进程可能并不能提速。 | |||
| :param pin_memory: bool, 默认为False. 设置为True时,有可能可以节省tensor从cpu移动到gpu的阻塞时间。 | |||
| :param timeout: float, 大于0的数,只有在num_workers>0时才有用。超过该时间仍然没有获取到一个batch则报错,可以用于 | |||
| 检测是否出现了batch产生阻塞的情况。 | |||
| :param keep_process: bool. 默认为False,该参数只在多进程下有效。在多进程的情况下,反复产生batch的iterator会导致 | |||
| 不断创建、销毁进程,可能对速度有一定的影响。当keep_process为True时,直到Batch对象被删除之前,多进程都没有关 | |||
| 闭。如果设置了keep_process为True,可以通过del BatchObject来删除Batch对象并关闭进程。 | |||
| """ | |||
| :param DataSet dataset: a DataSet object | |||
| :param int batch_size: the size of the batch | |||
| :param Sampler sampler: a Sampler object | |||
| :param bool as_numpy: If True, return Numpy array. Otherwise, return torch tensors. | |||
| if num_workers < 0: | |||
| raise ValueError('num_workers option cannot be negative; ' | |||
| 'use num_workers=0 to disable multiprocessing.') | |||
| if timeout < 0: | |||
| raise ValueError('timeout option should be non-negative') | |||
| """ | |||
| def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False): | |||
| self.dataset = dataset | |||
| self.batch_size = batch_size | |||
| self.sampler = sampler | |||
| self.num_workers = num_workers | |||
| self.keep_process = keep_process | |||
| self.pin_memory = pin_memory | |||
| self.timeout = timeout | |||
| self.as_numpy = as_numpy | |||
| self.idx_list = None | |||
| self.curidx = 0 | |||
| self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0) | |||
| self.cur_batch_indices = None | |||
| self._data_iterator = None | |||
| def __iter__(self): | |||
| if self._data_iterator is not None: | |||
| # 重新设置index_list | |||
| self._data_iterator.reset() | |||
| return self._data_iterator | |||
| elif self.keep_process and self.num_workers>0: | |||
| self._data_iterator = _DataLoaderIter(self) | |||
| return self._data_iterator | |||
| else: # 大多数情况是这个 | |||
| return _DataLoaderIter(self) | |||
| self.idx_list = self.sampler(self.dataset) | |||
| self.curidx = 0 | |||
| self.lengths = self.dataset.get_length() | |||
| return self | |||
| def __next__(self): | |||
| if self.curidx >= len(self.idx_list): | |||
| raise StopIteration | |||
| else: | |||
| endidx = min(self.curidx + self.batch_size, len(self.idx_list)) | |||
| batch_x, batch_y = {}, {} | |||
| indices = self.idx_list[self.curidx:endidx] | |||
| self.cur_batch_indices = indices | |||
| for field_name, field in self.dataset.get_all_fields().items(): | |||
| if field.is_target or field.is_input: | |||
| batch = field.get(indices) | |||
| if not self.as_numpy and field.padder is not None: | |||
| batch = to_tensor(batch, field.dtype) | |||
| if field.is_target: | |||
| batch_y[field_name] = batch | |||
| if field.is_input: | |||
| batch_x[field_name] = batch | |||
| self.curidx = endidx | |||
| return batch_x, batch_y | |||
| def __len__(self): | |||
| return self.num_batches | |||
| @@ -74,11 +65,6 @@ class Batch(object): | |||
| def get_batch_indices(self): | |||
| return self.cur_batch_indices | |||
| def __del__(self): | |||
| if self.keep_process is True: | |||
| del self._data_iterator | |||
| def to_tensor(batch, dtype): | |||
| try: | |||
| @@ -89,396 +75,3 @@ def to_tensor(batch, dtype): | |||
| except: | |||
| pass | |||
| return batch | |||
| """ | |||
| 由于多进程涉及到大量问题,包括系统、安全关闭进程等。所以这里直接从pytorch的官方版本修改DataLoader实现多进程加速 | |||
| """ | |||
| IS_WINDOWS = sys.platform == "win32" | |||
| if IS_WINDOWS: | |||
| import ctypes | |||
| from ctypes.wintypes import DWORD, BOOL, HANDLE | |||
| if sys.version_info[0] == 2: | |||
| import Queue as queue | |||
| else: | |||
| import queue | |||
| class ExceptionWrapper(object): | |||
| r"""Wraps an exception plus traceback to communicate across threads""" | |||
| def __init__(self, exc_info): | |||
| self.exc_type = exc_info[0] | |||
| self.exc_msg = "".join(traceback.format_exception(*exc_info)) | |||
| _use_shared_memory = False | |||
| r"""Whether to use shared memory in default_collate""" | |||
| MANAGER_STATUS_CHECK_INTERVAL = 5.0 | |||
| if IS_WINDOWS: | |||
| # On Windows, the parent ID of the worker process remains unchanged when the manager process | |||
| # is gone, and the only way to check it through OS is to let the worker have a process handle | |||
| # of the manager and ask if the process status has changed. | |||
| class ManagerWatchdog(object): | |||
| def __init__(self): | |||
| self.manager_pid = os.getppid() | |||
| self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True) | |||
| self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD) | |||
| self.kernel32.OpenProcess.restype = HANDLE | |||
| self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD) | |||
| self.kernel32.WaitForSingleObject.restype = DWORD | |||
| # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx | |||
| SYNCHRONIZE = 0x00100000 | |||
| self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid) | |||
| if not self.manager_handle: | |||
| raise ctypes.WinError(ctypes.get_last_error()) | |||
| def is_alive(self): | |||
| # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx | |||
| return self.kernel32.WaitForSingleObject(self.manager_handle, 0) != 0 | |||
| else: | |||
| class ManagerWatchdog(object): | |||
| def __init__(self): | |||
| self.manager_pid = os.getppid() | |||
| def is_alive(self): | |||
| return os.getppid() == self.manager_pid | |||
| def _worker_loop(dataset, index_queue, data_queue, seed, worker_id, as_numpy): | |||
| # 产生数据的循环 | |||
| global _use_shared_memory | |||
| _use_shared_memory = True | |||
| # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal | |||
| # module's handlers are executed after Python returns from C low-level | |||
| # handlers, likely when the same fatal signal happened again already. | |||
| # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1 | |||
| _set_worker_signal_handlers() | |||
| torch.set_num_threads(1) | |||
| random.seed(seed) | |||
| torch.manual_seed(seed) | |||
| watchdog = ManagerWatchdog() | |||
| while True: | |||
| try: | |||
| # 获取当前batch计数,当前batch的indexes | |||
| r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL) | |||
| except queue.Empty: | |||
| if watchdog.is_alive(): | |||
| continue | |||
| else: | |||
| break | |||
| if r is None: | |||
| break | |||
| idx, batch_indices = r | |||
| try: | |||
| # 获取相应的batch数据。这里需要修改为从dataset中取出数据并且完成padding | |||
| samples = _get_batch_from_dataset(dataset, batch_indices, as_numpy) | |||
| except Exception: | |||
| data_queue.put((idx, ExceptionWrapper(sys.exc_info()), batch_indices)) | |||
| else: | |||
| data_queue.put((idx, samples, batch_indices)) | |||
| del samples | |||
| def _get_batch_from_dataset(dataset, indices, as_numpy): | |||
| """ | |||
| 给定indices,从DataSet中取出(batch_x, batch_y). 数据从这里产生后,若没有pin_memory, 则直接传递给Trainer了,如果存在 | |||
| pin_memory还会经过一道pin_memory()的处理 | |||
| :param dataset: fastNLP.DataSet对象 | |||
| :param indices: List[int], index | |||
| :param as_numpy: bool, 是否只是转换为numpy | |||
| :return: (batch_x, batch_y) | |||
| """ | |||
| batch_x, batch_y = {}, {} | |||
| for field_name, field in dataset.get_all_fields().items(): | |||
| if field.is_target or field.is_input: | |||
| batch = field.get(indices) | |||
| if not as_numpy and field.padder is not None: | |||
| batch = to_tensor(batch, field.dtype) | |||
| if field.is_target: | |||
| batch_y[field_name] = batch | |||
| if field.is_input: | |||
| batch_x[field_name] = batch | |||
| return batch_x, batch_y | |||
| def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id): | |||
| # 将数据送入到指定的query中. 即如果需要pin_memory, 则 | |||
| if pin_memory: | |||
| torch.cuda.set_device(device_id) | |||
| while True: | |||
| try: | |||
| r = in_queue.get() | |||
| except Exception: | |||
| if done_event.is_set(): | |||
| return | |||
| raise | |||
| if r is None: | |||
| break | |||
| if isinstance(r[1], ExceptionWrapper): | |||
| out_queue.put(r) | |||
| continue | |||
| idx, batch, batch_indices = r | |||
| try: | |||
| if pin_memory: | |||
| batch = pin_memory_batch(batch) | |||
| except Exception: | |||
| out_queue.put((idx, ExceptionWrapper(sys.exc_info()), batch_indices)) | |||
| else: | |||
| out_queue.put((idx, batch, batch_indices)) | |||
| def pin_memory_batch(batchs): | |||
| """ | |||
| :param batchs: (batch_x, batch_y) | |||
| :return: (batch_x, batch_y) | |||
| """ | |||
| for batch_dict in batchs: | |||
| for field_name, batch in batch_dict.items(): | |||
| if isinstance(batch, torch.Tensor): | |||
| batch_dict[field_name] = batch.pin_memory() | |||
| return batchs | |||
| _SIGCHLD_handler_set = False | |||
| r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one | |||
| handler needs to be set for all DataLoaders in a process.""" | |||
| def _set_SIGCHLD_handler(): | |||
| # Windows doesn't support SIGCHLD handler | |||
| if sys.platform == 'win32': | |||
| return | |||
| # can't set signal in child threads | |||
| if not isinstance(threading.current_thread(), threading._MainThread): | |||
| return | |||
| global _SIGCHLD_handler_set | |||
| if _SIGCHLD_handler_set: | |||
| return | |||
| previous_handler = signal.getsignal(signal.SIGCHLD) | |||
| if not callable(previous_handler): | |||
| previous_handler = None | |||
| def handler(signum, frame): | |||
| # This following call uses `waitid` with WNOHANG from C side. Therefore, | |||
| # Python can still get and update the process status successfully. | |||
| _error_if_any_worker_fails() | |||
| if previous_handler is not None: | |||
| previous_handler(signum, frame) | |||
| signal.signal(signal.SIGCHLD, handler) | |||
| _SIGCHLD_handler_set = True | |||
| class _DataLoaderIter(object): | |||
| r"""Iterates once over the DataLoader's dataset, as specified by the sampler""" | |||
| def __init__(self, batcher): | |||
| self.batcher = batcher | |||
| self.dataset = batcher.dataset | |||
| self.sampler = batcher.sampler | |||
| self.as_numpy = batcher.as_numpy | |||
| self.batch_size = batcher.batch_size | |||
| self.num_workers = batcher.num_workers | |||
| self.pin_memory = batcher.pin_memory and torch.cuda.is_available() | |||
| self.timeout = batcher.timeout | |||
| self.keep_process = batcher.keep_process | |||
| self.done_event = threading.Event() | |||
| self.curidx = 0 | |||
| self.idx_list = self.sampler(self.dataset) | |||
| # self.sample_iter一次返回一个index. 可以通过其他方式替代 | |||
| base_seed = torch.LongTensor(1).random_().item() | |||
| if self.num_workers > 0: | |||
| # 每个worker建立一个index queue | |||
| self.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)] | |||
| self.worker_queue_idx = 0 | |||
| # 存放获取到的batch | |||
| self.worker_result_queue = multiprocessing.SimpleQueue() | |||
| self.batches_outstanding = 0 | |||
| self.worker_pids_set = False | |||
| self.shutdown = False | |||
| self.send_idx = 0 | |||
| self.rcvd_idx = 0 | |||
| self.reorder_dict = {} | |||
| # 这里会将batch的数据输送到self.worker_result_queue中,但是还没有送入到device中 | |||
| self.workers = [ | |||
| multiprocessing.Process( | |||
| target=_worker_loop, | |||
| args=(self.dataset, self.index_queues[i], | |||
| self.worker_result_queue, base_seed + i, i, self.as_numpy)) | |||
| for i in range(self.num_workers)] | |||
| # self.data_queue取数据就行。如果有pin_memory的话,会把数据放到另一个queue | |||
| if self.pin_memory or self.timeout > 0: | |||
| self.data_queue = queue.Queue() | |||
| if self.pin_memory: | |||
| maybe_device_id = torch.cuda.current_device() | |||
| else: | |||
| # do not initialize cuda context if not necessary | |||
| maybe_device_id = None | |||
| self.worker_manager_thread = threading.Thread( | |||
| target=_worker_manager_loop, | |||
| args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory, | |||
| maybe_device_id)) | |||
| self.worker_manager_thread.daemon = True | |||
| self.worker_manager_thread.start() | |||
| else: | |||
| self.data_queue = self.worker_result_queue | |||
| # worker们开始工作 | |||
| for w in self.workers: | |||
| w.daemon = True # ensure that the worker exits on process exit | |||
| w.start() | |||
| _update_worker_pids(id(self), tuple(w.pid for w in self.workers)) | |||
| _set_SIGCHLD_handler() | |||
| self.worker_pids_set = True | |||
| # prime the prefetch loop | |||
| for _ in range(2 * self.num_workers): | |||
| self._put_indices() | |||
| def reset(self): | |||
| """ | |||
| 重置curidx以及重新采样idx_list. 只有再需要keep_process时才有用 | |||
| :return: | |||
| """ | |||
| if self.keep_process: | |||
| self.curidx = 0 | |||
| self.idx_list = self.sampler(self.dataset) | |||
| for _ in range(2 * self.num_workers): | |||
| self._put_indices() | |||
| def _get_batch(self): | |||
| if self.timeout > 0: | |||
| try: | |||
| return self.data_queue.get(timeout=self.timeout) | |||
| except queue.Empty: | |||
| raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout)) | |||
| else: | |||
| return self.data_queue.get() | |||
| def __next__(self): | |||
| if self.num_workers == 0: # same-process loading | |||
| if self.curidx >= len(self.idx_list): | |||
| raise StopIteration | |||
| endidx = min(self.curidx + self.batch_size, len(self.idx_list)) | |||
| # 直接从数据集中采集数据即可 | |||
| indices = self.idx_list[self.curidx:endidx] | |||
| self.batcher.cur_batch_indices = indices | |||
| batch_x, batch_y = _get_batch_from_dataset(dataset=self.dataset, indices=indices, | |||
| as_numpy=self.as_numpy) | |||
| if self.pin_memory: | |||
| batch_x, batch_y = pin_memory_batch((batch_x, batch_y)) | |||
| self.curidx = endidx | |||
| return batch_x, batch_y | |||
| # check if the next sample has already been generated | |||
| if self.rcvd_idx in self.reorder_dict: | |||
| batch = self.reorder_dict.pop(self.rcvd_idx) | |||
| return self._process_next_batch(batch) | |||
| # 如果生成的数据为0了,则停止 | |||
| if self.batches_outstanding == 0: | |||
| if not self.keep_process: | |||
| self._shutdown_workers() | |||
| raise StopIteration | |||
| while True: | |||
| assert (not self.shutdown and self.batches_outstanding > 0) | |||
| idx, batch, batch_indices = self._get_batch() | |||
| self.batches_outstanding -= 1 | |||
| if idx != self.rcvd_idx: | |||
| # store out-of-order samples | |||
| self.reorder_dict[idx] = batch | |||
| continue | |||
| self.batcher.cur_batch_indices = batch_indices | |||
| return self._process_next_batch(batch) | |||
| def __iter__(self): | |||
| self.curidx = 0 | |||
| return self | |||
| def _put_indices(self): | |||
| # 向采集数据的index queue中放入index | |||
| assert self.batches_outstanding < 2 * self.num_workers | |||
| if self.curidx >= len(self.idx_list): | |||
| indices = None | |||
| else: | |||
| endidx = min(self.curidx + self.batch_size, len(self.idx_list)) | |||
| # 直接从数据集中采集数据即可 | |||
| indices = self.idx_list[self.curidx:endidx] | |||
| if indices is None: | |||
| return | |||
| self.index_queues[self.worker_queue_idx].put((self.send_idx, indices)) | |||
| self.curidx = endidx | |||
| self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers | |||
| self.batches_outstanding += 1 | |||
| self.send_idx += 1 | |||
| def _process_next_batch(self, batch): | |||
| # 只是提醒生成下一个batch indice数据 | |||
| self.rcvd_idx += 1 | |||
| self._put_indices() | |||
| if isinstance(batch, ExceptionWrapper): | |||
| raise batch.exc_type(batch.exc_msg) | |||
| return batch | |||
| def __getstate__(self): | |||
| # TODO: add limited pickling support for sharing an iterator | |||
| # across multiple threads for HOGWILD. | |||
| # Probably the best way to do this is by moving the sample pushing | |||
| # to a separate thread and then just sharing the data queue | |||
| # but signalling the end is tricky without a non-blocking API | |||
| raise NotImplementedError("_DataLoaderIter cannot be pickled") | |||
| def _shutdown_workers(self): | |||
| try: | |||
| if not self.shutdown: | |||
| self.shutdown = True | |||
| self.done_event.set() | |||
| for q in self.index_queues: | |||
| q.put(None) | |||
| # if some workers are waiting to put, make place for them | |||
| try: | |||
| while not self.worker_result_queue.empty(): | |||
| self.worker_result_queue.get() | |||
| except (FileNotFoundError, ImportError): | |||
| # Many weird errors can happen here due to Python | |||
| # shutting down. These are more like obscure Python bugs. | |||
| # FileNotFoundError can happen when we rebuild the fd | |||
| # fetched from the queue but the socket is already closed | |||
| # from the worker side. | |||
| # ImportError can happen when the unpickler loads the | |||
| # resource from `get`. | |||
| pass | |||
| # done_event should be sufficient to exit worker_manager_thread, | |||
| # but be safe here and put another None | |||
| self.worker_result_queue.put(None) | |||
| finally: | |||
| # removes pids no matter what | |||
| if self.worker_pids_set: | |||
| _remove_worker_pids(id(self)) | |||
| self.worker_pids_set = False | |||
| def __del__(self): | |||
| if self.num_workers > 0: | |||
| self._shutdown_workers() | |||