Browse Source

[FIX] pass flake8

examples
troyyyyy 3 months ago
parent
commit
8db8c1b4a5
10 changed files with 59 additions and 25 deletions
  1. +2
    -1
      examples/bdd_oia/bridge.py
  2. +3
    -2
      examples/bdd_oia/dataset/data_util.py
  3. +6
    -4
      examples/bdd_oia/main.py
  4. +5
    -2
      examples/bdd_oia/metric.py
  5. +2
    -1
      examples/bdd_oia/models/bdd_model.py
  6. +4
    -5
      examples/bdd_oia/models/bdd_nn.py
  7. +3
    -1
      examples/bdd_oia/models/nn.py
  8. +27
    -6
      examples/bdd_oia/reasoning/bddkb.py
  9. +0
    -1
      examples/bdd_oia/requirements.txt
  10. +7
    -2
      examples/hwf/main.py

+ 2
- 1
examples/bdd_oia/bridge.py View File

@@ -3,6 +3,7 @@ from typing import List, Any
from ablkit.data import ListData
from ablkit.bridge import SimpleBridge


class BDDBridge(SimpleBridge):
def idx_to_pseudo_label(self, data_examples: ListData) -> List[List[Any]]:
pred_idx = data_examples.pred_idx # [ ndarray(1,nc),... ]
@@ -20,4 +21,4 @@ class BDDBridge(SimpleBridge):
sub_list = np.array([self.reasoner.label_to_idx[_lab] for _lab in sub_list])
abduced_idx.append(sub_list)
data_examples.abduced_idx = abduced_idx
return data_examples.abduced_idx
return data_examples.abduced_idx

+ 3
- 2
examples/bdd_oia/dataset/data_util.py View File

@@ -14,5 +14,6 @@ def get_dataset(fname, get_pseudo_label=True):
Y = [tuple(y) for y in Y]
return X, pseudo_label, Y

if __name__ == '__main__':
dataset = get_dataset("val.npz")

if __name__ == "__main__":
dataset = get_dataset("val.npz")

+ 6
- 4
examples/bdd_oia/main.py View File

@@ -8,7 +8,7 @@ from ablkit.data.evaluation import SymbolAccuracy
from ablkit.reasoning import Reasoner
from ablkit.utils import ABLLogger, print_log

from models.nn import *
from models.nn import ConceptNet
from models.bdd_nn import BDDNN
from models.bdd_model import BDDABLModel
from reasoning.bddkb import BDDKB
@@ -19,12 +19,13 @@ from metric import BDDReasoningMetric

def multi_label_confidence_dist(data_example, candidates, candidates_idxs, reasoning_results):
pred_prob = data_example.pred_prob.T # nc x 1
pred_prob = np.concatenate([1-pred_prob, pred_prob], axis=1) # nc x 2
pred_prob = np.concatenate([1 - pred_prob, pred_prob], axis=1) # nc x 2
cols = np.arange(len(candidates_idxs[0]))[None, :]
corr_prob = pred_prob[cols, candidates_idxs]
costs = - np.sum(np.log(corr_prob + 1e-6), axis=1)
costs = -np.sum(np.log(corr_prob + 1e-6), axis=1)
return costs


def get_args():
parser = argparse.ArgumentParser(description="BDD-OIA example")
parser.add_argument(
@@ -62,6 +63,7 @@ def get_args():
args = parser.parse_args()
return args


def main():
args = get_args()

@@ -116,7 +118,7 @@ def main():
kb,
dist_func=multi_label_confidence_dist,
max_revision=args.max_revision,
require_more_revision=args.require_more_revision
require_more_revision=args.require_more_revision,
)

# -- Building Evaluation Metrics --------------------


+ 5
- 2
examples/bdd_oia/metric.py View File

@@ -3,6 +3,7 @@ from typing import Optional
from ablkit.reasoning import KBBase
from ablkit.data import BaseMetric, ListData


class BDDReasoningMetric(BaseMetric):
def __init__(self, kb: KBBase, prefix: Optional[str] = None) -> None:
super().__init__(prefix)
@@ -13,7 +14,9 @@ class BDDReasoningMetric(BaseMetric):
y_list = data_examples.Y
x_list = data_examples.X
for pred_pseudo_label, y, x in zip(pred_pseudo_label_list, y_list, x_list):
pred_y = self.kb.logic_forward(pred_pseudo_label, *(x,) if self.kb._num_args == 2 else ())
pred_y = self.kb.logic_forward(
pred_pseudo_label, *(x,) if self.kb._num_args == 2 else ()
)
for py, yy in zip(pred_y, y):
self.results.append(int(py == yy))

@@ -21,4 +24,4 @@ class BDDReasoningMetric(BaseMetric):
results = self.results
metrics = dict()
metrics["reasoning_accuracy"] = sum(results) / len(results)
return metrics
return metrics

+ 2
- 1
examples/bdd_oia/models/bdd_model.py View File

