Compare commits

...

9 Commits

Author SHA1 Message Date
  troyyyyy 559e2ac91d [DOC] update readme doc 3 months ago
  troyyyyy d0df086e6e [FIX] add emoji in title 3 months ago
  troyyyyy 51274d316d [FIX] change font size 3 months ago
  troyyyyy e86f12e1f1 [FIX] fix placeholder 3 months ago
  troyyyyy 695f9d2176 [ENH] add contributors 3 months ago
  troyyyyy 8db8c1b4a5 [FIX] pass flake8 3 months ago
  AbductiveLearning 79a26a0dc5
Merge pull request #12 from wnqn1597/examples 3 months ago
  Beq Jal c9c82ed8e5 update readme 3 months ago
  Beq Jal c4c85dd02a add BDD-OIA example 3 months ago
14 changed files with 540 additions and 13 deletions
Split View
  1. +52
    -9
      README.md
  2. +2
    -2
      ablkit/reasoning/reasoner.py
  3. +50
    -0
      examples/bdd_oia/README.md
  4. +24
    -0
      examples/bdd_oia/bridge.py
  5. +19
    -0
      examples/bdd_oia/dataset/data_util.py
  6. BIN
      examples/bdd_oia/dataset/dataset.zip
  7. +149
    -0
      examples/bdd_oia/main.py
  8. +27
    -0
      examples/bdd_oia/metric.py
  9. +25
    -0
      examples/bdd_oia/models/bdd_model.py
  10. +93
    -0
      examples/bdd_oia/models/bdd_nn.py
  11. +24
    -0
      examples/bdd_oia/models/nn.py
  12. +67
    -0
      examples/bdd_oia/reasoning/bddkb.py
  13. +1
    -0
      examples/bdd_oia/requirements.txt
  14. +7
    -2
      examples/hwf/main.py

+ 52
- 9
README.md View File

@@ -2,13 +2,13 @@

<img src="https://raw.githubusercontent.com/AbductiveLearning/ABLkit/main/docs/_static/img/logo.png" width="180">

[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/ablkit)](https://pypi.org/project/ablkit/) [![PyPI version](https://badgen.net/pypi/v/ablkit)](https://pypi.org/project/ablkit/) [![Documentation Status](https://readthedocs.org/projects/ablkit/badge/?version=latest)](https://ablkit.readthedocs.io/en/latest/?badge=latest) [![license](https://img.shields.io/github/license/mashape/apistatus.svg?maxAge=2592000)](https://github.com/AbductiveLearning/ABLkit/blob/main/LICENSE) [![flake8 Lint](https://github.com/AbductiveLearning/ABLkit/actions/workflows/lint.yaml/badge.svg)](https://github.com/AbductiveLearning/ABLkit/actions/workflows/lint.yaml) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![ABLkit-CI](https://github.com/AbductiveLearning/ABLkit/actions/workflows/build-and-test.yaml/badge.svg)](https://github.com/AbductiveLearning/ABLkit/actions/workflows/build-and-test.yaml)
[![license](https://img.shields.io/github/license/mashape/apistatus.svg?maxAge=2592000)](https://github.com/AbductiveLearning/ABLkit/blob/main/LICENSE) [![last commit](https://img.shields.io/github/last-commit/AbductiveLearning/ablkit)](https://img.shields.io/github/last-commit/AbductiveLearning/ablkit) [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/ablkit)](https://pypi.org/project/ablkit/) [![PyPI version](https://badgen.net/pypi/v/ablkit)](https://pypi.org/project/ablkit/) [![Documentation Status](https://readthedocs.org/projects/ablkit/badge/?version=latest)](https://ablkit.readthedocs.io/en/latest/?badge=latest) [![ABLkit-CI](https://github.com/AbductiveLearning/ABLkit/actions/workflows/build-and-test.yaml/badge.svg)](https://github.com/AbductiveLearning/ABLkit/actions/workflows/build-and-test.yaml) [![flake8 Lint](https://github.com/AbductiveLearning/ABLkit/actions/workflows/lint.yaml/badge.svg)](https://github.com/AbductiveLearning/ABLkit/actions/workflows/lint.yaml) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![PyPI - Downloads](https://img.shields.io/pypi/dm/ablkit)](https://pypi.org/project/ablkit/)

[📘Documentation](https://ablkit.readthedocs.io/en/latest/index.html) | [📄Paper](https://journal.hep.com.cn/fcs/EN/10.1007/s11704-024-40085-7) | [📚Examples](https://github.com/AbductiveLearning/ABLkit/tree/main/examples) | [💬Reporting Issues](https://github.com/AbductiveLearning/ABLkit/issues/new)
[📘Documentation](https://ablkit.readthedocs.io/en/latest/index.html) | [📄Paper](https://journal.hep.com.cn/fcs/EN/10.1007/s11704-024-40085-7) | [🧪Examples](https://github.com/AbductiveLearning/ABLkit/tree/main/examples) | [💬Reporting Issues](https://github.com/AbductiveLearning/ABLkit/issues/new)

</div>

# ABLkit: A Toolkit for Abductive Learning
# 🧰 ABLkit: A Toolkit for Abductive Learning 📊📐

**ABLkit** is an efficient Python toolkit for [**Abductive Learning (ABL)**](https://www.lamda.nju.edu.cn/publication/chap_ABL.pdf). ABL is a novel paradigm that integrates machine learning and logical reasoning in a unified framework. It is suitable for tasks where both data and (logical) domain knowledge are available.

@@ -28,7 +28,7 @@ ABLkit encapsulates advanced ABL techniques, providing users with an efficient a
<img src="https://raw.githubusercontent.com/AbductiveLearning/ABLkit/main/docs/_static/img/ABLkit.png" alt="ABLkit" style="width: 80%;"/>
</p>

## Installation
## 🛠️ Installation

### Install from PyPI

@@ -60,7 +60,7 @@ sudo apt-get install swi-prolog

For Windows and Mac users, please refer to the [SWI-Prolog Install Guide](https://github.com/yuce/pyswip/blob/master/INSTALL.md).

## Quick Start
## Quick Start

We use the MNIST Addition task as a quick start example. In this task, pairs of MNIST handwritten images and their sums are given, alongwith a domain knowledge base which contains information on how to perform addition operations. Our objective is to input a pair of handwritten images and accurately determine their sum.

@@ -184,7 +184,7 @@ bridge.test(test_data)

To explore detailed tutorials and information, please refer to: [Documentation on Read the Docs](https://ablkit.readthedocs.io/en/latest/index.html).

## Examples
## 🧪 Examples

We provide several examples in `examples/`. Each example is stored in a separate folder containing a README file.

@@ -192,8 +192,9 @@ We provide several examples in `examples/`. Each example is stored in a separate
+ [Handwritten Formula (HWF)](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/hwf)
+ [Handwritten Equation Decipherment](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/hed)
+ [Zoo](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/zoo)
+ [BDD-OIA](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/bdd_oia)

## References
## 📚 References

For more information about ABL, please refer to: [Zhou, 2019](http://scis.scichina.com/en/2019/076101.pdf) and [Zhou and Huang, 2022](https://www.lamda.nju.edu.cn/publication/chap_ABL.pdf).

@@ -220,7 +221,7 @@ For more information about ABL, please refer to: [Zhou, 2019](http://scis.scichi
}
```

## Citation
## 📝 Citation

To cite ABLkit, please cite the following paper: [Huang et al., 2024](https://journal.hep.com.cn/fcs/EN/10.1007/s11704-024-40085-7).

@@ -234,4 +235,46 @@ To cite ABLkit, please cite the following paper: [Huang et al., 2024](https://j
pages = {186354},
year = {2024}
}
```
```

## ✨ Contributors

We would like to thank the following contributors for their efforts on this project: <sub><i>(*: current maintainer)</i></sub>

<table>
<tr>
<td align="center">
<a href="https://github.com/Tony-HYX">
<img src="https://avatars.githubusercontent.com/u/34394824?V=4" width="100px;" alt=""/>
<br />
Yu-Xuan Huang
</a>
</td>
<td align="center">
<a href="https://github.com/troyyyyy">
<img src="https://avatars.githubusercontent.com/u/49091847?v=4" width="100px;" alt=""/>
<br />
Wen-Chao Hu
</a>*
</td>
<td align="center">
<a href="https://github.com/WaTerminator">
<img src="https://avatars.githubusercontent.com/u/58843099?V=4" width="100px;" alt=""/>
<br />
En-Hao Gao
</a>
</td>
<td align="center">
<a href="https://github.com/snqn1597">
<img src="https://avatars.githubusercontent.com/u/98020642?V=4" width="100px;" alt=""/>
<br />
Qi-Jie Li
</a>
</td>
</tr>
</table>

We also thank the following users for their helpful suggestions and feedback:

- [Hao-Yuan He](https://github.com/Hao-Yuan-He)
- [Wang-Zhou Dai](https://github.com/haldai)

+ 2
- 2
ablkit/reasoning/reasoner.py View File

@@ -180,8 +180,8 @@ class Reasoner:
candidates_idxs = [[self.label_to_idx[x] for x in c] for c in candidates]
return avg_confidence_dist(data_example.pred_prob, candidates_idxs)
else:
candidate_idxs = [[self.label_to_idx[x] for x in c] for c in candidates]
cost_list = self.dist_func(data_example, candidates, candidate_idxs, reasoning_results)
candidates_idxs = [[self.label_to_idx[x] for x in c] for c in candidates]
cost_list = self.dist_func(data_example, candidates, candidates_idxs, reasoning_results)
if len(cost_list) != len(candidates):
raise ValueError(
"The length of the array returned by dist_func must be equal to the number "


+ 50
- 0
examples/bdd_oia/README.md View File

@@ -0,0 +1,50 @@
# BDD-OIA

This example shows an implementation of [BDD-OIA](https://twizwei.github.io/bddoia_project/) task. The BDD-OIA dataset comprises frames extracted from driving scene videos, which are utilized for autonomous driving predictions. Each frame is annotated with 4 binary labels, indicating the possible actions, namely $\textsf{move forward}$, $\textsf{stop}$, $\textsf{turn left}$, $\textsf{turn right}$. Each frame is also annotated with 21 intermediate binary concepts such as $\textsf{red light}$, $\textsf{road clear}$, etc., underlying the reasons for the possible actions.

The objective is to predict possible actions for each frame. During training, we only make use of the label supervision, along with a knowledge base, which comprises information about the relations between concepts and actions, e.g., $\textsf{red light} \lor \textsf{traffic sign} \lor \textsf{obstacle} \implies \textsf{stop}$. The training set consists of 16,000 frames, while the test set contains 4,500 annotated data points.

Before usage, the dataset was pre-processed by [Marconato et al. (2023)](https://proceedings.neurips.cc/paper_files/paper/2023/file/e560202b6e779a82478edb46c6f8f4dd-Paper-Conference.pdf) using a pretrained Faster-RCNN model on BDD-100k, in conjunction with the first module in CBM-AUC [(Sawada & Nakamura, 2022)](https://arxiv.org/abs/2202.01459), resulting in embeddings of dimension 2048.

## Run

```bash
pip install -r requirements.txt
cd dataset
unzip dataset.zip
cd ..
python main.py
```

## Usage

```bash
usage: main.py [-h] [--no-cuda] [--epochs EPOCHS] [--lr LR]
[--batch-size BATCH_SIZE] [--loops LOOPS]
[--segment_size SEGMENT_SIZE]
[--save_interval SAVE_INTERVAL]
[--max-revision MAX_REVISION]
[--require-more-revision REQUIRE_MORE_REVISION]

BDD_OIA example

optional arguments:
-h, --help show this help message and exit
--no-cuda disables CUDA training
--epochs EPOCHS number of epochs in each learning loop iteration
(default : 1)
--lr LR base model learning rate (default : 0.002)
--batch-size BATCH_SIZE
base model batch size (default : 32)
--loops LOOPS
number of loop iterations (default : 2)
--segment_size SEGMENT_SIZE
segment size (default : 0.01)
--save_interval SAVE_INTERVAL
save interval (default : 1)
--max-revision MAX_REVISION
maximum revision in reasoner (default : 3)
-require-more-revision REQUIRE_MORE_REVISION
require more revision in reasoner (default : 3)

```

+ 24
- 0
examples/bdd_oia/bridge.py View File

@@ -0,0 +1,24 @@
import numpy as np
from typing import List, Any
from ablkit.data import ListData
from ablkit.bridge import SimpleBridge


class BDDBridge(SimpleBridge):
def idx_to_pseudo_label(self, data_examples: ListData) -> List[List[Any]]:
pred_idx = data_examples.pred_idx # [ ndarray(1,nc),... ]
pred_pseudo_label = []
for sub_list in pred_idx:
sub_list = sub_list.squeeze() # 1 x nc -> nc
pred_pseudo_label.append([self.reasoner.idx_to_label[_idx] for _idx in sub_list])
data_examples.pred_pseudo_label = pred_pseudo_label
return data_examples.pred_pseudo_label

def pseudo_label_to_idx(self, data_examples: ListData) -> List[List[Any]]:
abduced_pseudo_label = data_examples.abduced_pseudo_label
abduced_idx = []
for sub_list in abduced_pseudo_label:
sub_list = np.array([self.reasoner.label_to_idx[_lab] for _lab in sub_list])
abduced_idx.append(sub_list)
data_examples.abduced_idx = abduced_idx
return data_examples.abduced_idx

+ 19
- 0
examples/bdd_oia/dataset/data_util.py View File

@@ -0,0 +1,19 @@
import os
import numpy as np

CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))


def get_dataset(fname, get_pseudo_label=True):
fname = os.path.join(CURRENT_DIR, fname)
data = np.load(fname)
X = data["X"]
X = [[emb.astype(np.float32)] for emb in X]
pseudo_label = data["pseudo_label"].astype(int).tolist() if get_pseudo_label else None
Y = data["Y"][:, :4].astype(int).tolist()
Y = [tuple(y) for y in Y]
return X, pseudo_label, Y


if __name__ == "__main__":
dataset = get_dataset("val.npz")

BIN
examples/bdd_oia/dataset/dataset.zip View File


+ 149
- 0
examples/bdd_oia/main.py View File

@@ -0,0 +1,149 @@
import argparse
import os.path as osp
import numpy as np
import torch
from torch import optim

from ablkit.data.evaluation import SymbolAccuracy
from ablkit.reasoning import Reasoner
from ablkit.utils import ABLLogger, print_log

from models.nn import ConceptNet
from models.bdd_nn import BDDNN
from models.bdd_model import BDDABLModel
from reasoning.bddkb import BDDKB
from dataset.data_util import get_dataset
from bridge import BDDBridge
from metric import BDDReasoningMetric


def multi_label_confidence_dist(data_example, candidates, candidates_idxs, reasoning_results):
pred_prob = data_example.pred_prob.T # nc x 1
pred_prob = np.concatenate([1 - pred_prob, pred_prob], axis=1) # nc x 2
cols = np.arange(len(candidates_idxs[0]))[None, :]
corr_prob = pred_prob[cols, candidates_idxs]
costs = -np.sum(np.log(corr_prob + 1e-6), axis=1)
return costs


def get_args():
parser = argparse.ArgumentParser(description="BDD-OIA example")
parser.add_argument(
"--no-cuda", action="store_true", default=False, help="disables CUDA training"
)
parser.add_argument(
"--epochs",
type=int,
default=1,
help="number of epochs in each learning loop iteration (default : 1)",
)
parser.add_argument(
"--lr", type=float, default=2e-3, help="base model learning rate (default : 0.002)"
)
parser.add_argument(
"--batch-size", type=int, default=32, help="base model batch size (default : 32)"
)
parser.add_argument(
"--loops", type=int, default=2, help="number of loop iterations (default : 2)"
)
parser.add_argument(
"--segment_size", type=int, default=0.01, help="segment size (default : 0.01)"
)
parser.add_argument("--save_interval", type=int, default=1, help="save interval (default : 1)")
parser.add_argument(
"--max-revision", type=int, default=3, help="maximum revision in reasoner (default : 3)"
)
parser.add_argument(
"--require-more-revision",
type=int,
default=3,
help="require more revision in reasoner (default : 3)",
)

args = parser.parse_args()
return args


def main():
args = get_args()

# Build logger
print_log("Abductive Learning on the BDD-OIA example.", logger="current")

# -- Working with Data ------------------------------
print_log("Working with Data.", logger="current")
train_data = get_dataset(fname="train.npz", get_pseudo_label=True)
val_data = get_dataset(fname="val.npz", get_pseudo_label=True)
test_data = get_dataset(fname="test.npz", get_pseudo_label=True)

# -- Building the Learning Part ---------------------
print_log("Building the Learning Part.", logger="current")

# Build necessary components for BDDNN
cls = ConceptNet()
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(cls.parameters(), lr=args.lr)
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=args.lr,
pct_start=0.15,
epochs=args.loops,
steps_per_epoch=int(1 / args.segment_size) + 1,
)

# Build BDDNN
base_model = BDDNN(
cls,
loss_fn,
optimizer,
scheduler=scheduler,
device=device,
batch_size=args.batch_size,
num_epochs=args.epochs,
)

# Build ABLModel
model = BDDABLModel(base_model)

# -- Building the Reasoning Part --------------------
print_log("Building the Reasoning Part.", logger="current")

# Build knowledge base
kb = BDDKB()

# Create reasoner
reasoner = Reasoner(
kb,
dist_func=multi_label_confidence_dist,
max_revision=args.max_revision,
require_more_revision=args.require_more_revision,
)

# -- Building Evaluation Metrics --------------------
print_log("Building Evaluation Metrics.", logger="current")
metric_list = [SymbolAccuracy(prefix="bdd_oia"), BDDReasoningMetric(kb=kb, prefix="bdd_oia")]

# -- Bridging Learning and Reasoning ----------------
print_log("Bridge Learning and Reasoning.", logger="current")
bridge = BDDBridge(model, reasoner, metric_list)

# Retrieve the directory of the Log file and define the directory for saving the model weights.
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")

# Train and Test
bridge.train(
train_data=train_data,
val_data=val_data,
loops=args.loops,
segment_size=args.segment_size,
save_interval=args.save_interval,
save_dir=weights_dir,
)
bridge.test(test_data)


if __name__ == "__main__":
main()

+ 27
- 0
examples/bdd_oia/metric.py View File

@@ -0,0 +1,27 @@
from typing import Optional

from ablkit.reasoning import KBBase
from ablkit.data import BaseMetric, ListData


class BDDReasoningMetric(BaseMetric):
def __init__(self, kb: KBBase, prefix: Optional[str] = None) -> None:
super().__init__(prefix)
self.kb = kb

def process(self, data_examples: ListData) -> None:
pred_pseudo_label_list = data_examples.pred_pseudo_label
y_list = data_examples.Y
x_list = data_examples.X
for pred_pseudo_label, y, x in zip(pred_pseudo_label_list, y_list, x_list):
pred_y = self.kb.logic_forward(
pred_pseudo_label, *(x,) if self.kb._num_args == 2 else ()
)
for py, yy in zip(pred_y, y):
self.results.append(int(py == yy))

def compute_metrics(self) -> dict:
results = self.results
metrics = dict()
metrics["reasoning_accuracy"] = sum(results) / len(results)
return metrics

+ 25
- 0
examples/bdd_oia/models/bdd_model.py View File

@@ -0,0 +1,25 @@
from typing import Dict

import numpy as np
from ablkit.data import ListData
from ablkit.learning import ABLModel
from ablkit.utils import reform_list


class BDDABLModel(ABLModel):
def predict(self, data_examples: ListData) -> Dict:
model = self.base_model
data_X = data_examples.flatten("X")
if hasattr(model, "predict_proba"):
prob = model.predict_proba(X=data_X)
label = np.where(prob > 0.5, 1, 0).astype(int)
prob = reform_list(prob, data_examples.X)
else:
prob = None
label = model.predict(X=data_X)
label = reform_list(label, data_examples.X)

data_examples.pred_idx = label
data_examples.pred_prob = prob

return {"label": label, "prob": prob}

+ 93
- 0
examples/bdd_oia/models/bdd_nn.py View File

@@ -0,0 +1,93 @@
import logging
from typing import Any, Callable, List, Optional

import numpy
import torch
from torch.utils.data import DataLoader

from ablkit.learning import BasicNN, PredictionDataset, ClassificationDataset
from ablkit.utils.logger import print_log


class MultiLabelClassificationDataset(ClassificationDataset):
def __init__(self, X: List[Any], Y: List[int], transform: Optional[Callable[..., Any]] = None):
if (not isinstance(X, list)) or (not isinstance(Y, list)):
raise ValueError("X and Y should be of type list.")
self.X = X
self.Y = torch.FloatTensor(numpy.stack(Y, axis=0)) # float32 for BCELoss
self.transform = transform


class BDDNN(BasicNN):
def predict(
self,
data_loader: Optional[DataLoader] = None,
X: Optional[List[Any]] = None,
) -> numpy.ndarray:
if data_loader is not None and X is not None:
print_log(
"Predict the class of input data in data_loader instead of X.",
logger="current",
level=logging.WARNING,
)

if data_loader is None:
dataset = PredictionDataset(X, self.test_transform)
data_loader = DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
collate_fn=self.collate_fn,
pin_memory=torch.cuda.is_available(),
)
pred_probs = self._predict(data_loader).sigmoid()
pred = torch.where(pred_probs > 0.5, 1, 0).int()
return pred.cpu().numpy()

def predict_proba(
self,
data_loader: Optional[DataLoader] = None,
X: Optional[List[Any]] = None,
) -> numpy.ndarray:
if data_loader is not None and X is not None:
print_log(
"Predict the class probability of input data in data_loader instead of X.",
logger="current",
level=logging.WARNING,
)

if data_loader is None:
dataset = PredictionDataset(X, self.test_transform)
data_loader = DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
collate_fn=self.collate_fn,
pin_memory=torch.cuda.is_available(),
)
pred_probs = self._predict(data_loader).sigmoid() # B x NC
return pred_probs.cpu().numpy()

def _data_loader(
self,
X: Optional[List[Any]],
y: Optional[List[int]] = None,
shuffle: Optional[bool] = True,
) -> DataLoader:
if X is None:
raise ValueError("X should not be None.")
if y is None:
y = [0] * len(X)
if not len(y) == len(X):
raise ValueError("X and y should have equal length.")

dataset = MultiLabelClassificationDataset(X, y, transform=self.train_transform)
data_loader = DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=shuffle,
num_workers=self.num_workers,
collate_fn=self.collate_fn,
pin_memory=torch.cuda.is_available(),
)
return data_loader

+ 24
- 0
examples/bdd_oia/models/nn.py View File

@@ -0,0 +1,24 @@
from torch import nn


class SimpleNet(nn.Module):
def __init__(self, num_features=2048, num_concepts=21):
super(SimpleNet, self).__init__()
self.fc = nn.Linear(num_features, num_concepts)

def forward(self, x):
return self.fc(x)


class ConceptNet(nn.Module):
def __init__(self, num_features=2048, num_concepts=21):
super(ConceptNet, self).__init__()
intermidate_dim = 256
self.fc = nn.Sequential(
nn.Linear(num_features, intermidate_dim),
nn.SiLU(),
nn.Linear(intermidate_dim, num_concepts),
)

def forward(self, x):
return self.fc(x)

+ 67
- 0
examples/bdd_oia/reasoning/bddkb.py View File

@@ -0,0 +1,67 @@
# -*- coding: utf-8 -*-
from ablkit.reasoning import KBBase


class BDDKB(KBBase):
def __init__(self, pseudo_label_list=None):
if pseudo_label_list is None:
pseudo_label_list = [0, 1]
super().__init__(pseudo_label_list)

def logic_forward(self, attrs):
"""
Abduction space
(0, 1, 0, 0) 610812
(0, 1, 0, 1) 75012
(0, 1, 1, 0) 75012
(0, 1, 1, 1) 9212
(1, 0, 0, 0) 12996
(1, 0, 0, 1) 1596
(1, 0, 1, 0) 1596
(1, 0, 1, 1) 196
"""
assert len(attrs) == 21
(
green_light,
follow,
road_clear,
red_light,
traffic_sign,
car,
person,
rider,
other_obstacle,
left_lane,
left_green_light,
left_follow,
no_left_lane,
left_obstacle,
left_solid_line,
right_lane,
right_green_light,
right_follow,
no_right_lane,
right_obstacle,
right_solid_line,
) = attrs

illegal_return = (0, 0, 0, 0)
if red_light == green_light == 1:
return illegal_return
obstacle = car or person or rider or other_obstacle
if road_clear == obstacle:
return illegal_return
move_forward = green_light or follow or road_clear
stop = red_light or traffic_sign or obstacle
if stop:
move_forward = 0

can_turn_left = left_lane or left_green_light or left_follow
cannot_turn_left = no_left_lane or left_obstacle or left_solid_line
turn_left = can_turn_left and int(not cannot_turn_left)

can_turn_right = right_lane or right_green_light or right_follow
cannot_turn_right = no_right_lane or right_obstacle or right_solid_line
turn_right = can_turn_right and int(not cannot_turn_right)

return move_forward, stop, turn_left, turn_right

+ 1
- 0
examples/bdd_oia/requirements.txt View File

@@ -0,0 +1 @@
ablkit

+ 7
- 2
examples/hwf/main.py View File

@@ -81,7 +81,7 @@ def main():
"--label-smoothing",
type=float,
default=0.2,
help="label smoothing in cross entropy loss (default : 0.2)"
help="label smoothing in cross entropy loss (default : 0.2)",
)
parser.add_argument(
"--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)"
@@ -138,7 +138,12 @@ def main():

# Build BasicNN
base_model = BasicNN(
cls, loss_fn, optimizer, device=device, batch_size=args.batch_size, num_epochs=args.epochs,
cls,
loss_fn,
optimizer,
device=device,
batch_size=args.batch_size,
num_epochs=args.epochs,
)

# Build ABLModel


Loading…
Cancel
Save