|
|
@@ -88,10 +88,11 @@ class DataSet(object): |
|
|
|
assert name in self.field_arrays |
|
|
|
self.field_arrays[name].append(field) |
|
|
|
|
|
|
|
def add_field(self, name, fields, need_tensor=False, is_target=False): |
|
|
|
def add_field(self, name, fields, padding_val=0, need_tensor=False, is_target=False): |
|
|
|
if len(self.field_arrays) != 0: |
|
|
|
assert len(self) == len(fields) |
|
|
|
self.field_arrays[name] = FieldArray(name, fields, |
|
|
|
padding_val=padding_val, |
|
|
|
need_tensor=need_tensor, |
|
|
|
is_target=is_target) |
|
|
|
|
|
|
@@ -104,6 +105,16 @@ class DataSet(object): |
|
|
|
def __getitem__(self, name): |
|
|
|
if isinstance(name, int): |
|
|
|
return self.Instance(self, idx=name) |
|
|
|
elif isinstance(name, slice): |
|
|
|
ds = DataSet() |
|
|
|
for field in self.field_arrays.values(): |
|
|
|
ds.add_field(name=field.name, |
|
|
|
fields=field.content[name], |
|
|
|
padding_val=field.padding_val, |
|
|
|
need_tensor=field.need_tensor, |
|
|
|
is_target=field.is_target) |
|
|
|
return ds |
|
|
|
|
|
|
|
elif isinstance(name, str): |
|
|
|
return self.field_arrays[name] |
|
|
|
else: |
|
|
@@ -187,7 +198,15 @@ class DataSet(object): |
|
|
|
for ins in self: |
|
|
|
results.append(func(ins)) |
|
|
|
if new_field_name is not None: |
|
|
|
self.add_field(new_field_name, results) |
|
|
|
if new_field_name in self.field_arrays: |
|
|
|
# overwrite the field, keep same attributes |
|
|
|
old_field = self.field_arrays[new_field_name] |
|
|
|
padding_val = old_field.padding_val |
|
|
|
need_tensor = old_field.need_tensor |
|
|
|
is_target = old_field.is_target |
|
|
|
self.add_field(new_field_name, results, padding_val, need_tensor, is_target) |
|
|
|
else: |
|
|
|
self.add_field(new_field_name, results) |
|
|
|
else: |
|
|
|
return results |
|
|
|
|
|
|
|