|
@@ -51,34 +51,20 @@ class Batch(object): |
|
|
raise StopIteration |
|
|
raise StopIteration |
|
|
else: |
|
|
else: |
|
|
endidx = min(self.curidx + self.batch_size, len(self.idx_list)) |
|
|
endidx = min(self.curidx + self.batch_size, len(self.idx_list)) |
|
|
batch_idxes = self.idx_list[self.curidx: endidx] |
|
|
|
|
|
padding_length = {field_name: max([field_length[idx] for idx in batch_idxes]) |
|
|
|
|
|
for field_name, field_length in self.lengths.items()} |
|
|
|
|
|
batch_x, batch_y = defaultdict(list), defaultdict(list) |
|
|
|
|
|
|
|
|
|
|
|
# transform index to tensor and do padding for sequences |
|
|
|
|
|
batch = [] |
|
|
|
|
|
for idx in batch_idxes: |
|
|
|
|
|
x, y = self.dataset.to_tensor(idx, padding_length) |
|
|
|
|
|
batch.append((self.lengths[self.sort_key][idx] if self.sort_in_batch else None, x, y)) |
|
|
|
|
|
|
|
|
|
|
|
if self.sort_in_batch: |
|
|
|
|
|
batch = sorted(batch, key=lambda x: x[0], reverse=True) |
|
|
|
|
|
|
|
|
|
|
|
for _, x, y in batch: |
|
|
|
|
|
for name, tensor in x.items(): |
|
|
|
|
|
batch_x[name].append(tensor) |
|
|
|
|
|
for name, tensor in y.items(): |
|
|
|
|
|
batch_y[name].append(tensor) |
|
|
|
|
|
|
|
|
|
|
|
# combine instances to form a batch |
|
|
|
|
|
for batch in (batch_x, batch_y): |
|
|
|
|
|
for name, tensor_list in batch.items(): |
|
|
|
|
|
if self.use_cuda: |
|
|
|
|
|
batch[name] = torch.stack(tensor_list, dim=0).cuda() |
|
|
|
|
|
else: |
|
|
|
|
|
batch[name] = torch.stack(tensor_list, dim=0) |
|
|
|
|
|
|
|
|
batch_x, batch_y = {}, {} |
|
|
|
|
|
|
|
|
|
|
|
indices = self.idx_list[self.curidx:endidx] |
|
|
|
|
|
|
|
|
|
|
|
for field_name, field in self.dataset.get_fields(): |
|
|
|
|
|
batch = field.get(indices) |
|
|
|
|
|
if not field.tensorable: #TODO 修改 |
|
|
|
|
|
pass |
|
|
|
|
|
elif field.is_target: |
|
|
|
|
|
batch_y[field_name] = batch |
|
|
|
|
|
else: |
|
|
|
|
|
batch_x[field_name] = batch |
|
|
|
|
|
|
|
|
self.curidx = endidx |
|
|
self.curidx = endidx |
|
|
|
|
|
|
|
|
return batch_x, batch_y |
|
|
return batch_x, batch_y |
|
|
|
|
|
|