From d9ac3344093e53b0ce3cbe2c25e45ffaa35b6a99 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Fri, 18 Jan 2019 23:33:19 +0800 Subject: [PATCH] =?UTF-8?q?=E5=87=8F=E5=B0=91batch=E4=B8=AD=E4=B8=8D?= =?UTF-8?q?=E6=96=AD=E5=88=9B=E5=BB=BA=E5=A4=9A=E8=BF=9B=E7=A8=8B=E7=9A=84?= =?UTF-8?q?=E5=BC=80=E9=94=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/batch.py | 50 ++++++++++++++++++++++++++++++++------- fastNLP/core/trainer.py | 9 ++++--- test/core/test_trainer.py | 31 +++++++++++++++++++++--- 3 files changed, 75 insertions(+), 15 deletions(-) diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 05bd5665..9dbf9604 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -15,24 +15,28 @@ from fastNLP.core.sampler import RandomSampler class Batch(object): def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False, num_workers=0, pin_memory=False, - timeout=0.0): + timeout=0.0, keep_process=False): """ Batch is an iterable object which iterates over mini-batches. Example:: - - for batch_x, batch_y in Batch(data_set, batch_size=16, sampler=SequentialSampler()): - # ... + iterator = Batch(data_set, batch_size=16, sampler=SequentialSampler()) + for epoch in range(num_epochs): + for batch_x, batch_y in iterator: # 每次epoch会重新使用sampler生成index的。 + # ... :param DataSet dataset: a DataSet object :param int batch_size: the size of the batch :param Sampler sampler: a Sampler object - :param bool as_numpy: If True, return Numpy array when possible. Otherwise, return torch tensors. + :param bool as_numpy: If True, return Numpy array. Otherwise, return torch tensors. :param num_workers: int, 使用多少个进程来准备数据。默认为0, 即使用主线程生成数据。 特性处于实验阶段,谨慎使用。 如果DataSet较大,且每个batch的准备时间很短,使用多进程可能并不能提速。 :param pin_memory: bool, 默认为False. 设置为True时,有可能可以节省tensor从cpu移动到gpu的阻塞时间。 :param timeout: float, 大于0的数,只有在num_workers>0时才有用。超过该时间仍然没有获取到一个batch则报错,可以用于 检测是否出现了batch产生阻塞的情况。 + :param keep_process: bool. 默认为False,该参数只在多进程下有效。在多进程的情况下,反复产生batch的iterator会导致 + 不断创建、销毁进程,可能对速度有一定的影响。当keep_process为True时,直到Batch对象被删除之前,多进程都没有关 + 闭。如果设置了keep_process为True,可以通过del BatchObject来删除Batch对象并关闭进程。 """ if num_workers < 0: @@ -45,15 +49,24 @@ class Batch(object): self.batch_size = batch_size self.sampler = sampler self.num_workers = num_workers + self.keep_process = keep_process self.pin_memory = pin_memory self.timeout = timeout self.as_numpy = as_numpy self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0) self.cur_batch_indices = None + self._data_iterator = None def __iter__(self): - # TODO 现在多线程的情况下每个循环都会重新创建多进程,开销可能有点大。可以考虑直接复用iterator. - return _DataLoaderIter(self) + if self._data_iterator is not None: + # 重新设置index_list + self._data_iterator.reset() + return self._data_iterator + elif self.keep_process and self.num_workers>0: + self._data_iterator = _DataLoaderIter(self) + return self._data_iterator + else: # 大多数情况是这个 + return _DataLoaderIter(self) def __len__(self): return self.num_batches @@ -61,6 +74,12 @@ class Batch(object): def get_batch_indices(self): return self.cur_batch_indices + def __del__(self): + if self.keep_process is True: + del self._data_iterator + + + def to_tensor(batch, dtype): try: if dtype in (int, np.int8, np.int16, np.int32, np.int64): @@ -276,6 +295,7 @@ class _DataLoaderIter(object): self.num_workers = batcher.num_workers self.pin_memory = batcher.pin_memory and torch.cuda.is_available() self.timeout = batcher.timeout + self.keep_process = batcher.keep_process self.done_event = threading.Event() self.curidx = 0 self.idx_list = self.sampler(self.dataset) @@ -335,6 +355,17 @@ class _DataLoaderIter(object): for _ in range(2 * self.num_workers): self._put_indices() + def reset(self): + """ + 重置curidx以及重新采样idx_list. 只有再需要keep_process时才有用 + :return: + """ + if self.keep_process: + self.curidx = 0 + self.idx_list = self.sampler(self.dataset) + for _ in range(2 * self.num_workers): + self._put_indices() + def _get_batch(self): if self.timeout > 0: try: @@ -366,7 +397,8 @@ class _DataLoaderIter(object): # 如果生成的数据为0了,则停止 if self.batches_outstanding == 0: - self._shutdown_workers() + if not self.keep_process: + self._shutdown_workers() raise StopIteration while True: @@ -449,4 +481,4 @@ class _DataLoaderIter(object): def __del__(self): if self.num_workers > 0: - self._shutdown_workers() + self._shutdown_workers() \ No newline at end of file diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 76a8562b..07d94d11 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -61,7 +61,8 @@ class Trainer(object): :param BaseSampler sampler: method used to generate batch data. :param num_workers: int, 使用多少个进程来准备数据。默认为0, 即使用主线程生成数据。 特性处于实验阶段,谨慎使用。 如果DataSet较大,且每个batch的准备时间很短,使用多进程可能并不能提速。 - :param pin_memory: bool, 默认为False. 设置为True时,有可能可以节省tensor从cpu移动到gpu的阻塞时间。 + :param pin_memory: bool, 默认为False. 当设置为True时,会使用锁页内存,可能导致内存占用变多。如果内存比较充足, + 可以考虑设置为True进行加速, 当pin_memory为True时,默认使用non_blocking=True的方式将数据从cpu移动到gpu。 :param timeout: float, 大于0的数,只有在num_workers>0时才有用。超过该时间仍然没有获取到一个batch则报错,可以用于 检测是否出现了batch产生阻塞的情况。 :param bool use_tqdm: whether to use tqdm to show train progress. @@ -246,7 +247,8 @@ class Trainer(object): with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: avg_loss = 0 data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, - num_workers=self.num_workers, pin_memory=self.pin_memory, timeout=self.timeout) + num_workers=self.num_workers, pin_memory=self.pin_memory, timeout=self.timeout, + keep_process=True) for epoch in range(1, self.n_epochs+1): pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) # early stopping @@ -255,7 +257,8 @@ class Trainer(object): indices = data_iterator.get_batch_indices() # negative sampling; replace unknown; re-weight batch_y self.callback_manager.before_batch(batch_x, batch_y, indices) - _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) + _move_dict_value_to_device(batch_x, batch_y, device=self._model_device, + non_blocking=self.pin_memory) # pin_memory, use non_blockling. prediction = self._data_forward(self.model, batch_x) # edit prediction diff --git a/test/core/test_trainer.py b/test/core/test_trainer.py index 624f2587..7c869633 100644 --- a/test/core/test_trainer.py +++ b/test/core/test_trainer.py @@ -237,6 +237,31 @@ class TrainerTestGround(unittest.TestCase): use_tqdm=False, print_every=2) - def test_case2(self): - # check metrics Wrong - data_set = prepare_fake_dataset2('x1', 'x2') + def test_trainer_multiprocess(self): + dataset = prepare_fake_dataset2('x1', 'x2') + dataset.set_input('x1', 'x2', 'y', flag=True) + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(5, 4) + + def forward(self, x1, x2, y): + x1 = self.fc(x1) + x2 = self.fc(x2) + x = x1 + x2 + loss = F.cross_entropy(x, y) + return {'loss': loss} + + model = Model() + trainer = Trainer( + train_data=dataset, + model=model, + use_tqdm=True, + print_every=2, + num_workers=2, + pin_memory=False, + timeout=0, + ) + trainer.train() +