From f40dc2e6fa2639e8d9ab6a491c9042c86da81464 Mon Sep 17 00:00:00 2001 From: Yunfan Shao Date: Sun, 4 Nov 2018 16:29:51 +0800 Subject: [PATCH] fix & update batch Add support for sorted batch output, can be useful when using RNN in Pytorch with `pack_padded_sequence` & `pad_packed_sequence` --- fastNLP/core/batch.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index bf837d0f..b55ae3dd 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -11,7 +11,7 @@ class Batch(object): """ - def __init__(self, dataset, batch_size, sampler, use_cuda): + def __init__(self, dataset, batch_size, sampler, use_cuda, sort_in_batch=False, sort_key=None): """ :param dataset: a DataSet object @@ -24,6 +24,8 @@ class Batch(object): self.batch_size = batch_size self.sampler = sampler self.use_cuda = use_cuda + self.sort_in_batch = sort_in_batch + self.sort_key = sort_key if sort_key is not None else 'word_seq' self.idx_list = None self.curidx = 0 @@ -49,13 +51,21 @@ class Batch(object): raise StopIteration else: endidx = min(self.curidx + self.batch_size, len(self.idx_list)) - padding_length = {field_name: max(field_length[self.curidx: endidx]) + batch_idxes = self.idx_list[self.curidx: endidx] + padding_length = {field_name: max([field_length[idx] for idx in batch_idxes]) for field_name, field_length in self.lengths.items()} batch_x, batch_y = defaultdict(list), defaultdict(list) # transform index to tensor and do padding for sequences - for idx in range(self.curidx, endidx): + batch = [] + for idx in batch_idxes: x, y = self.dataset.to_tensor(idx, padding_length) + batch.append((self.lengths[self.sort_key][idx] if self.sort_in_batch else None, x, y)) + + if self.sort_in_batch: + batch = sorted(batch, key=lambda x: x[0], reverse=True) + + for _, x, y in batch: for name, tensor in x.items(): batch_x[name].append(tensor) for name, tensor in y.items():