Browse Source

[ENH] remove list and len constraints in ListData

pull/4/head
Gao Enhao 1 year ago
parent
commit
5887cbdb97
4 changed files with 30 additions and 47 deletions
  1. +1
    -1
      abl/bridge/simple_bridge.py
  2. +1
    -2
      abl/learning/abl_model.py
  3. +3
    -3
      abl/reasoning/reasoner.py
  4. +25
    -41
      abl/structures/list_data.py

+ 1
- 1
abl/bridge/simple_bridge.py View File

@@ -25,7 +25,7 @@ class SimpleBridge(BaseBridge):


def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]: def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]:
self.model.predict(data_samples) self.model.predict(data_samples)
return data_samples.pred_idx, data_samples.get("pred_prob", None)
return data_samples.pred_idx, data_samples.pred_prob


def abduce_pseudo_label( def abduce_pseudo_label(
self, self,


+ 1
- 2
abl/learning/abl_model.py View File

@@ -76,8 +76,7 @@ class ABLModel:
label = reform_list(label, data_samples.X) label = reform_list(label, data_samples.X)


data_samples.pred_idx = label data_samples.pred_idx = label
if prob is not None:
data_samples.pred_prob = prob
data_samples.pred_prob = prob


return {"label": label, "prob": prob} return {"label": label, "prob": prob}




+ 3
- 3
abl/reasoning/reasoner.py View File

@@ -222,9 +222,9 @@ class ReasonerBase:
symbol_num = data_sample.elements_num("pred_pseudo_label") symbol_num = data_sample.elements_num("pred_pseudo_label")
max_revision_num = self._get_max_revision_num(max_revision, symbol_num) max_revision_num = self._get_max_revision_num(max_revision, symbol_num)
pred_pseudo_label = data_sample.pred_pseudo_label[0]
pred_prob = data_sample.pred_prob[0]
y = data_sample.Y[0]
pred_pseudo_label = data_sample.pred_pseudo_label
pred_prob = data_sample.pred_prob
y = data_sample.Y
if self.use_zoopt: if self.use_zoopt:
solution = self.zoopt_get_solution( solution = self.zoopt_get_solution(


+ 25
- 41
abl/structures/list_data.py View File

@@ -132,22 +132,21 @@ class ListData(BaseDataElement):
super().__setattr__(name, value) super().__setattr__(name, value)
else: else:
raise AttributeError( raise AttributeError(
f"{name} has been used as a "
"private attribute, which is immutable."
f"{name} has been used as a " "private attribute, which is immutable."
) )


else: else:
assert isinstance(value, list), "value must be of type `list`"
# assert isinstance(value, list), "value must be of type `list`"


if len(self) > 0:
assert len(value) == len(self), (
"The length of "
f"values {len(value)} is "
"not consistent with "
"the length of this "
":obj:`ListData` "
f"{len(self)}"
)
# if len(self) > 0:
# assert len(value) == len(self), (
# "The length of "
# f"values {len(value)} is "
# "not consistent with "
# "the length of this "
# ":obj:`ListData` "
# f"{len(self)}"
# )
super().__setattr__(name, value) super().__setattr__(name, value)


__setitem__ = __setattr__ __setitem__ = __setattr__
@@ -176,32 +175,15 @@ class ListData(BaseDataElement):
if isinstance(item, str): if isinstance(item, str):
return getattr(self, item) return getattr(self, item)


if isinstance(item, int):
if item >= len(self) or item < -len(self): # type:ignore
raise IndexError(f"Index {item} out of range!")
else:
# keep the dimension
item = slice(item, None, len(self))

new_data = self.__class__(metainfo=self.metainfo) new_data = self.__class__(metainfo=self.metainfo)

if isinstance(item, torch.Tensor): if isinstance(item, torch.Tensor):
assert item.dim() == 1, (
"Only support to get the" " values along the first dimension."
)
if isinstance(item, BoolTypeTensor.__args__):
assert len(item) == len(self), (
"The shape of the "
"input(BoolTensor) "
f"{len(item)} "
"does not match the shape "
"of the indexed tensor "
"in results_field "
f"{len(self)} at "
"first dimension."
)
assert item.dim() == 1, "Only support to get the" " values along the first dimension."


for k, v in self.items(): for k, v in self.items():
if isinstance(v, torch.Tensor):
if v is None:
new_data[k] = None
elif isinstance(v, torch.Tensor):
new_data[k] = v[item] new_data[k] = v[item]
elif isinstance(v, np.ndarray): elif isinstance(v, np.ndarray):
new_data[k] = v[item.cpu().numpy()] new_data[k] = v[item.cpu().numpy()]
@@ -235,9 +217,12 @@ class ListData(BaseDataElement):
) )


else: else:
# item is a slice
# item is a slice or int
for k, v in self.items(): for k, v in self.items():
new_data[k] = v[item]
if v is None:
new_data[k] = None
else:
new_data[k] = v[item]
return new_data # type:ignore return new_data # type:ignore


@staticmethod @staticmethod
@@ -289,8 +274,7 @@ class ListData(BaseDataElement):
new_values = v0.cat(values) new_values = v0.cat(values)
else: else:
raise ValueError( raise ValueError(
f"The type of `{k}` is `{type(v0)}` which has no "
"attribute of `cat`"
f"The type of `{k}` is `{type(v0)}` which has no " "attribute of `cat`"
) )
new_data[k] = new_values new_data[k] = new_values
return new_data # type:ignore return new_data # type:ignore
@@ -302,15 +286,15 @@ class ListData(BaseDataElement):
list: Flattened data fields. list: Flattened data fields.
""" """
return flatten_list(self[item]) return flatten_list(self[item])
def elements_num(self, item: IndexType) -> int: def elements_num(self, item: IndexType) -> int:
"""int: The number of elements in self[item].""" """int: The number of elements in self[item]."""
return len(self.flatten(item)) return len(self.flatten(item))
def to_tuple(self, item: IndexType) -> tuple: def to_tuple(self, item: IndexType) -> tuple:
"""tuple: The data fields in self[item] converted to tuple.""" """tuple: The data fields in self[item] converted to tuple."""
return to_hashable(self[item]) return to_hashable(self[item])
def __len__(self) -> int: def __len__(self) -> int:
"""int: The length of ListData.""" """int: The length of ListData."""
if len(self._data_fields) > 0: if len(self._data_fields) > 0:


Loading…
Cancel
Save