From 443184f82e58af8940309b73b8ee6b2619ca679a Mon Sep 17 00:00:00 2001 From: yunfan Date: Mon, 6 May 2019 14:55:40 +0800 Subject: [PATCH] - fix batch & vocab --- fastNLP/core/batch.py | 98 +++++++++++++++++++++----------------- fastNLP/core/vocabulary.py | 4 +- 2 files changed, 56 insertions(+), 46 deletions(-) diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 4af6d651..235a9a3a 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -9,6 +9,7 @@ import atexit from .sampler import RandomSampler, Sampler import torch.multiprocessing as mp +from queue import Empty, Full _python_is_exit = False @@ -92,7 +93,7 @@ class Batch(object): :return: """ if self.prefetch: - return _run_batch_iter(self) + return self._run_batch_iter(self) def batch_iter(): self.init_iter() @@ -120,6 +121,57 @@ class Batch(object): """ return self.cur_batch_indices + @staticmethod + def _run_fetch(batch, q): + try: + global _python_is_exit + batch.init_iter() + # print('start fetch') + while 1: + res = batch.fetch_one() + # print('fetch one') + while 1: + try: + q.put(res, timeout=3) + break + except Full: + if _python_is_exit: + return + if res is None: + # print('fetch done, waiting processing') + break + # print('fetch exit') + except Exception as e: + q.put(e) + finally: + q.join() + + @staticmethod + def _run_batch_iter(batch): + q = mp.JoinableQueue(maxsize=10) + fetch_p = mp.Process(target=Batch._run_fetch, args=(batch, q)) + fetch_p.daemon = True + fetch_p.start() + # print('fork fetch process') + while 1: + try: + res = q.get(timeout=1) + q.task_done() + # print('get fetched') + if res is None: + break + elif isinstance(res, Exception): + raise res + yield res + except Empty as e: + if fetch_p.is_alive(): + continue + else: + break + fetch_p.terminate() + fetch_p.join() + # print('iter done') + def _to_tensor(batch, dtype): try: @@ -131,47 +183,3 @@ def _to_tensor(batch, dtype): pass return batch - -def _run_fetch(batch, q): - global _python_is_exit - batch.init_iter() - # print('start fetch') - while 1: - res = batch.fetch_one() - # print('fetch one') - while 1: - try: - q.put(res, timeout=3) - break - except: - if _python_is_exit: - return - if res is None: - # print('fetch done, waiting processing') - q.join() - break - # print('fetch exit') - - -def _run_batch_iter(batch): - q = mp.JoinableQueue(maxsize=10) - fetch_p = mp.Process(target=_run_fetch, args=(batch, q)) - fetch_p.daemon = True - fetch_p.start() - # print('fork fetch process') - while 1: - try: - res = q.get(timeout=1) - q.task_done() - # print('get fetched') - if res is None: - break - yield res - except Exception as e: - if fetch_p.is_alive(): - continue - else: - break - fetch_p.terminate() - fetch_p.join() - # print('iter done') diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index c82c316e..0dc232e4 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -110,7 +110,8 @@ class Vocabulary(object): 但已经记录在词典中的词, 不会改变对应的 `int` """ - self.word2idx = {} + if self.word2idx is None: + self.word2idx = {} if self.padding is not None: self.word2idx[self.padding] = len(self.word2idx) if self.unknown is not None: @@ -316,6 +317,7 @@ class Vocabulary(object): """Use to prepare data for pickle. """ + len(self) # make sure vocab has been built state = self.__dict__.copy() # no need to pickle idx2word as it can be constructed from word2idx del state['idx2word']