|
|
@@ -56,8 +56,8 @@ class Batch(object): |
|
|
|
indices = self.idx_list[self.curidx:endidx] |
|
|
|
|
|
|
|
for field_name, field in self.dataset.get_fields(): |
|
|
|
batch = field.get(indices) |
|
|
|
if not field.tensorable: #TODO 修改 |
|
|
|
batch = torch.from_numpy(field.get(indices)) |
|
|
|
if not field.need_tensor: #TODO 修改 |
|
|
|
pass |
|
|
|
elif field.is_target: |
|
|
|
batch_y[field_name] = batch |
|
|
|