Browse Source

- fix batch & vocab

tags/v0.4.10
yunfan 6 years ago
parent
commit
443184f82e
2 changed files with 56 additions and 46 deletions
  1. +53
    -45
      fastNLP/core/batch.py
  2. +3
    -1
      fastNLP/core/vocabulary.py

+ 53
- 45
fastNLP/core/batch.py View File

@@ -9,6 +9,7 @@ import atexit

from .sampler import RandomSampler, Sampler
import torch.multiprocessing as mp
from queue import Empty, Full

_python_is_exit = False

@@ -92,7 +93,7 @@ class Batch(object):
:return:
"""
if self.prefetch:
return _run_batch_iter(self)
return self._run_batch_iter(self)
def batch_iter():
self.init_iter()
@@ -120,6 +121,57 @@ class Batch(object):
"""
return self.cur_batch_indices

@staticmethod
def _run_fetch(batch, q):
try:
global _python_is_exit
batch.init_iter()
# print('start fetch')
while 1:
res = batch.fetch_one()
# print('fetch one')
while 1:
try:
q.put(res, timeout=3)
break
except Full:
if _python_is_exit:
return
if res is None:
# print('fetch done, waiting processing')
break
# print('fetch exit')
except Exception as e:
q.put(e)
finally:
q.join()

@staticmethod
def _run_batch_iter(batch):
q = mp.JoinableQueue(maxsize=10)
fetch_p = mp.Process(target=Batch._run_fetch, args=(batch, q))
fetch_p.daemon = True
fetch_p.start()
# print('fork fetch process')
while 1:
try:
res = q.get(timeout=1)
q.task_done()
# print('get fetched')
if res is None:
break
elif isinstance(res, Exception):
raise res
yield res
except Empty as e:
if fetch_p.is_alive():
continue
else:
break
fetch_p.terminate()
fetch_p.join()
# print('iter done')


def _to_tensor(batch, dtype):
try:
@@ -131,47 +183,3 @@ def _to_tensor(batch, dtype):
pass
return batch


def _run_fetch(batch, q):
global _python_is_exit
batch.init_iter()
# print('start fetch')
while 1:
res = batch.fetch_one()
# print('fetch one')
while 1:
try:
q.put(res, timeout=3)
break
except:
if _python_is_exit:
return
if res is None:
# print('fetch done, waiting processing')
q.join()
break
# print('fetch exit')


def _run_batch_iter(batch):
q = mp.JoinableQueue(maxsize=10)
fetch_p = mp.Process(target=_run_fetch, args=(batch, q))
fetch_p.daemon = True
fetch_p.start()
# print('fork fetch process')
while 1:
try:
res = q.get(timeout=1)
q.task_done()
# print('get fetched')
if res is None:
break
yield res
except Exception as e:
if fetch_p.is_alive():
continue
else:
break
fetch_p.terminate()
fetch_p.join()
# print('iter done')

+ 3
- 1
fastNLP/core/vocabulary.py View File

@@ -110,7 +110,8 @@ class Vocabulary(object):
但已经记录在词典中的词, 不会改变对应的 `int`

"""
self.word2idx = {}
if self.word2idx is None:
self.word2idx = {}
if self.padding is not None:
self.word2idx[self.padding] = len(self.word2idx)
if self.unknown is not None:
@@ -316,6 +317,7 @@ class Vocabulary(object):
"""Use to prepare data for pickle.

"""
len(self) # make sure vocab has been built
state = self.__dict__.copy()
# no need to pickle idx2word as it can be constructed from word2idx
del state['idx2word']


Loading…
Cancel
Save