Browse Source

fix & update batch

Add support for sorted batch output, can be useful when using RNN in Pytorch with `pack_padded_sequence` & `pad_packed_sequence`
tags/v0.2.0
Yunfan Shao GitHub 6 years ago
parent
commit
f40dc2e6fa
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 13 additions and 3 deletions
  1. +13
    -3
      fastNLP/core/batch.py

+ 13
- 3
fastNLP/core/batch.py View File

@@ -11,7 +11,7 @@ class Batch(object):

"""

def __init__(self, dataset, batch_size, sampler, use_cuda):
def __init__(self, dataset, batch_size, sampler, use_cuda, sort_in_batch=False, sort_key=None):
"""

:param dataset: a DataSet object
@@ -24,6 +24,8 @@ class Batch(object):
self.batch_size = batch_size
self.sampler = sampler
self.use_cuda = use_cuda
self.sort_in_batch = sort_in_batch
self.sort_key = sort_key if sort_key is not None else 'word_seq'
self.idx_list = None
self.curidx = 0

@@ -49,13 +51,21 @@ class Batch(object):
raise StopIteration
else:
endidx = min(self.curidx + self.batch_size, len(self.idx_list))
padding_length = {field_name: max(field_length[self.curidx: endidx])
batch_idxes = self.idx_list[self.curidx: endidx]
padding_length = {field_name: max([field_length[idx] for idx in batch_idxes])
for field_name, field_length in self.lengths.items()}
batch_x, batch_y = defaultdict(list), defaultdict(list)

# transform index to tensor and do padding for sequences
for idx in range(self.curidx, endidx):
batch = []
for idx in batch_idxes:
x, y = self.dataset.to_tensor(idx, padding_length)
batch.append((self.lengths[self.sort_key][idx] if self.sort_in_batch else None, x, y))

if self.sort_in_batch:
batch = sorted(batch, key=lambda x: x[0], reverse=True)

for _, x, y in batch:
for name, tensor in x.items():
batch_x[name].append(tensor)
for name, tensor in y.items():


Loading…
Cancel
Save