|
|
@@ -30,21 +30,25 @@ class DataSet(object): |
|
|
|
def __init__(self, dataset, idx=-1): |
|
|
|
self.dataset = dataset |
|
|
|
self.idx = idx |
|
|
|
self.fields = None |
|
|
|
|
|
|
|
def __next__(self): |
|
|
|
self.idx += 1 |
|
|
|
if self.idx >= len(self.dataset): |
|
|
|
try: |
|
|
|
self.fields = {name: field[self.idx] for name, field in self.dataset.get_fields().items()} |
|
|
|
except IndexError: |
|
|
|
raise StopIteration |
|
|
|
return self |
|
|
|
|
|
|
|
def __getitem__(self, name): |
|
|
|
return self.dataset[name][self.idx] |
|
|
|
return self.fields[name] |
|
|
|
|
|
|
|
def __setitem__(self, name, val): |
|
|
|
if name not in self.dataset: |
|
|
|
new_fields = [None] * len(self.dataset) |
|
|
|
self.dataset.add_field(name, new_fields) |
|
|
|
self.dataset[name][self.idx] = val |
|
|
|
self.fields[name] = val |
|
|
|
|
|
|
|
def __repr__(self): |
|
|
|
return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name |
|
|
@@ -163,9 +167,8 @@ class DataSet(object): |
|
|
|
self.field_arrays[new_name] = self.field_arrays.pop(old_name) |
|
|
|
else: |
|
|
|
raise KeyError("{} is not a valid name. ".format(old_name)) |
|
|
|
return self |
|
|
|
|
|
|
|
def set_is_target(self, **fields): |
|
|
|
def set_target(self, **fields): |
|
|
|
"""Change the flag of `is_target` for all instance. For fields not set here, leave their `is_target` unchanged. |
|
|
|
|
|
|
|
:param key-value pairs for field-name and `is_target` value(True, False). |
|
|
@@ -176,9 +179,20 @@ class DataSet(object): |
|
|
|
self.field_arrays[name].is_target = val |
|
|
|
else: |
|
|
|
raise KeyError("{} is not a valid field name.".format(name)) |
|
|
|
self._set_need_tensor(**fields) |
|
|
|
return self |
|
|
|
|
|
|
|
def set_input(self, **fields): |
|
|
|
for name, val in fields.items(): |
|
|
|
if name in self.field_arrays: |
|
|
|
assert isinstance(val, bool) |
|
|
|
self.field_arrays[name].is_target = not val |
|
|
|
else: |
|
|
|
raise KeyError("{} is not a valid field name.".format(name)) |
|
|
|
self._set_need_tensor(**fields) |
|
|
|
return self |
|
|
|
|
|
|
|
def set_need_tensor(self, **kwargs): |
|
|
|
def _set_need_tensor(self, **kwargs): |
|
|
|
for name, val in kwargs.items(): |
|
|
|
if name in self.field_arrays: |
|
|
|
assert isinstance(val, bool) |
|
|
|