diff --git a/examples/bdd_oia/bridge.py b/examples/bdd_oia/bridge.py index 0139277..2d0dcce 100644 --- a/examples/bdd_oia/bridge.py +++ b/examples/bdd_oia/bridge.py @@ -3,6 +3,7 @@ from typing import List, Any from ablkit.data import ListData from ablkit.bridge import SimpleBridge + class BDDBridge(SimpleBridge): def idx_to_pseudo_label(self, data_examples: ListData) -> List[List[Any]]: pred_idx = data_examples.pred_idx # [ ndarray(1,nc),... ] @@ -20,4 +21,4 @@ class BDDBridge(SimpleBridge): sub_list = np.array([self.reasoner.label_to_idx[_lab] for _lab in sub_list]) abduced_idx.append(sub_list) data_examples.abduced_idx = abduced_idx - return data_examples.abduced_idx \ No newline at end of file + return data_examples.abduced_idx diff --git a/examples/bdd_oia/dataset/data_util.py b/examples/bdd_oia/dataset/data_util.py index 23bb9a8..a4a9ee9 100644 --- a/examples/bdd_oia/dataset/data_util.py +++ b/examples/bdd_oia/dataset/data_util.py @@ -14,5 +14,6 @@ def get_dataset(fname, get_pseudo_label=True): Y = [tuple(y) for y in Y] return X, pseudo_label, Y -if __name__ == '__main__': - dataset = get_dataset("val.npz") \ No newline at end of file + +if __name__ == "__main__": + dataset = get_dataset("val.npz") diff --git a/examples/bdd_oia/main.py b/examples/bdd_oia/main.py index 219f76e..13998e4 100644 --- a/examples/bdd_oia/main.py +++ b/examples/bdd_oia/main.py @@ -8,7 +8,7 @@ from ablkit.data.evaluation import SymbolAccuracy from ablkit.reasoning import Reasoner from ablkit.utils import ABLLogger, print_log -from models.nn import * +from models.nn import ConceptNet from models.bdd_nn import BDDNN from models.bdd_model import BDDABLModel from reasoning.bddkb import BDDKB @@ -19,12 +19,13 @@ from metric import BDDReasoningMetric def multi_label_confidence_dist(data_example, candidates, candidates_idxs, reasoning_results): pred_prob = data_example.pred_prob.T # nc x 1 - pred_prob = np.concatenate([1-pred_prob, pred_prob], axis=1) # nc x 2 + pred_prob = np.concatenate([1 - pred_prob, pred_prob], axis=1) # nc x 2 cols = np.arange(len(candidates_idxs[0]))[None, :] corr_prob = pred_prob[cols, candidates_idxs] - costs = - np.sum(np.log(corr_prob + 1e-6), axis=1) + costs = -np.sum(np.log(corr_prob + 1e-6), axis=1) return costs + def get_args(): parser = argparse.ArgumentParser(description="BDD-OIA example") parser.add_argument( @@ -62,6 +63,7 @@ def get_args(): args = parser.parse_args() return args + def main(): args = get_args() @@ -116,7 +118,7 @@ def main(): kb, dist_func=multi_label_confidence_dist, max_revision=args.max_revision, - require_more_revision=args.require_more_revision + require_more_revision=args.require_more_revision, ) # -- Building Evaluation Metrics -------------------- diff --git a/examples/bdd_oia/metric.py b/examples/bdd_oia/metric.py index c73c2f9..e2042ec 100644 --- a/examples/bdd_oia/metric.py +++ b/examples/bdd_oia/metric.py @@ -3,6 +3,7 @@ from typing import Optional from ablkit.reasoning import KBBase from ablkit.data import BaseMetric, ListData + class BDDReasoningMetric(BaseMetric): def __init__(self, kb: KBBase, prefix: Optional[str] = None) -> None: super().__init__(prefix) @@ -13,7 +14,9 @@ class BDDReasoningMetric(BaseMetric): y_list = data_examples.Y x_list = data_examples.X for pred_pseudo_label, y, x in zip(pred_pseudo_label_list, y_list, x_list): - pred_y = self.kb.logic_forward(pred_pseudo_label, *(x,) if self.kb._num_args == 2 else ()) + pred_y = self.kb.logic_forward( + pred_pseudo_label, *(x,) if self.kb._num_args == 2 else () + ) for py, yy in zip(pred_y, y): self.results.append(int(py == yy)) @@ -21,4 +24,4 @@ class BDDReasoningMetric(BaseMetric): results = self.results metrics = dict() metrics["reasoning_accuracy"] = sum(results) / len(results) - return metrics \ No newline at end of file + return metrics diff --git a/examples/bdd_oia/models/bdd_model.py b/examples/bdd_oia/models/bdd_model.py index bdddb7e..05b1824 100644 --- a/examples/bdd_oia/models/bdd_model.py +++ b/examples/bdd_oia/models/bdd_model.py @@ -5,6 +5,7 @@ from ablkit.data import ListData from ablkit.learning import ABLModel from ablkit.utils import reform_list + class BDDABLModel(ABLModel): def predict(self, data_examples: ListData) -> Dict: model = self.base_model @@ -21,4 +22,4 @@ class BDDABLModel(ABLModel): data_examples.pred_idx = label data_examples.pred_prob = prob - return {"label": label, "prob": prob} \ No newline at end of file + return {"label": label, "prob": prob} diff --git a/examples/bdd_oia/models/bdd_nn.py b/examples/bdd_oia/models/bdd_nn.py index 5fdd378..89c7517 100644 --- a/examples/bdd_oia/models/bdd_nn.py +++ b/examples/bdd_oia/models/bdd_nn.py @@ -1,10 +1,9 @@ import logging -import os -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional import numpy import torch -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader from ablkit.learning import BasicNN, PredictionDataset, ClassificationDataset from ablkit.utils.logger import print_log @@ -15,11 +14,11 @@ class MultiLabelClassificationDataset(ClassificationDataset): if (not isinstance(X, list)) or (not isinstance(Y, list)): raise ValueError("X and Y should be of type list.") self.X = X - self.Y = torch.FloatTensor(numpy.stack(Y, axis=0)) # float32 for BCELoss + self.Y = torch.FloatTensor(numpy.stack(Y, axis=0)) # float32 for BCELoss self.transform = transform -class BDDNN(BasicNN): +class BDDNN(BasicNN): def predict( self, data_loader: Optional[DataLoader] = None, diff --git a/examples/bdd_oia/models/nn.py b/examples/bdd_oia/models/nn.py index 216be0e..8ff8173 100644 --- a/examples/bdd_oia/models/nn.py +++ b/examples/bdd_oia/models/nn.py @@ -1,5 +1,6 @@ from torch import nn + class SimpleNet(nn.Module): def __init__(self, num_features=2048, num_concepts=21): super(SimpleNet, self).__init__() @@ -8,6 +9,7 @@ class SimpleNet(nn.Module): def forward(self, x): return self.fc(x) + class ConceptNet(nn.Module): def __init__(self, num_features=2048, num_concepts=21): super(ConceptNet, self).__init__() @@ -15,7 +17,7 @@ class ConceptNet(nn.Module): self.fc = nn.Sequential( nn.Linear(num_features, intermidate_dim), nn.SiLU(), - nn.Linear(intermidate_dim, num_concepts) + nn.Linear(intermidate_dim, num_concepts), ) def forward(self, x): diff --git a/examples/bdd_oia/reasoning/bddkb.py b/examples/bdd_oia/reasoning/bddkb.py index 3548c69..2ef9424 100644 --- a/examples/bdd_oia/reasoning/bddkb.py +++ b/examples/bdd_oia/reasoning/bddkb.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from ablkit.reasoning import KBBase + class BDDKB(KBBase): def __init__(self, pseudo_label_list=None): if pseudo_label_list is None: @@ -20,9 +21,29 @@ class BDDKB(KBBase): (1, 0, 1, 1) 196 """ assert len(attrs) == 21 - green_light, follow, road_clear, red_light, traffic_sign, car, person, rider, other_obstacle, \ - left_lane, left_green_light, left_follow, no_left_lane, left_obstacle, left_solid_line, \ - right_lane, right_green_light, right_follow, no_right_lane, right_obstacle, right_solid_line = attrs + ( + green_light, + follow, + road_clear, + red_light, + traffic_sign, + car, + person, + rider, + other_obstacle, + left_lane, + left_green_light, + left_follow, + no_left_lane, + left_obstacle, + left_solid_line, + right_lane, + right_green_light, + right_follow, + no_right_lane, + right_obstacle, + right_solid_line, + ) = attrs illegal_return = (0, 0, 0, 0) if red_light == green_light == 1: @@ -30,8 +51,8 @@ class BDDKB(KBBase): obstacle = car or person or rider or other_obstacle if road_clear == obstacle: return illegal_return - move_forward = (green_light or follow or road_clear) - stop = (red_light or traffic_sign or obstacle) + move_forward = green_light or follow or road_clear + stop = red_light or traffic_sign or obstacle if stop: move_forward = 0 @@ -43,4 +64,4 @@ class BDDKB(KBBase): cannot_turn_right = no_right_lane or right_obstacle or right_solid_line turn_right = can_turn_right and int(not cannot_turn_right) - return move_forward, stop, turn_left, turn_right \ No newline at end of file + return move_forward, stop, turn_left, turn_right diff --git a/examples/bdd_oia/requirements.txt b/examples/bdd_oia/requirements.txt index c238889..a1d6490 100644 --- a/examples/bdd_oia/requirements.txt +++ b/examples/bdd_oia/requirements.txt @@ -1,2 +1 @@ -torch ablkit diff --git a/examples/hwf/main.py b/examples/hwf/main.py index 6ea0380..fb03be6 100644 --- a/examples/hwf/main.py +++ b/examples/hwf/main.py @@ -81,7 +81,7 @@ def main(): "--label-smoothing", type=float, default=0.2, - help="label smoothing in cross entropy loss (default : 0.2)" + help="label smoothing in cross entropy loss (default : 0.2)", ) parser.add_argument( "--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)" @@ -138,7 +138,12 @@ def main(): # Build BasicNN base_model = BasicNN( - cls, loss_fn, optimizer, device=device, batch_size=args.batch_size, num_epochs=args.epochs, + cls, + loss_fn, + optimizer, + device=device, + batch_size=args.batch_size, + num_epochs=args.epochs, ) # Build ABLModel