|
|
@@ -365,6 +365,15 @@ class DataSet(object): |
|
|
|
if idx not in self: |
|
|
|
raise KeyError("No such field called {} in DataSet.".format(idx)) |
|
|
|
return self.field_arrays[idx] |
|
|
|
elif isinstance(idx, list): |
|
|
|
dataset = DataSet() |
|
|
|
for i in idx: |
|
|
|
assert isinstance(i, int), "Only int index allowed." |
|
|
|
instance = self[i] |
|
|
|
dataset.append(instance) |
|
|
|
for field_name, field in self.field_arrays.items(): |
|
|
|
dataset.field_arrays[field_name].to(field) |
|
|
|
return dataset |
|
|
|
else: |
|
|
|
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) |
|
|
|
|
|
|
|