diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index bf837d0f..b55ae3dd 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -11,7 +11,7 @@ class Batch(object): """ - def __init__(self, dataset, batch_size, sampler, use_cuda): + def __init__(self, dataset, batch_size, sampler, use_cuda, sort_in_batch=False, sort_key=None): """ :param dataset: a DataSet object @@ -24,6 +24,8 @@ class Batch(object): self.batch_size = batch_size self.sampler = sampler self.use_cuda = use_cuda + self.sort_in_batch = sort_in_batch + self.sort_key = sort_key if sort_key is not None else 'word_seq' self.idx_list = None self.curidx = 0 @@ -49,13 +51,21 @@ class Batch(object): raise StopIteration else: endidx = min(self.curidx + self.batch_size, len(self.idx_list)) - padding_length = {field_name: max(field_length[self.curidx: endidx]) + batch_idxes = self.idx_list[self.curidx: endidx] + padding_length = {field_name: max([field_length[idx] for idx in batch_idxes]) for field_name, field_length in self.lengths.items()} batch_x, batch_y = defaultdict(list), defaultdict(list) # transform index to tensor and do padding for sequences - for idx in range(self.curidx, endidx): + batch = [] + for idx in batch_idxes: x, y = self.dataset.to_tensor(idx, padding_length) + batch.append((self.lengths[self.sort_key][idx] if self.sort_in_batch else None, x, y)) + + if self.sort_in_batch: + batch = sorted(batch, key=lambda x: x[0], reverse=True) + + for _, x, y in batch: for name, tensor in x.items(): batch_x[name].append(tensor) for name, tensor in y.items():