diff --git a/abl/bridge/simple_bridge.py b/abl/bridge/simple_bridge.py index ac39df1..76a492d 100644 --- a/abl/bridge/simple_bridge.py +++ b/abl/bridge/simple_bridge.py @@ -25,7 +25,7 @@ class SimpleBridge(BaseBridge): def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]: self.model.predict(data_samples) - return data_samples.pred_idx, data_samples.get("pred_prob", None) + return data_samples.pred_idx, data_samples.pred_prob def abduce_pseudo_label( self, diff --git a/abl/learning/abl_model.py b/abl/learning/abl_model.py index 97775c0..bcf03df 100644 --- a/abl/learning/abl_model.py +++ b/abl/learning/abl_model.py @@ -76,8 +76,7 @@ class ABLModel: label = reform_list(label, data_samples.X) data_samples.pred_idx = label - if prob is not None: - data_samples.pred_prob = prob + data_samples.pred_prob = prob return {"label": label, "prob": prob} diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index 2e57570..962c087 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -222,9 +222,9 @@ class ReasonerBase: symbol_num = data_sample.elements_num("pred_pseudo_label") max_revision_num = self._get_max_revision_num(max_revision, symbol_num) - pred_pseudo_label = data_sample.pred_pseudo_label[0] - pred_prob = data_sample.pred_prob[0] - y = data_sample.Y[0] + pred_pseudo_label = data_sample.pred_pseudo_label + pred_prob = data_sample.pred_prob + y = data_sample.Y if self.use_zoopt: solution = self.zoopt_get_solution( diff --git a/abl/structures/list_data.py b/abl/structures/list_data.py index 2571a13..a53ffc5 100644 --- a/abl/structures/list_data.py +++ b/abl/structures/list_data.py @@ -132,22 +132,21 @@ class ListData(BaseDataElement): super().__setattr__(name, value) else: raise AttributeError( - f"{name} has been used as a " - "private attribute, which is immutable." + f"{name} has been used as a " "private attribute, which is immutable." ) else: - assert isinstance(value, list), "value must be of type `list`" + # assert isinstance(value, list), "value must be of type `list`" - if len(self) > 0: - assert len(value) == len(self), ( - "The length of " - f"values {len(value)} is " - "not consistent with " - "the length of this " - ":obj:`ListData` " - f"{len(self)}" - ) + # if len(self) > 0: + # assert len(value) == len(self), ( + # "The length of " + # f"values {len(value)} is " + # "not consistent with " + # "the length of this " + # ":obj:`ListData` " + # f"{len(self)}" + # ) super().__setattr__(name, value) __setitem__ = __setattr__ @@ -176,32 +175,15 @@ class ListData(BaseDataElement): if isinstance(item, str): return getattr(self, item) - if isinstance(item, int): - if item >= len(self) or item < -len(self): # type:ignore - raise IndexError(f"Index {item} out of range!") - else: - # keep the dimension - item = slice(item, None, len(self)) - new_data = self.__class__(metainfo=self.metainfo) + if isinstance(item, torch.Tensor): - assert item.dim() == 1, ( - "Only support to get the" " values along the first dimension." - ) - if isinstance(item, BoolTypeTensor.__args__): - assert len(item) == len(self), ( - "The shape of the " - "input(BoolTensor) " - f"{len(item)} " - "does not match the shape " - "of the indexed tensor " - "in results_field " - f"{len(self)} at " - "first dimension." - ) + assert item.dim() == 1, "Only support to get the" " values along the first dimension." for k, v in self.items(): - if isinstance(v, torch.Tensor): + if v is None: + new_data[k] = None + elif isinstance(v, torch.Tensor): new_data[k] = v[item] elif isinstance(v, np.ndarray): new_data[k] = v[item.cpu().numpy()] @@ -235,9 +217,12 @@ class ListData(BaseDataElement): ) else: - # item is a slice + # item is a slice or int for k, v in self.items(): - new_data[k] = v[item] + if v is None: + new_data[k] = None + else: + new_data[k] = v[item] return new_data # type:ignore @staticmethod @@ -289,8 +274,7 @@ class ListData(BaseDataElement): new_values = v0.cat(values) else: raise ValueError( - f"The type of `{k}` is `{type(v0)}` which has no " - "attribute of `cat`" + f"The type of `{k}` is `{type(v0)}` which has no " "attribute of `cat`" ) new_data[k] = new_values return new_data # type:ignore @@ -302,15 +286,15 @@ class ListData(BaseDataElement): list: Flattened data fields. """ return flatten_list(self[item]) - + def elements_num(self, item: IndexType) -> int: """int: The number of elements in self[item].""" return len(self.flatten(item)) - + def to_tuple(self, item: IndexType) -> tuple: """tuple: The data fields in self[item] converted to tuple.""" return to_hashable(self[item]) - + def __len__(self) -> int: """int: The length of ListData.""" if len(self._data_fields) > 0: