@@ -3,6 +3,7 @@ from typing import List, Any | |||||
from ablkit.data import ListData | from ablkit.data import ListData | ||||
from ablkit.bridge import SimpleBridge | from ablkit.bridge import SimpleBridge | ||||
class BDDBridge(SimpleBridge): | class BDDBridge(SimpleBridge): | ||||
def idx_to_pseudo_label(self, data_examples: ListData) -> List[List[Any]]: | def idx_to_pseudo_label(self, data_examples: ListData) -> List[List[Any]]: | ||||
pred_idx = data_examples.pred_idx # [ ndarray(1,nc),... ] | pred_idx = data_examples.pred_idx # [ ndarray(1,nc),... ] | ||||
@@ -20,4 +21,4 @@ class BDDBridge(SimpleBridge): | |||||
sub_list = np.array([self.reasoner.label_to_idx[_lab] for _lab in sub_list]) | sub_list = np.array([self.reasoner.label_to_idx[_lab] for _lab in sub_list]) | ||||
abduced_idx.append(sub_list) | abduced_idx.append(sub_list) | ||||
data_examples.abduced_idx = abduced_idx | data_examples.abduced_idx = abduced_idx | ||||
return data_examples.abduced_idx | |||||
return data_examples.abduced_idx |
@@ -14,5 +14,6 @@ def get_dataset(fname, get_pseudo_label=True): | |||||
Y = [tuple(y) for y in Y] | Y = [tuple(y) for y in Y] | ||||
return X, pseudo_label, Y | return X, pseudo_label, Y | ||||
if __name__ == '__main__': | |||||
dataset = get_dataset("val.npz") | |||||
if __name__ == "__main__": | |||||
dataset = get_dataset("val.npz") |
@@ -8,7 +8,7 @@ from ablkit.data.evaluation import SymbolAccuracy | |||||
from ablkit.reasoning import Reasoner | from ablkit.reasoning import Reasoner | ||||
from ablkit.utils import ABLLogger, print_log | from ablkit.utils import ABLLogger, print_log | ||||
from models.nn import * | |||||
from models.nn import ConceptNet | |||||
from models.bdd_nn import BDDNN | from models.bdd_nn import BDDNN | ||||
from models.bdd_model import BDDABLModel | from models.bdd_model import BDDABLModel | ||||
from reasoning.bddkb import BDDKB | from reasoning.bddkb import BDDKB | ||||
@@ -19,12 +19,13 @@ from metric import BDDReasoningMetric | |||||
def multi_label_confidence_dist(data_example, candidates, candidates_idxs, reasoning_results): | def multi_label_confidence_dist(data_example, candidates, candidates_idxs, reasoning_results): | ||||
pred_prob = data_example.pred_prob.T # nc x 1 | pred_prob = data_example.pred_prob.T # nc x 1 | ||||
pred_prob = np.concatenate([1-pred_prob, pred_prob], axis=1) # nc x 2 | |||||
pred_prob = np.concatenate([1 - pred_prob, pred_prob], axis=1) # nc x 2 | |||||
cols = np.arange(len(candidates_idxs[0]))[None, :] | cols = np.arange(len(candidates_idxs[0]))[None, :] | ||||
corr_prob = pred_prob[cols, candidates_idxs] | corr_prob = pred_prob[cols, candidates_idxs] | ||||
costs = - np.sum(np.log(corr_prob + 1e-6), axis=1) | |||||
costs = -np.sum(np.log(corr_prob + 1e-6), axis=1) | |||||
return costs | return costs | ||||
def get_args(): | def get_args(): | ||||
parser = argparse.ArgumentParser(description="BDD-OIA example") | parser = argparse.ArgumentParser(description="BDD-OIA example") | ||||
parser.add_argument( | parser.add_argument( | ||||
@@ -62,6 +63,7 @@ def get_args(): | |||||
args = parser.parse_args() | args = parser.parse_args() | ||||
return args | return args | ||||
def main(): | def main(): | ||||
args = get_args() | args = get_args() | ||||
@@ -116,7 +118,7 @@ def main(): | |||||
kb, | kb, | ||||
dist_func=multi_label_confidence_dist, | dist_func=multi_label_confidence_dist, | ||||
max_revision=args.max_revision, | max_revision=args.max_revision, | ||||
require_more_revision=args.require_more_revision | |||||
require_more_revision=args.require_more_revision, | |||||
) | ) | ||||
# -- Building Evaluation Metrics -------------------- | # -- Building Evaluation Metrics -------------------- | ||||
@@ -3,6 +3,7 @@ from typing import Optional | |||||
from ablkit.reasoning import KBBase | from ablkit.reasoning import KBBase | ||||
from ablkit.data import BaseMetric, ListData | from ablkit.data import BaseMetric, ListData | ||||
class BDDReasoningMetric(BaseMetric): | class BDDReasoningMetric(BaseMetric): | ||||
def __init__(self, kb: KBBase, prefix: Optional[str] = None) -> None: | def __init__(self, kb: KBBase, prefix: Optional[str] = None) -> None: | ||||
super().__init__(prefix) | super().__init__(prefix) | ||||
@@ -13,7 +14,9 @@ class BDDReasoningMetric(BaseMetric): | |||||
y_list = data_examples.Y | y_list = data_examples.Y | ||||
x_list = data_examples.X | x_list = data_examples.X | ||||
for pred_pseudo_label, y, x in zip(pred_pseudo_label_list, y_list, x_list): | for pred_pseudo_label, y, x in zip(pred_pseudo_label_list, y_list, x_list): | ||||
pred_y = self.kb.logic_forward(pred_pseudo_label, *(x,) if self.kb._num_args == 2 else ()) | |||||
pred_y = self.kb.logic_forward( | |||||
pred_pseudo_label, *(x,) if self.kb._num_args == 2 else () | |||||
) | |||||
for py, yy in zip(pred_y, y): | for py, yy in zip(pred_y, y): | ||||
self.results.append(int(py == yy)) | self.results.append(int(py == yy)) | ||||
@@ -21,4 +24,4 @@ class BDDReasoningMetric(BaseMetric): | |||||
results = self.results | results = self.results | ||||
metrics = dict() | metrics = dict() | ||||
metrics["reasoning_accuracy"] = sum(results) / len(results) | metrics["reasoning_accuracy"] = sum(results) / len(results) | ||||
return metrics | |||||
return metrics |
@@ -5,6 +5,7 @@ from ablkit.data import ListData | |||||
from ablkit.learning import ABLModel | from ablkit.learning import ABLModel | ||||
from ablkit.utils import reform_list | from ablkit.utils import reform_list | ||||
class BDDABLModel(ABLModel): | class BDDABLModel(ABLModel): | ||||
def predict(self, data_examples: ListData) -> Dict: | def predict(self, data_examples: ListData) -> Dict: | ||||
model = self.base_model | model = self.base_model | ||||
@@ -21,4 +22,4 @@ class BDDABLModel(ABLModel): | |||||
data_examples.pred_idx = label | data_examples.pred_idx = label | ||||
data_examples.pred_prob = prob | data_examples.pred_prob = prob | ||||
return {"label": label, "prob": prob} | |||||
return {"label": label, "prob": prob} |
@@ -1,10 +1,9 @@ | |||||
import logging | import logging | ||||
import os | |||||
from typing import Any, Callable, List, Optional, Tuple, Union | |||||
from typing import Any, Callable, List, Optional | |||||
import numpy | import numpy | ||||
import torch | import torch | ||||
from torch.utils.data import DataLoader, Dataset | |||||
from torch.utils.data import DataLoader | |||||
from ablkit.learning import BasicNN, PredictionDataset, ClassificationDataset | from ablkit.learning import BasicNN, PredictionDataset, ClassificationDataset | ||||
from ablkit.utils.logger import print_log | from ablkit.utils.logger import print_log | ||||
@@ -15,11 +14,11 @@ class MultiLabelClassificationDataset(ClassificationDataset): | |||||
if (not isinstance(X, list)) or (not isinstance(Y, list)): | if (not isinstance(X, list)) or (not isinstance(Y, list)): | ||||
raise ValueError("X and Y should be of type list.") | raise ValueError("X and Y should be of type list.") | ||||
self.X = X | self.X = X | ||||
self.Y = torch.FloatTensor(numpy.stack(Y, axis=0)) # float32 for BCELoss | |||||
self.Y = torch.FloatTensor(numpy.stack(Y, axis=0)) # float32 for BCELoss | |||||
self.transform = transform | self.transform = transform | ||||
class BDDNN(BasicNN): | |||||
class BDDNN(BasicNN): | |||||
def predict( | def predict( | ||||
self, | self, | ||||
data_loader: Optional[DataLoader] = None, | data_loader: Optional[DataLoader] = None, | ||||
@@ -1,5 +1,6 @@ | |||||
from torch import nn | from torch import nn | ||||
class SimpleNet(nn.Module): | class SimpleNet(nn.Module): | ||||
def __init__(self, num_features=2048, num_concepts=21): | def __init__(self, num_features=2048, num_concepts=21): | ||||
super(SimpleNet, self).__init__() | super(SimpleNet, self).__init__() | ||||
@@ -8,6 +9,7 @@ class SimpleNet(nn.Module): | |||||
def forward(self, x): | def forward(self, x): | ||||
return self.fc(x) | return self.fc(x) | ||||
class ConceptNet(nn.Module): | class ConceptNet(nn.Module): | ||||
def __init__(self, num_features=2048, num_concepts=21): | def __init__(self, num_features=2048, num_concepts=21): | ||||
super(ConceptNet, self).__init__() | super(ConceptNet, self).__init__() | ||||
@@ -15,7 +17,7 @@ class ConceptNet(nn.Module): | |||||
self.fc = nn.Sequential( | self.fc = nn.Sequential( | ||||
nn.Linear(num_features, intermidate_dim), | nn.Linear(num_features, intermidate_dim), | ||||
nn.SiLU(), | nn.SiLU(), | ||||
nn.Linear(intermidate_dim, num_concepts) | |||||
nn.Linear(intermidate_dim, num_concepts), | |||||
) | ) | ||||
def forward(self, x): | def forward(self, x): | ||||
@@ -1,6 +1,7 @@ | |||||
# -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
from ablkit.reasoning import KBBase | from ablkit.reasoning import KBBase | ||||
class BDDKB(KBBase): | class BDDKB(KBBase): | ||||
def __init__(self, pseudo_label_list=None): | def __init__(self, pseudo_label_list=None): | ||||
if pseudo_label_list is None: | if pseudo_label_list is None: | ||||
@@ -20,9 +21,29 @@ class BDDKB(KBBase): | |||||
(1, 0, 1, 1) 196 | (1, 0, 1, 1) 196 | ||||
""" | """ | ||||
assert len(attrs) == 21 | assert len(attrs) == 21 | ||||
green_light, follow, road_clear, red_light, traffic_sign, car, person, rider, other_obstacle, \ | |||||
left_lane, left_green_light, left_follow, no_left_lane, left_obstacle, left_solid_line, \ | |||||
right_lane, right_green_light, right_follow, no_right_lane, right_obstacle, right_solid_line = attrs | |||||
( | |||||
green_light, | |||||
follow, | |||||
road_clear, | |||||
red_light, | |||||
traffic_sign, | |||||
car, | |||||
person, | |||||
rider, | |||||
other_obstacle, | |||||
left_lane, | |||||
left_green_light, | |||||
left_follow, | |||||
no_left_lane, | |||||
left_obstacle, | |||||
left_solid_line, | |||||
right_lane, | |||||
right_green_light, | |||||
right_follow, | |||||
no_right_lane, | |||||
right_obstacle, | |||||
right_solid_line, | |||||
) = attrs | |||||
illegal_return = (0, 0, 0, 0) | illegal_return = (0, 0, 0, 0) | ||||
if red_light == green_light == 1: | if red_light == green_light == 1: | ||||
@@ -30,8 +51,8 @@ class BDDKB(KBBase): | |||||
obstacle = car or person or rider or other_obstacle | obstacle = car or person or rider or other_obstacle | ||||
if road_clear == obstacle: | if road_clear == obstacle: | ||||
return illegal_return | return illegal_return | ||||
move_forward = (green_light or follow or road_clear) | |||||
stop = (red_light or traffic_sign or obstacle) | |||||
move_forward = green_light or follow or road_clear | |||||
stop = red_light or traffic_sign or obstacle | |||||
if stop: | if stop: | ||||
move_forward = 0 | move_forward = 0 | ||||
@@ -43,4 +64,4 @@ class BDDKB(KBBase): | |||||
cannot_turn_right = no_right_lane or right_obstacle or right_solid_line | cannot_turn_right = no_right_lane or right_obstacle or right_solid_line | ||||
turn_right = can_turn_right and int(not cannot_turn_right) | turn_right = can_turn_right and int(not cannot_turn_right) | ||||
return move_forward, stop, turn_left, turn_right | |||||
return move_forward, stop, turn_left, turn_right |
@@ -1,2 +1 @@ | |||||
torch | |||||
ablkit | ablkit |
@@ -81,7 +81,7 @@ def main(): | |||||
"--label-smoothing", | "--label-smoothing", | ||||
type=float, | type=float, | ||||
default=0.2, | default=0.2, | ||||
help="label smoothing in cross entropy loss (default : 0.2)" | |||||
help="label smoothing in cross entropy loss (default : 0.2)", | |||||
) | ) | ||||
parser.add_argument( | parser.add_argument( | ||||
"--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)" | "--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)" | ||||
@@ -138,7 +138,12 @@ def main(): | |||||
# Build BasicNN | # Build BasicNN | ||||
base_model = BasicNN( | base_model = BasicNN( | ||||
cls, loss_fn, optimizer, device=device, batch_size=args.batch_size, num_epochs=args.epochs, | |||||
cls, | |||||
loss_fn, | |||||
optimizer, | |||||
device=device, | |||||
batch_size=args.batch_size, | |||||
num_epochs=args.epochs, | |||||
) | ) | ||||
# Build ABLModel | # Build ABLModel | ||||