Browse Source

[FIX] pass flake8

examples
troyyyyy 3 months ago
parent
commit
8db8c1b4a5
10 changed files with 59 additions and 25 deletions
  1. +2
    -1
      examples/bdd_oia/bridge.py
  2. +3
    -2
      examples/bdd_oia/dataset/data_util.py
  3. +6
    -4
      examples/bdd_oia/main.py
  4. +5
    -2
      examples/bdd_oia/metric.py
  5. +2
    -1
      examples/bdd_oia/models/bdd_model.py
  6. +4
    -5
      examples/bdd_oia/models/bdd_nn.py
  7. +3
    -1
      examples/bdd_oia/models/nn.py
  8. +27
    -6
      examples/bdd_oia/reasoning/bddkb.py
  9. +0
    -1
      examples/bdd_oia/requirements.txt
  10. +7
    -2
      examples/hwf/main.py

+ 2
- 1
examples/bdd_oia/bridge.py View File

@@ -3,6 +3,7 @@ from typing import List, Any
from ablkit.data import ListData from ablkit.data import ListData
from ablkit.bridge import SimpleBridge from ablkit.bridge import SimpleBridge



class BDDBridge(SimpleBridge): class BDDBridge(SimpleBridge):
def idx_to_pseudo_label(self, data_examples: ListData) -> List[List[Any]]: def idx_to_pseudo_label(self, data_examples: ListData) -> List[List[Any]]:
pred_idx = data_examples.pred_idx # [ ndarray(1,nc),... ] pred_idx = data_examples.pred_idx # [ ndarray(1,nc),... ]
@@ -20,4 +21,4 @@ class BDDBridge(SimpleBridge):
sub_list = np.array([self.reasoner.label_to_idx[_lab] for _lab in sub_list]) sub_list = np.array([self.reasoner.label_to_idx[_lab] for _lab in sub_list])
abduced_idx.append(sub_list) abduced_idx.append(sub_list)
data_examples.abduced_idx = abduced_idx data_examples.abduced_idx = abduced_idx
return data_examples.abduced_idx
return data_examples.abduced_idx

+ 3
- 2
examples/bdd_oia/dataset/data_util.py View File

@@ -14,5 +14,6 @@ def get_dataset(fname, get_pseudo_label=True):
Y = [tuple(y) for y in Y] Y = [tuple(y) for y in Y]
return X, pseudo_label, Y return X, pseudo_label, Y


if __name__ == '__main__':
dataset = get_dataset("val.npz")

if __name__ == "__main__":
dataset = get_dataset("val.npz")

+ 6
- 4
examples/bdd_oia/main.py View File

@@ -8,7 +8,7 @@ from ablkit.data.evaluation import SymbolAccuracy
from ablkit.reasoning import Reasoner from ablkit.reasoning import Reasoner
from ablkit.utils import ABLLogger, print_log from ablkit.utils import ABLLogger, print_log


from models.nn import *
from models.nn import ConceptNet
from models.bdd_nn import BDDNN from models.bdd_nn import BDDNN
from models.bdd_model import BDDABLModel from models.bdd_model import BDDABLModel
from reasoning.bddkb import BDDKB from reasoning.bddkb import BDDKB
@@ -19,12 +19,13 @@ from metric import BDDReasoningMetric


def multi_label_confidence_dist(data_example, candidates, candidates_idxs, reasoning_results): def multi_label_confidence_dist(data_example, candidates, candidates_idxs, reasoning_results):
pred_prob = data_example.pred_prob.T # nc x 1 pred_prob = data_example.pred_prob.T # nc x 1
pred_prob = np.concatenate([1-pred_prob, pred_prob], axis=1) # nc x 2
pred_prob = np.concatenate([1 - pred_prob, pred_prob], axis=1) # nc x 2
cols = np.arange(len(candidates_idxs[0]))[None, :] cols = np.arange(len(candidates_idxs[0]))[None, :]
corr_prob = pred_prob[cols, candidates_idxs] corr_prob = pred_prob[cols, candidates_idxs]
costs = - np.sum(np.log(corr_prob + 1e-6), axis=1)
costs = -np.sum(np.log(corr_prob + 1e-6), axis=1)
return costs return costs



def get_args(): def get_args():
parser = argparse.ArgumentParser(description="BDD-OIA example") parser = argparse.ArgumentParser(description="BDD-OIA example")
parser.add_argument( parser.add_argument(
@@ -62,6 +63,7 @@ def get_args():
args = parser.parse_args() args = parser.parse_args()
return args return args



def main(): def main():
args = get_args() args = get_args()


@@ -116,7 +118,7 @@ def main():
kb, kb,
dist_func=multi_label_confidence_dist, dist_func=multi_label_confidence_dist,
max_revision=args.max_revision, max_revision=args.max_revision,
require_more_revision=args.require_more_revision
require_more_revision=args.require_more_revision,
) )


# -- Building Evaluation Metrics -------------------- # -- Building Evaluation Metrics --------------------


+ 5
- 2
examples/bdd_oia/metric.py View File

@@ -3,6 +3,7 @@ from typing import Optional
from ablkit.reasoning import KBBase from ablkit.reasoning import KBBase
from ablkit.data import BaseMetric, ListData from ablkit.data import BaseMetric, ListData



class BDDReasoningMetric(BaseMetric): class BDDReasoningMetric(BaseMetric):
def __init__(self, kb: KBBase, prefix: Optional[str] = None) -> None: def __init__(self, kb: KBBase, prefix: Optional[str] = None) -> None:
super().__init__(prefix) super().__init__(prefix)
@@ -13,7 +14,9 @@ class BDDReasoningMetric(BaseMetric):
y_list = data_examples.Y y_list = data_examples.Y
x_list = data_examples.X x_list = data_examples.X
for pred_pseudo_label, y, x in zip(pred_pseudo_label_list, y_list, x_list): for pred_pseudo_label, y, x in zip(pred_pseudo_label_list, y_list, x_list):
pred_y = self.kb.logic_forward(pred_pseudo_label, *(x,) if self.kb._num_args == 2 else ())
pred_y = self.kb.logic_forward(
pred_pseudo_label, *(x,) if self.kb._num_args == 2 else ()
)
for py, yy in zip(pred_y, y): for py, yy in zip(pred_y, y):
self.results.append(int(py == yy)) self.results.append(int(py == yy))


@@ -21,4 +24,4 @@ class BDDReasoningMetric(BaseMetric):
results = self.results results = self.results
metrics = dict() metrics = dict()
metrics["reasoning_accuracy"] = sum(results) / len(results) metrics["reasoning_accuracy"] = sum(results) / len(results)
return metrics
return metrics

+ 2
- 1
examples/bdd_oia/models/bdd_model.py View File

@@ -5,6 +5,7 @@ from ablkit.data import ListData
from ablkit.learning import ABLModel from ablkit.learning import ABLModel
from ablkit.utils import reform_list from ablkit.utils import reform_list



class BDDABLModel(ABLModel): class BDDABLModel(ABLModel):
def predict(self, data_examples: ListData) -> Dict: def predict(self, data_examples: ListData) -> Dict:
model = self.base_model model = self.base_model
@@ -21,4 +22,4 @@ class BDDABLModel(ABLModel):
data_examples.pred_idx = label data_examples.pred_idx = label
data_examples.pred_prob = prob data_examples.pred_prob = prob


return {"label": label, "prob": prob}
return {"label": label, "prob": prob}

+ 4
- 5
examples/bdd_oia/models/bdd_nn.py View File

@@ -1,10 +1,9 @@
import logging import logging
import os
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Callable, List, Optional


import numpy import numpy
import torch import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import DataLoader


from ablkit.learning import BasicNN, PredictionDataset, ClassificationDataset from ablkit.learning import BasicNN, PredictionDataset, ClassificationDataset
from ablkit.utils.logger import print_log from ablkit.utils.logger import print_log
@@ -15,11 +14,11 @@ class MultiLabelClassificationDataset(ClassificationDataset):
if (not isinstance(X, list)) or (not isinstance(Y, list)): if (not isinstance(X, list)) or (not isinstance(Y, list)):
raise ValueError("X and Y should be of type list.") raise ValueError("X and Y should be of type list.")
self.X = X self.X = X
self.Y = torch.FloatTensor(numpy.stack(Y, axis=0)) # float32 for BCELoss
self.Y = torch.FloatTensor(numpy.stack(Y, axis=0)) # float32 for BCELoss
self.transform = transform self.transform = transform


class BDDNN(BasicNN):


