|
|
@@ -11,7 +11,7 @@ class Batch(object): |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, dataset, batch_size, sampler, use_cuda): |
|
|
|
def __init__(self, dataset, batch_size, sampler, use_cuda, sort_in_batch=False, sort_key=None): |
|
|
|
""" |
|
|
|
|
|
|
|
:param dataset: a DataSet object |
|
|
@@ -24,6 +24,8 @@ class Batch(object): |
|
|
|
self.batch_size = batch_size |
|
|
|
self.sampler = sampler |
|
|
|
self.use_cuda = use_cuda |
|
|
|
self.sort_in_batch = sort_in_batch |
|
|
|
self.sort_key = sort_key if sort_key is not None else 'word_seq' |
|
|
|
self.idx_list = None |
|
|
|
self.curidx = 0 |
|
|
|
|
|
|
@@ -49,13 +51,21 @@ class Batch(object): |
|
|
|
raise StopIteration |
|
|
|
else: |
|
|
|
endidx = min(self.curidx + self.batch_size, len(self.idx_list)) |
|
|
|
padding_length = {field_name: max(field_length[self.curidx: endidx]) |
|
|
|
batch_idxes = self.idx_list[self.curidx: endidx] |
|
|
|
padding_length = {field_name: max([field_length[idx] for idx in batch_idxes]) |
|
|
|
for field_name, field_length in self.lengths.items()} |
|
|
|
batch_x, batch_y = defaultdict(list), defaultdict(list) |
|
|
|
|
|
|
|
# transform index to tensor and do padding for sequences |
|
|
|
for idx in range(self.curidx, endidx): |
|
|
|
batch = [] |
|
|
|
for idx in batch_idxes: |
|
|
|
x, y = self.dataset.to_tensor(idx, padding_length) |
|
|
|
batch.append((self.lengths[self.sort_key][idx] if self.sort_in_batch else None, x, y)) |
|
|
|
|
|
|
|
if self.sort_in_batch: |
|
|
|
batch = sorted(batch, key=lambda x: x[0], reverse=True) |
|
|
|
|
|
|
|
for _, x, y in batch: |
|
|
|
for name, tensor in x.items(): |
|
|
|
batch_x[name].append(tensor) |
|
|
|
for name, tensor in y.items(): |
|
|
|