From a7f3701bdf3fc48e4caa92210ded14bd8ca19852 Mon Sep 17 00:00:00 2001 From: yunfan Date: Sat, 19 Jan 2019 16:26:39 +0800 Subject: [PATCH] - revert batch --- fastNLP/core/batch.py | 491 ++++-------------------------------------- 1 file changed, 42 insertions(+), 449 deletions(-) diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 9dbf9604..d4fcbf23 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -1,72 +1,63 @@ import numpy as np -import random import torch -import torch.multiprocessing as multiprocessing -from torch.utils.data.dataloader import _set_worker_signal_handlers, _update_worker_pids, \ - _remove_worker_pids, _error_if_any_worker_fails -import signal -import sys -import threading -import traceback -import os -from torch._six import FileNotFoundError from fastNLP.core.sampler import RandomSampler + class Batch(object): - def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False, num_workers=0, pin_memory=False, - timeout=0.0, keep_process=False): - """ - Batch is an iterable object which iterates over mini-batches. + """Batch is an iterable object which iterates over mini-batches. + + Example:: - Example:: - iterator = Batch(data_set, batch_size=16, sampler=SequentialSampler()) - for epoch in range(num_epochs): - for batch_x, batch_y in iterator: # 每次epoch会重新使用sampler生成index的。 - # ... + for batch_x, batch_y in Batch(data_set, batch_size=16, sampler=SequentialSampler()): + # ... - :param DataSet dataset: a DataSet object - :param int batch_size: the size of the batch - :param Sampler sampler: a Sampler object - :param bool as_numpy: If True, return Numpy array. Otherwise, return torch tensors. - :param num_workers: int, 使用多少个进程来准备数据。默认为0, 即使用主线程生成数据。 特性处于实验阶段,谨慎使用。 - 如果DataSet较大,且每个batch的准备时间很短,使用多进程可能并不能提速。 - :param pin_memory: bool, 默认为False. 设置为True时,有可能可以节省tensor从cpu移动到gpu的阻塞时间。 - :param timeout: float, 大于0的数,只有在num_workers>0时才有用。超过该时间仍然没有获取到一个batch则报错,可以用于 - 检测是否出现了batch产生阻塞的情况。 - :param keep_process: bool. 默认为False,该参数只在多进程下有效。在多进程的情况下,反复产生batch的iterator会导致 - 不断创建、销毁进程,可能对速度有一定的影响。当keep_process为True时,直到Batch对象被删除之前,多进程都没有关 - 闭。如果设置了keep_process为True,可以通过del BatchObject来删除Batch对象并关闭进程。 - """ + :param DataSet dataset: a DataSet object + :param int batch_size: the size of the batch + :param Sampler sampler: a Sampler object + :param bool as_numpy: If True, return Numpy array. Otherwise, return torch tensors. - if num_workers < 0: - raise ValueError('num_workers option cannot be negative; ' - 'use num_workers=0 to disable multiprocessing.') - if timeout < 0: - raise ValueError('timeout option should be non-negative') + """ + def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False): self.dataset = dataset self.batch_size = batch_size self.sampler = sampler - self.num_workers = num_workers - self.keep_process = keep_process - self.pin_memory = pin_memory - self.timeout = timeout self.as_numpy = as_numpy + self.idx_list = None + self.curidx = 0 self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0) self.cur_batch_indices = None - self._data_iterator = None def __iter__(self): - if self._data_iterator is not None: - # 重新设置index_list - self._data_iterator.reset() - return self._data_iterator - elif self.keep_process and self.num_workers>0: - self._data_iterator = _DataLoaderIter(self) - return self._data_iterator - else: # 大多数情况是这个 - return _DataLoaderIter(self) + self.idx_list = self.sampler(self.dataset) + self.curidx = 0 + self.lengths = self.dataset.get_length() + return self + + def __next__(self): + if self.curidx >= len(self.idx_list): + raise StopIteration + else: + endidx = min(self.curidx + self.batch_size, len(self.idx_list)) + batch_x, batch_y = {}, {} + + indices = self.idx_list[self.curidx:endidx] + self.cur_batch_indices = indices + + for field_name, field in self.dataset.get_all_fields().items(): + if field.is_target or field.is_input: + batch = field.get(indices) + if not self.as_numpy and field.padder is not None: + batch = to_tensor(batch, field.dtype) + if field.is_target: + batch_y[field_name] = batch + if field.is_input: + batch_x[field_name] = batch + + self.curidx = endidx + + return batch_x, batch_y def __len__(self): return self.num_batches @@ -74,11 +65,6 @@ class Batch(object): def get_batch_indices(self): return self.cur_batch_indices - def __del__(self): - if self.keep_process is True: - del self._data_iterator - - def to_tensor(batch, dtype): try: @@ -89,396 +75,3 @@ def to_tensor(batch, dtype): except: pass return batch - - -""" -由于多进程涉及到大量问题,包括系统、安全关闭进程等。所以这里直接从pytorch的官方版本修改DataLoader实现多进程加速 -""" - -IS_WINDOWS = sys.platform == "win32" -if IS_WINDOWS: - import ctypes - from ctypes.wintypes import DWORD, BOOL, HANDLE - -if sys.version_info[0] == 2: - import Queue as queue -else: - import queue - - -class ExceptionWrapper(object): - r"""Wraps an exception plus traceback to communicate across threads""" - - def __init__(self, exc_info): - self.exc_type = exc_info[0] - self.exc_msg = "".join(traceback.format_exception(*exc_info)) - - -_use_shared_memory = False -r"""Whether to use shared memory in default_collate""" - -MANAGER_STATUS_CHECK_INTERVAL = 5.0 - -if IS_WINDOWS: - # On Windows, the parent ID of the worker process remains unchanged when the manager process - # is gone, and the only way to check it through OS is to let the worker have a process handle - # of the manager and ask if the process status has changed. - class ManagerWatchdog(object): - def __init__(self): - self.manager_pid = os.getppid() - - self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True) - self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD) - self.kernel32.OpenProcess.restype = HANDLE - self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD) - self.kernel32.WaitForSingleObject.restype = DWORD - - # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx - SYNCHRONIZE = 0x00100000 - self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid) - - if not self.manager_handle: - raise ctypes.WinError(ctypes.get_last_error()) - - def is_alive(self): - # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx - return self.kernel32.WaitForSingleObject(self.manager_handle, 0) != 0 -else: - class ManagerWatchdog(object): - def __init__(self): - self.manager_pid = os.getppid() - - def is_alive(self): - return os.getppid() == self.manager_pid - - -def _worker_loop(dataset, index_queue, data_queue, seed, worker_id, as_numpy): - # 产生数据的循环 - global _use_shared_memory - _use_shared_memory = True - - # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal - # module's handlers are executed after Python returns from C low-level - # handlers, likely when the same fatal signal happened again already. - # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1 - _set_worker_signal_handlers() - - torch.set_num_threads(1) - random.seed(seed) - torch.manual_seed(seed) - - watchdog = ManagerWatchdog() - - while True: - try: - # 获取当前batch计数,当前batch的indexes - r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL) - except queue.Empty: - if watchdog.is_alive(): - continue - else: - break - if r is None: - break - idx, batch_indices = r - try: - # 获取相应的batch数据。这里需要修改为从dataset中取出数据并且完成padding - samples = _get_batch_from_dataset(dataset, batch_indices, as_numpy) - except Exception: - data_queue.put((idx, ExceptionWrapper(sys.exc_info()), batch_indices)) - else: - data_queue.put((idx, samples, batch_indices)) - del samples - -def _get_batch_from_dataset(dataset, indices, as_numpy): - """ - 给定indices,从DataSet中取出(batch_x, batch_y). 数据从这里产生后,若没有pin_memory, 则直接传递给Trainer了,如果存在 - pin_memory还会经过一道pin_memory()的处理 - :param dataset: fastNLP.DataSet对象 - :param indices: List[int], index - :param as_numpy: bool, 是否只是转换为numpy - :return: (batch_x, batch_y) - """ - batch_x, batch_y = {}, {} - for field_name, field in dataset.get_all_fields().items(): - if field.is_target or field.is_input: - batch = field.get(indices) - if not as_numpy and field.padder is not None: - batch = to_tensor(batch, field.dtype) - if field.is_target: - batch_y[field_name] = batch - if field.is_input: - batch_x[field_name] = batch - - return batch_x, batch_y - - -def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id): - # 将数据送入到指定的query中. 即如果需要pin_memory, 则 - if pin_memory: - torch.cuda.set_device(device_id) - - while True: - try: - r = in_queue.get() - except Exception: - if done_event.is_set(): - return - raise - if r is None: - break - if isinstance(r[1], ExceptionWrapper): - out_queue.put(r) - continue - idx, batch, batch_indices = r - try: - if pin_memory: - batch = pin_memory_batch(batch) - except Exception: - out_queue.put((idx, ExceptionWrapper(sys.exc_info()), batch_indices)) - else: - out_queue.put((idx, batch, batch_indices)) - - -def pin_memory_batch(batchs): - """ - - :param batchs: (batch_x, batch_y) - :return: (batch_x, batch_y) - """ - for batch_dict in batchs: - for field_name, batch in batch_dict.items(): - if isinstance(batch, torch.Tensor): - batch_dict[field_name] = batch.pin_memory() - return batchs - - -_SIGCHLD_handler_set = False -r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one -handler needs to be set for all DataLoaders in a process.""" - - -def _set_SIGCHLD_handler(): - # Windows doesn't support SIGCHLD handler - if sys.platform == 'win32': - return - # can't set signal in child threads - if not isinstance(threading.current_thread(), threading._MainThread): - return - global _SIGCHLD_handler_set - if _SIGCHLD_handler_set: - return - previous_handler = signal.getsignal(signal.SIGCHLD) - if not callable(previous_handler): - previous_handler = None - - def handler(signum, frame): - # This following call uses `waitid` with WNOHANG from C side. Therefore, - # Python can still get and update the process status successfully. - _error_if_any_worker_fails() - if previous_handler is not None: - previous_handler(signum, frame) - - signal.signal(signal.SIGCHLD, handler) - _SIGCHLD_handler_set = True - - -class _DataLoaderIter(object): - r"""Iterates once over the DataLoader's dataset, as specified by the sampler""" - - def __init__(self, batcher): - self.batcher = batcher - self.dataset = batcher.dataset - self.sampler = batcher.sampler - self.as_numpy = batcher.as_numpy - self.batch_size = batcher.batch_size - self.num_workers = batcher.num_workers - self.pin_memory = batcher.pin_memory and torch.cuda.is_available() - self.timeout = batcher.timeout - self.keep_process = batcher.keep_process - self.done_event = threading.Event() - self.curidx = 0 - self.idx_list = self.sampler(self.dataset) - - # self.sample_iter一次返回一个index. 可以通过其他方式替代 - - base_seed = torch.LongTensor(1).random_().item() - - if self.num_workers > 0: - # 每个worker建立一个index queue - self.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)] - self.worker_queue_idx = 0 - # 存放获取到的batch - self.worker_result_queue = multiprocessing.SimpleQueue() - self.batches_outstanding = 0 - self.worker_pids_set = False - self.shutdown = False - self.send_idx = 0 - self.rcvd_idx = 0 - self.reorder_dict = {} - - # 这里会将batch的数据输送到self.worker_result_queue中,但是还没有送入到device中 - self.workers = [ - multiprocessing.Process( - target=_worker_loop, - args=(self.dataset, self.index_queues[i], - self.worker_result_queue, base_seed + i, i, self.as_numpy)) - for i in range(self.num_workers)] - - # self.data_queue取数据就行。如果有pin_memory的话,会把数据放到另一个queue - if self.pin_memory or self.timeout > 0: - self.data_queue = queue.Queue() - if self.pin_memory: - maybe_device_id = torch.cuda.current_device() - else: - # do not initialize cuda context if not necessary - maybe_device_id = None - self.worker_manager_thread = threading.Thread( - target=_worker_manager_loop, - args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory, - maybe_device_id)) - self.worker_manager_thread.daemon = True - self.worker_manager_thread.start() - else: - self.data_queue = self.worker_result_queue - - # worker们开始工作 - for w in self.workers: - w.daemon = True # ensure that the worker exits on process exit - w.start() - - _update_worker_pids(id(self), tuple(w.pid for w in self.workers)) - _set_SIGCHLD_handler() - self.worker_pids_set = True - - # prime the prefetch loop - for _ in range(2 * self.num_workers): - self._put_indices() - - def reset(self): - """ - 重置curidx以及重新采样idx_list. 只有再需要keep_process时才有用 - :return: - """ - if self.keep_process: - self.curidx = 0 - self.idx_list = self.sampler(self.dataset) - for _ in range(2 * self.num_workers): - self._put_indices() - - def _get_batch(self): - if self.timeout > 0: - try: - return self.data_queue.get(timeout=self.timeout) - except queue.Empty: - raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout)) - else: - return self.data_queue.get() - - def __next__(self): - if self.num_workers == 0: # same-process loading - if self.curidx >= len(self.idx_list): - raise StopIteration - endidx = min(self.curidx + self.batch_size, len(self.idx_list)) - # 直接从数据集中采集数据即可 - indices = self.idx_list[self.curidx:endidx] - self.batcher.cur_batch_indices = indices - batch_x, batch_y = _get_batch_from_dataset(dataset=self.dataset, indices=indices, - as_numpy=self.as_numpy) - if self.pin_memory: - batch_x, batch_y = pin_memory_batch((batch_x, batch_y)) - self.curidx = endidx - return batch_x, batch_y - - # check if the next sample has already been generated - if self.rcvd_idx in self.reorder_dict: - batch = self.reorder_dict.pop(self.rcvd_idx) - return self._process_next_batch(batch) - - # 如果生成的数据为0了,则停止 - if self.batches_outstanding == 0: - if not self.keep_process: - self._shutdown_workers() - raise StopIteration - - while True: - assert (not self.shutdown and self.batches_outstanding > 0) - idx, batch, batch_indices = self._get_batch() - self.batches_outstanding -= 1 - if idx != self.rcvd_idx: - # store out-of-order samples - self.reorder_dict[idx] = batch - continue - self.batcher.cur_batch_indices = batch_indices - return self._process_next_batch(batch) - - def __iter__(self): - self.curidx = 0 - - return self - - def _put_indices(self): - # 向采集数据的index queue中放入index - assert self.batches_outstanding < 2 * self.num_workers - if self.curidx >= len(self.idx_list): - indices = None - else: - endidx = min(self.curidx + self.batch_size, len(self.idx_list)) - # 直接从数据集中采集数据即可 - indices = self.idx_list[self.curidx:endidx] - if indices is None: - return - self.index_queues[self.worker_queue_idx].put((self.send_idx, indices)) - self.curidx = endidx - self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers - self.batches_outstanding += 1 - self.send_idx += 1 - - def _process_next_batch(self, batch): - # 只是提醒生成下一个batch indice数据 - self.rcvd_idx += 1 - self._put_indices() - if isinstance(batch, ExceptionWrapper): - raise batch.exc_type(batch.exc_msg) - return batch - - def __getstate__(self): - # TODO: add limited pickling support for sharing an iterator - # across multiple threads for HOGWILD. - # Probably the best way to do this is by moving the sample pushing - # to a separate thread and then just sharing the data queue - # but signalling the end is tricky without a non-blocking API - raise NotImplementedError("_DataLoaderIter cannot be pickled") - - def _shutdown_workers(self): - try: - if not self.shutdown: - self.shutdown = True - self.done_event.set() - for q in self.index_queues: - q.put(None) - # if some workers are waiting to put, make place for them - try: - while not self.worker_result_queue.empty(): - self.worker_result_queue.get() - except (FileNotFoundError, ImportError): - # Many weird errors can happen here due to Python - # shutting down. These are more like obscure Python bugs. - # FileNotFoundError can happen when we rebuild the fd - # fetched from the queue but the socket is already closed - # from the worker side. - # ImportError can happen when the unpickler loads the - # resource from `get`. - pass - # done_event should be sufficient to exit worker_manager_thread, - # but be safe here and put another None - self.worker_result_queue.put(None) - finally: - # removes pids no matter what - if self.worker_pids_set: - _remove_worker_pids(id(self)) - self.worker_pids_set = False - - def __del__(self): - if self.num_workers > 0: - self._shutdown_workers() \ No newline at end of file