@@ -3,6 +3,7 @@ from typing import List, Any | |||
from ablkit.data import ListData | |||
from ablkit.bridge import SimpleBridge | |||
class BDDBridge(SimpleBridge): | |||
def idx_to_pseudo_label(self, data_examples: ListData) -> List[List[Any]]: | |||
pred_idx = data_examples.pred_idx # [ ndarray(1,nc),... ] | |||
@@ -20,4 +21,4 @@ class BDDBridge(SimpleBridge): | |||
sub_list = np.array([self.reasoner.label_to_idx[_lab] for _lab in sub_list]) | |||
abduced_idx.append(sub_list) | |||
data_examples.abduced_idx = abduced_idx | |||
return data_examples.abduced_idx | |||
return data_examples.abduced_idx |
@@ -14,5 +14,6 @@ def get_dataset(fname, get_pseudo_label=True): | |||
Y = [tuple(y) for y in Y] | |||
return X, pseudo_label, Y | |||
if __name__ == '__main__': | |||
dataset = get_dataset("val.npz") | |||
if __name__ == "__main__": | |||
dataset = get_dataset("val.npz") |
@@ -8,7 +8,7 @@ from ablkit.data.evaluation import SymbolAccuracy | |||
from ablkit.reasoning import Reasoner | |||
from ablkit.utils import ABLLogger, print_log | |||
from models.nn import * | |||
from models.nn import ConceptNet | |||
from models.bdd_nn import BDDNN | |||
from models.bdd_model import BDDABLModel | |||
from reasoning.bddkb import BDDKB | |||
@@ -19,12 +19,13 @@ from metric import BDDReasoningMetric | |||
def multi_label_confidence_dist(data_example, candidates, candidates_idxs, reasoning_results): | |||
pred_prob = data_example.pred_prob.T # nc x 1 | |||
pred_prob = np.concatenate([1-pred_prob, pred_prob], axis=1) # nc x 2 | |||
pred_prob = np.concatenate([1 - pred_prob, pred_prob], axis=1) # nc x 2 | |||
cols = np.arange(len(candidates_idxs[0]))[None, :] | |||
corr_prob = pred_prob[cols, candidates_idxs] | |||
costs = - np.sum(np.log(corr_prob + 1e-6), axis=1) | |||
costs = -np.sum(np.log(corr_prob + 1e-6), axis=1) | |||
return costs | |||
def get_args(): | |||
parser = argparse.ArgumentParser(description="BDD-OIA example") | |||
parser.add_argument( | |||
@@ -62,6 +63,7 @@ def get_args(): | |||
args = parser.parse_args() | |||
return args | |||
def main(): | |||
args = get_args() | |||
@@ -116,7 +118,7 @@ def main(): | |||
kb, | |||
dist_func=multi_label_confidence_dist, | |||
max_revision=args.max_revision, | |||
require_more_revision=args.require_more_revision | |||
require_more_revision=args.require_more_revision, | |||
) | |||
# -- Building Evaluation Metrics -------------------- | |||
@@ -3,6 +3,7 @@ from typing import Optional | |||
from ablkit.reasoning import KBBase | |||
from ablkit.data import BaseMetric, ListData | |||
class BDDReasoningMetric(BaseMetric): | |||
def __init__(self, kb: KBBase, prefix: Optional[str] = None) -> None: | |||
super().__init__(prefix) | |||
@@ -13,7 +14,9 @@ class BDDReasoningMetric(BaseMetric): | |||
y_list = data_examples.Y | |||
x_list = data_examples.X | |||
for pred_pseudo_label, y, x in zip(pred_pseudo_label_list, y_list, x_list): | |||
pred_y = self.kb.logic_forward(pred_pseudo_label, *(x,) if self.kb._num_args == 2 else ()) | |||
pred_y = self.kb.logic_forward( | |||
pred_pseudo_label, *(x,) if self.kb._num_args == 2 else () | |||
) | |||
for py, yy in zip(pred_y, y): | |||
self.results.append(int(py == yy)) | |||
@@ -21,4 +24,4 @@ class BDDReasoningMetric(BaseMetric): | |||
results = self.results | |||
metrics = dict() | |||
metrics["reasoning_accuracy"] = sum(results) / len(results) | |||
return metrics | |||
return metrics |
@@ -5,6 +5,7 @@ from ablkit.data import ListData | |||
from ablkit.learning import ABLModel | |||
from ablkit.utils import reform_list | |||
class BDDABLModel(ABLModel): | |||
def predict(self, data_examples: ListData) -> Dict: | |||
model = self.base_model | |||
@@ -21,4 +22,4 @@ class BDDABLModel(ABLModel): | |||
data_examples.pred_idx = label | |||
data_examples.pred_prob = prob | |||
return {"label": label, "prob": prob} | |||
return {"label": label, "prob": prob} |
@@ -1,10 +1,9 @@ | |||
import logging | |||
import os | |||
from typing import Any, Callable, List, Optional, Tuple, Union | |||
from typing import Any, Callable, List, Optional | |||
import numpy | |||
import torch | |||
from torch.utils.data import DataLoader, Dataset | |||
from torch.utils.data import DataLoader | |||
from ablkit.learning import BasicNN, PredictionDataset, ClassificationDataset | |||
from ablkit.utils.logger import print_log | |||
@@ -15,11 +14,11 @@ class MultiLabelClassificationDataset(ClassificationDataset): | |||
if (not isinstance(X, list)) or (not isinstance(Y, list)): | |||
raise ValueError("X and Y should be of type list.") | |||
self.X = X | |||
self.Y = torch.FloatTensor(numpy.stack(Y, axis=0)) # float32 for BCELoss | |||
self.Y = torch.FloatTensor(numpy.stack(Y, axis=0)) # float32 for BCELoss | |||
self.transform = transform | |||
class BDDNN(BasicNN): | |||
class BDDNN(BasicNN): | |||
def predict( | |||
self, | |||
data_loader: Optional[DataLoader] = None, | |||
@@ -1,5 +1,6 @@ | |||
from torch import nn | |||
class SimpleNet(nn.Module): | |||
def __init__(self, num_features=2048, num_concepts=21): | |||
super(SimpleNet, self).__init__() | |||
@@ -8,6 +9,7 @@ class SimpleNet(nn.Module): | |||
def forward(self, x): | |||
return self.fc(x) | |||
class ConceptNet(nn.Module): | |||
def __init__(self, num_features=2048, num_concepts=21): | |||
super(ConceptNet, self).__init__() | |||
@@ -15,7 +17,7 @@ class ConceptNet(nn.Module): | |||
self.fc = nn.Sequential( | |||
nn.Linear(num_features, intermidate_dim), | |||
nn.SiLU(), | |||
nn.Linear(intermidate_dim, num_concepts) | |||
nn.Linear(intermidate_dim, num_concepts), | |||
) | |||
def forward(self, x): | |||
@@ -1,6 +1,7 @@ | |||
# -*- coding: utf-8 -*- | |||
from ablkit.reasoning import KBBase | |||
class BDDKB(KBBase): | |||
def __init__(self, pseudo_label_list=None): | |||
if pseudo_label_list is None: | |||
@@ -20,9 +21,29 @@ class BDDKB(KBBase): | |||
(1, 0, 1, 1) 196 | |||
""" | |||
assert len(attrs) == 21 | |||
green_light, follow, road_clear, red_light, traffic_sign, car, person, rider, other_obstacle, \ | |||
left_lane, left_green_light, left_follow, no_left_lane, left_obstacle, left_solid_line, \ | |||
right_lane, right_green_light, right_follow, no_right_lane, right_obstacle, right_solid_line = attrs | |||
( | |||
green_light, | |||
follow, | |||
road_clear, | |||
red_light, | |||
traffic_sign, | |||
car, | |||
person, | |||
rider, | |||
other_obstacle, | |||
left_lane, | |||
left_green_light, | |||
left_follow, | |||
no_left_lane, | |||
left_obstacle, | |||
left_solid_line, | |||
right_lane, | |||
right_green_light, | |||
right_follow, | |||
no_right_lane, | |||
right_obstacle, | |||
right_solid_line, | |||
) = attrs | |||
illegal_return = (0, 0, 0, 0) | |||
if red_light == green_light == 1: | |||
@@ -30,8 +51,8 @@ class BDDKB(KBBase): | |||
obstacle = car or person or rider or other_obstacle | |||
if road_clear == obstacle: | |||
return illegal_return | |||
move_forward = (green_light or follow or road_clear) | |||
stop = (red_light or traffic_sign or obstacle) | |||
move_forward = green_light or follow or road_clear | |||
stop = red_light or traffic_sign or obstacle | |||
if stop: | |||
move_forward = 0 | |||
@@ -43,4 +64,4 @@ class BDDKB(KBBase): | |||
cannot_turn_right = no_right_lane or right_obstacle or right_solid_line | |||
turn_right = can_turn_right and int(not cannot_turn_right) | |||
return move_forward, stop, turn_left, turn_right | |||
return move_forward, stop, turn_left, turn_right |
@@ -1,2 +1 @@ | |||
torch | |||
ablkit |
@@ -81,7 +81,7 @@ def main(): | |||
"--label-smoothing", | |||
type=float, | |||
default=0.2, | |||
help="label smoothing in cross entropy loss (default : 0.2)" | |||
help="label smoothing in cross entropy loss (default : 0.2)", | |||
) | |||
parser.add_argument( | |||
"--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)" | |||
@@ -138,7 +138,12 @@ def main(): | |||
# Build BasicNN | |||
base_model = BasicNN( | |||
cls, loss_fn, optimizer, device=device, batch_size=args.batch_size, num_epochs=args.epochs, | |||
cls, | |||
loss_fn, | |||
optimizer, | |||
device=device, | |||
batch_size=args.batch_size, | |||
num_epochs=args.epochs, | |||
) | |||
# Build ABLModel | |||