class BDDNN(BasicNN):
def predict( def predict(
self, self,
data_loader: Optional[DataLoader] = None, data_loader: Optional[DataLoader] = None,


+ 3
- 1
examples/bdd_oia/models/nn.py View File

@@ -1,5 +1,6 @@
from torch import nn from torch import nn



class SimpleNet(nn.Module): class SimpleNet(nn.Module):
def __init__(self, num_features=2048, num_concepts=21): def __init__(self, num_features=2048, num_concepts=21):
super(SimpleNet, self).__init__() super(SimpleNet, self).__init__()
@@ -8,6 +9,7 @@ class SimpleNet(nn.Module):
def forward(self, x): def forward(self, x):
return self.fc(x) return self.fc(x)



class ConceptNet(nn.Module): class ConceptNet(nn.Module):
def __init__(self, num_features=2048, num_concepts=21): def __init__(self, num_features=2048, num_concepts=21):
super(ConceptNet, self).__init__() super(ConceptNet, self).__init__()
@@ -15,7 +17,7 @@ class ConceptNet(nn.Module):
self.fc = nn.Sequential( self.fc = nn.Sequential(
nn.Linear(num_features, intermidate_dim), nn.Linear(num_features, intermidate_dim),
nn.SiLU(), nn.SiLU(),
nn.Linear(intermidate_dim, num_concepts)
nn.Linear(intermidate_dim, num_concepts),
) )


def forward(self, x): def forward(self, x):


+ 27
- 6
examples/bdd_oia/reasoning/bddkb.py View File

@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from ablkit.reasoning import KBBase from ablkit.reasoning import KBBase



class BDDKB(KBBase): class BDDKB(KBBase):
def __init__(self, pseudo_label_list=None): def __init__(self, pseudo_label_list=None):
if pseudo_label_list is None: if pseudo_label_list is None:
@@ -20,9 +21,29 @@ class BDDKB(KBBase):
(1, 0, 1, 1) 196 (1, 0, 1, 1) 196
""" """
assert len(attrs) == 21 assert len(attrs) == 21
green_light, follow, road_clear, red_light, traffic_sign, car, person, rider, other_obstacle, \
left_lane, left_green_light, left_follow, no_left_lane, left_obstacle, left_solid_line, \
right_lane, right_green_light, right_follow, no_right_lane, right_obstacle, right_solid_line = attrs
(
green_light,
follow,
road_clear,
red_light,
traffic_sign,
car,
person,
rider,
other_obstacle,
left_lane,
left_green_light,
left_follow,
no_left_lane,
left_obstacle,
left_solid_line,
right_lane,
right_green_light,
right_follow,
no_right_lane,
right_obstacle,
right_solid_line,
) = attrs


illegal_return = (0, 0, 0, 0) illegal_return = (0, 0, 0, 0)
if red_light == green_light == 1: if red_light == green_light == 1:
@@ -30,8 +51,8 @@ class BDDKB(KBBase):
obstacle = car or person or rider or other_obstacle obstacle = car or person or rider or other_obstacle
if road_clear == obstacle: if road_clear == obstacle:
return illegal_return return illegal_return
move_forward = (green_light or follow or road_clear)
stop = (red_light or traffic_sign or obstacle)
move_forward = green_light or follow or road_clear
stop = red_light or traffic_sign or obstacle
if stop: if stop:
move_forward = 0 move_forward = 0


@@ -43,4 +64,4 @@ class BDDKB(KBBase):
cannot_turn_right = no_right_lane or right_obstacle or right_solid_line cannot_turn_right = no_right_lane or right_obstacle or right_solid_line
turn_right = can_turn_right and int(not cannot_turn_right) turn_right = can_turn_right and int(not cannot_turn_right)


return move_forward, stop, turn_left, turn_right
return move_forward, stop, turn_left, turn_right

+ 0
- 1
examples/bdd_oia/requirements.txt View File

@@ -1,2 +1 @@
torch
ablkit ablkit

+ 7
- 2
examples/hwf/main.py View File

@@ -81,7 +81,7 @@ def main():
"--label-smoothing", "--label-smoothing",
type=float, type=float,
default=0.2, default=0.2,
help="label smoothing in cross entropy loss (default : 0.2)"
help="label smoothing in cross entropy loss (default : 0.2)",
) )
parser.add_argument( parser.add_argument(
"--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)" "--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)"
@@ -138,7 +138,12 @@ def main():


# Build BasicNN # Build BasicNN
base_model = BasicNN( base_model = BasicNN(
cls, loss_fn, optimizer, device=device, batch_size=args.batch_size, num_epochs=args.epochs,
cls,
loss_fn,
optimizer,
device=device,
batch_size=args.batch_size,
num_epochs=args.epochs,
) )


# Build ABLModel # Build ABLModel


Loading…
Cancel
Save