@@ -5,6 +5,7 @@ from ablkit.data import ListData
from ablkit.learning import ABLModel
from ablkit.utils import reform_list


class BDDABLModel(ABLModel):
def predict(self, data_examples: ListData) -> Dict:
model = self.base_model
@@ -21,4 +22,4 @@ class BDDABLModel(ABLModel):
data_examples.pred_idx = label
data_examples.pred_prob = prob

return {"label": label, "prob": prob}
return {"label": label, "prob": prob}

+ 4
- 5
examples/bdd_oia/models/bdd_nn.py View File

@@ -1,10 +1,9 @@
import logging
import os
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Callable, List, Optional

import numpy
import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import DataLoader

from ablkit.learning import BasicNN, PredictionDataset, ClassificationDataset
from ablkit.utils.logger import print_log
@@ -15,11 +14,11 @@ class MultiLabelClassificationDataset(ClassificationDataset):
if (not isinstance(X, list)) or (not isinstance(Y, list)):
raise ValueError("X and Y should be of type list.")
self.X = X
self.Y = torch.FloatTensor(numpy.stack(Y, axis=0)) # float32 for BCELoss
self.Y = torch.FloatTensor(numpy.stack(Y, axis=0)) # float32 for BCELoss
self.transform = transform

class BDDNN(BasicNN):

class BDDNN(BasicNN):
def predict(
self,
data_loader: Optional[DataLoader] = None,


+ 3
- 1
examples/bdd_oia/models/nn.py View File

@@ -1,5 +1,6 @@
from torch import nn


class SimpleNet(nn.Module):
def __init__(self, num_features=2048, num_concepts=21):
super(SimpleNet, self).__init__()
@@ -8,6 +9,7 @@ class SimpleNet(nn.Module):
def forward(self, x):
return self.fc(x)


class ConceptNet(nn.Module):
def __init__(self, num_features=2048, num_concepts=21):
super(ConceptNet, self).__init__()
@@ -15,7 +17,7 @@ class ConceptNet(nn.Module):
self.fc = nn.Sequential(
nn.Linear(num_features, intermidate_dim),
nn.SiLU(),
nn.Linear(intermidate_dim, num_concepts)
nn.Linear(intermidate_dim, num_concepts),
)

def forward(self, x):


+ 27
- 6
examples/bdd_oia/reasoning/bddkb.py View File

@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
from ablkit.reasoning import KBBase


class BDDKB(KBBase):
def __init__(self, pseudo_label_list=None):
if pseudo_label_list is None:
@@ -20,9 +21,29 @@ class BDDKB(KBBase):
(1, 0, 1, 1) 196
"""
assert len(attrs) == 21
green_light, follow, road_clear, red_light, traffic_sign, car, person, rider, other_obstacle, \
left_lane, left_green_light, left_follow, no_left_lane, left_obstacle, left_solid_line, \
right_lane, right_green_light, right_follow, no_right_lane, right_obstacle, right_solid_line = attrs
(
green_light,
follow,
road_clear,
red_light,
traffic_sign,
car,
person,
rider,
other_obstacle,
left_lane,
left_green_light,
left_follow,
no_left_lane,
left_obstacle,
left_solid_line,
right_lane,
right_green_light,
right_follow,
no_right_lane,
right_obstacle,
right_solid_line,
) = attrs

illegal_return = (0, 0, 0, 0)
if red_light == green_light == 1:
@@ -30,8 +51,8 @@ class BDDKB(KBBase):
obstacle = car or person or rider or other_obstacle
if road_clear == obstacle:
return illegal_return
move_forward = (green_light or follow or road_clear)
stop = (red_light or traffic_sign or obstacle)
move_forward = green_light or follow or road_clear
stop = red_light or traffic_sign or obstacle
if stop:
move_forward = 0

@@ -43,4 +64,4 @@ class BDDKB(KBBase):
cannot_turn_right = no_right_lane or right_obstacle or right_solid_line
turn_right = can_turn_right and int(not cannot_turn_right)

return move_forward, stop, turn_left, turn_right
return move_forward, stop, turn_left, turn_right

+ 0
- 1
examples/bdd_oia/requirements.txt View File

@@ -1,2 +1 @@
torch
ablkit

+ 7
- 2
examples/hwf/main.py View File

@@ -81,7 +81,7 @@ def main():
"--label-smoothing",
type=float,
default=0.2,
help="label smoothing in cross entropy loss (default : 0.2)"
help="label smoothing in cross entropy loss (default : 0.2)",
)
parser.add_argument(
"--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)"
@@ -138,7 +138,12 @@ def main():

# Build BasicNN
base_model = BasicNN(
cls, loss_fn, optimizer, device=device, batch_size=args.batch_size, num_epochs=args.epochs,
cls,
loss_fn,
optimizer,
device=device,
batch_size=args.batch_size,
num_epochs=args.epochs,
)

# Build ABLModel


Loading…
Cancel
Save