Author | SHA1 | Message | Date |
---|---|---|---|
|
559e2ac91d | [DOC] update readme doc | 3 months ago |
|
d0df086e6e | [FIX] add emoji in title | 3 months ago |
|
51274d316d | [FIX] change font size | 3 months ago |
|
e86f12e1f1 | [FIX] fix placeholder | 3 months ago |
|
695f9d2176 | [ENH] add contributors | 3 months ago |
|
8db8c1b4a5 | [FIX] pass flake8 | 3 months ago |
|
79a26a0dc5
|
Merge pull request #12 from wnqn1597/examples
add BDD-OIA example |
3 months ago |
|
c9c82ed8e5 | update readme | 3 months ago |
|
c4c85dd02a | add BDD-OIA example | 3 months ago |
@@ -2,13 +2,13 @@ | |||||
<img src="https://raw.githubusercontent.com/AbductiveLearning/ABLkit/main/docs/_static/img/logo.png" width="180"> | <img src="https://raw.githubusercontent.com/AbductiveLearning/ABLkit/main/docs/_static/img/logo.png" width="180"> | ||||
[](https://pypi.org/project/ablkit/) [](https://pypi.org/project/ablkit/) [](https://ablkit.readthedocs.io/en/latest/?badge=latest) [](https://github.com/AbductiveLearning/ABLkit/blob/main/LICENSE) [](https://github.com/AbductiveLearning/ABLkit/actions/workflows/lint.yaml) [](https://github.com/psf/black) [](https://github.com/AbductiveLearning/ABLkit/actions/workflows/build-and-test.yaml) | |||||
[](https://github.com/AbductiveLearning/ABLkit/blob/main/LICENSE) [](https://img.shields.io/github/last-commit/AbductiveLearning/ablkit) [](https://pypi.org/project/ablkit/) [](https://pypi.org/project/ablkit/) [](https://ablkit.readthedocs.io/en/latest/?badge=latest) [](https://github.com/AbductiveLearning/ABLkit/actions/workflows/build-and-test.yaml) [](https://github.com/AbductiveLearning/ABLkit/actions/workflows/lint.yaml) [](https://github.com/psf/black) [](https://pypi.org/project/ablkit/) | |||||
[📘Documentation](https://ablkit.readthedocs.io/en/latest/index.html) | [📄Paper](https://journal.hep.com.cn/fcs/EN/10.1007/s11704-024-40085-7) | [📚Examples](https://github.com/AbductiveLearning/ABLkit/tree/main/examples) | [💬Reporting Issues](https://github.com/AbductiveLearning/ABLkit/issues/new) | |||||
[📘Documentation](https://ablkit.readthedocs.io/en/latest/index.html) | [📄Paper](https://journal.hep.com.cn/fcs/EN/10.1007/s11704-024-40085-7) | [🧪Examples](https://github.com/AbductiveLearning/ABLkit/tree/main/examples) | [💬Reporting Issues](https://github.com/AbductiveLearning/ABLkit/issues/new) | |||||
</div> | </div> | ||||
# ABLkit: A Toolkit for Abductive Learning | |||||
# 🧰 ABLkit: A Toolkit for Abductive Learning 📊📐 | |||||
**ABLkit** is an efficient Python toolkit for [**Abductive Learning (ABL)**](https://www.lamda.nju.edu.cn/publication/chap_ABL.pdf). ABL is a novel paradigm that integrates machine learning and logical reasoning in a unified framework. It is suitable for tasks where both data and (logical) domain knowledge are available. | **ABLkit** is an efficient Python toolkit for [**Abductive Learning (ABL)**](https://www.lamda.nju.edu.cn/publication/chap_ABL.pdf). ABL is a novel paradigm that integrates machine learning and logical reasoning in a unified framework. It is suitable for tasks where both data and (logical) domain knowledge are available. | ||||
@@ -28,7 +28,7 @@ ABLkit encapsulates advanced ABL techniques, providing users with an efficient a | |||||
<img src="https://raw.githubusercontent.com/AbductiveLearning/ABLkit/main/docs/_static/img/ABLkit.png" alt="ABLkit" style="width: 80%;"/> | <img src="https://raw.githubusercontent.com/AbductiveLearning/ABLkit/main/docs/_static/img/ABLkit.png" alt="ABLkit" style="width: 80%;"/> | ||||
</p> | </p> | ||||
## Installation | |||||
## 🛠️ Installation | |||||
### Install from PyPI | ### Install from PyPI | ||||
@@ -60,7 +60,7 @@ sudo apt-get install swi-prolog | |||||
For Windows and Mac users, please refer to the [SWI-Prolog Install Guide](https://github.com/yuce/pyswip/blob/master/INSTALL.md). | For Windows and Mac users, please refer to the [SWI-Prolog Install Guide](https://github.com/yuce/pyswip/blob/master/INSTALL.md). | ||||
## Quick Start | |||||
## ⚡ Quick Start | |||||
We use the MNIST Addition task as a quick start example. In this task, pairs of MNIST handwritten images and their sums are given, alongwith a domain knowledge base which contains information on how to perform addition operations. Our objective is to input a pair of handwritten images and accurately determine their sum. | We use the MNIST Addition task as a quick start example. In this task, pairs of MNIST handwritten images and their sums are given, alongwith a domain knowledge base which contains information on how to perform addition operations. Our objective is to input a pair of handwritten images and accurately determine their sum. | ||||
@@ -184,7 +184,7 @@ bridge.test(test_data) | |||||
To explore detailed tutorials and information, please refer to: [Documentation on Read the Docs](https://ablkit.readthedocs.io/en/latest/index.html). | To explore detailed tutorials and information, please refer to: [Documentation on Read the Docs](https://ablkit.readthedocs.io/en/latest/index.html). | ||||
## Examples | |||||
## 🧪 Examples | |||||
We provide several examples in `examples/`. Each example is stored in a separate folder containing a README file. | We provide several examples in `examples/`. Each example is stored in a separate folder containing a README file. | ||||
@@ -192,8 +192,9 @@ We provide several examples in `examples/`. Each example is stored in a separate | |||||
+ [Handwritten Formula (HWF)](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/hwf) | + [Handwritten Formula (HWF)](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/hwf) | ||||
+ [Handwritten Equation Decipherment](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/hed) | + [Handwritten Equation Decipherment](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/hed) | ||||
+ [Zoo](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/zoo) | + [Zoo](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/zoo) | ||||
+ [BDD-OIA](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/bdd_oia) | |||||
## References | |||||
## 📚 References | |||||
For more information about ABL, please refer to: [Zhou, 2019](http://scis.scichina.com/en/2019/076101.pdf) and [Zhou and Huang, 2022](https://www.lamda.nju.edu.cn/publication/chap_ABL.pdf). | For more information about ABL, please refer to: [Zhou, 2019](http://scis.scichina.com/en/2019/076101.pdf) and [Zhou and Huang, 2022](https://www.lamda.nju.edu.cn/publication/chap_ABL.pdf). | ||||
@@ -220,7 +221,7 @@ For more information about ABL, please refer to: [Zhou, 2019](http://scis.scichi | |||||
} | } | ||||
``` | ``` | ||||
## Citation | |||||
## 📝 Citation | |||||
To cite ABLkit, please cite the following paper: [Huang et al., 2024](https://journal.hep.com.cn/fcs/EN/10.1007/s11704-024-40085-7). | To cite ABLkit, please cite the following paper: [Huang et al., 2024](https://journal.hep.com.cn/fcs/EN/10.1007/s11704-024-40085-7). | ||||
@@ -234,4 +235,46 @@ To cite ABLkit, please cite the following paper: [Huang et al., 2024](https://j | |||||
pages = {186354}, | pages = {186354}, | ||||
year = {2024} | year = {2024} | ||||
} | } | ||||
``` | |||||
``` | |||||
## ✨ Contributors | |||||
We would like to thank the following contributors for their efforts on this project: <sub><i>(*: current maintainer)</i></sub> | |||||
<table> | |||||
<tr> | |||||
<td align="center"> | |||||
<a href="https://github.com/Tony-HYX"> | |||||
<img src="https://avatars.githubusercontent.com/u/34394824?V=4" width="100px;" alt=""/> | |||||
<br /> | |||||
Yu-Xuan Huang | |||||
</a> | |||||
</td> | |||||
<td align="center"> | |||||
<a href="https://github.com/troyyyyy"> | |||||
<img src="https://avatars.githubusercontent.com/u/49091847?v=4" width="100px;" alt=""/> | |||||
<br /> | |||||
Wen-Chao Hu | |||||
</a>* | |||||
</td> | |||||
<td align="center"> | |||||
<a href="https://github.com/WaTerminator"> | |||||
<img src="https://avatars.githubusercontent.com/u/58843099?V=4" width="100px;" alt=""/> | |||||
<br /> | |||||
En-Hao Gao | |||||
</a> | |||||
</td> | |||||
<td align="center"> | |||||
<a href="https://github.com/snqn1597"> | |||||
<img src="https://avatars.githubusercontent.com/u/98020642?V=4" width="100px;" alt=""/> | |||||
<br /> | |||||
Qi-Jie Li | |||||
</a> | |||||
</td> | |||||
</tr> | |||||
</table> | |||||
We also thank the following users for their helpful suggestions and feedback: | |||||
- [Hao-Yuan He](https://github.com/Hao-Yuan-He) | |||||
- [Wang-Zhou Dai](https://github.com/haldai) |
@@ -180,8 +180,8 @@ class Reasoner: | |||||
candidates_idxs = [[self.label_to_idx[x] for x in c] for c in candidates] | candidates_idxs = [[self.label_to_idx[x] for x in c] for c in candidates] | ||||
return avg_confidence_dist(data_example.pred_prob, candidates_idxs) | return avg_confidence_dist(data_example.pred_prob, candidates_idxs) | ||||
else: | else: | ||||
candidate_idxs = [[self.label_to_idx[x] for x in c] for c in candidates] | |||||
cost_list = self.dist_func(data_example, candidates, candidate_idxs, reasoning_results) | |||||
candidates_idxs = [[self.label_to_idx[x] for x in c] for c in candidates] | |||||
cost_list = self.dist_func(data_example, candidates, candidates_idxs, reasoning_results) | |||||
if len(cost_list) != len(candidates): | if len(cost_list) != len(candidates): | ||||
raise ValueError( | raise ValueError( | ||||
"The length of the array returned by dist_func must be equal to the number " | "The length of the array returned by dist_func must be equal to the number " | ||||
@@ -0,0 +1,50 @@ | |||||
# BDD-OIA | |||||
This example shows an implementation of [BDD-OIA](https://twizwei.github.io/bddoia_project/) task. The BDD-OIA dataset comprises frames extracted from driving scene videos, which are utilized for autonomous driving predictions. Each frame is annotated with 4 binary labels, indicating the possible actions, namely $\textsf{move forward}$, $\textsf{stop}$, $\textsf{turn left}$, $\textsf{turn right}$. Each frame is also annotated with 21 intermediate binary concepts such as $\textsf{red light}$, $\textsf{road clear}$, etc., underlying the reasons for the possible actions. | |||||
The objective is to predict possible actions for each frame. During training, we only make use of the label supervision, along with a knowledge base, which comprises information about the relations between concepts and actions, e.g., $\textsf{red light} \lor \textsf{traffic sign} \lor \textsf{obstacle} \implies \textsf{stop}$. The training set consists of 16,000 frames, while the test set contains 4,500 annotated data points. | |||||
Before usage, the dataset was pre-processed by [Marconato et al. (2023)](https://proceedings.neurips.cc/paper_files/paper/2023/file/e560202b6e779a82478edb46c6f8f4dd-Paper-Conference.pdf) using a pretrained Faster-RCNN model on BDD-100k, in conjunction with the first module in CBM-AUC [(Sawada & Nakamura, 2022)](https://arxiv.org/abs/2202.01459), resulting in embeddings of dimension 2048. | |||||
## Run | |||||
```bash | |||||
pip install -r requirements.txt | |||||
cd dataset | |||||
unzip dataset.zip | |||||
cd .. | |||||
python main.py | |||||
``` | |||||
## Usage | |||||
```bash | |||||
usage: main.py [-h] [--no-cuda] [--epochs EPOCHS] [--lr LR] | |||||
[--batch-size BATCH_SIZE] [--loops LOOPS] | |||||
[--segment_size SEGMENT_SIZE] | |||||
[--save_interval SAVE_INTERVAL] | |||||
[--max-revision MAX_REVISION] | |||||
[--require-more-revision REQUIRE_MORE_REVISION] | |||||
BDD_OIA example | |||||
optional arguments: | |||||
-h, --help show this help message and exit | |||||
--no-cuda disables CUDA training | |||||
--epochs EPOCHS number of epochs in each learning loop iteration | |||||
(default : 1) | |||||
--lr LR base model learning rate (default : 0.002) | |||||
--batch-size BATCH_SIZE | |||||
base model batch size (default : 32) | |||||
--loops LOOPS | |||||
number of loop iterations (default : 2) | |||||
--segment_size SEGMENT_SIZE | |||||
segment size (default : 0.01) | |||||
--save_interval SAVE_INTERVAL | |||||
save interval (default : 1) | |||||
--max-revision MAX_REVISION | |||||
maximum revision in reasoner (default : 3) | |||||
-require-more-revision REQUIRE_MORE_REVISION | |||||
require more revision in reasoner (default : 3) | |||||
``` |
@@ -0,0 +1,24 @@ | |||||
import numpy as np | |||||
from typing import List, Any | |||||
from ablkit.data import ListData | |||||
from ablkit.bridge import SimpleBridge | |||||
class BDDBridge(SimpleBridge): | |||||
def idx_to_pseudo_label(self, data_examples: ListData) -> List[List[Any]]: | |||||
pred_idx = data_examples.pred_idx # [ ndarray(1,nc),... ] | |||||
pred_pseudo_label = [] | |||||
for sub_list in pred_idx: | |||||
sub_list = sub_list.squeeze() # 1 x nc -> nc | |||||
pred_pseudo_label.append([self.reasoner.idx_to_label[_idx] for _idx in sub_list]) | |||||
data_examples.pred_pseudo_label = pred_pseudo_label | |||||
return data_examples.pred_pseudo_label | |||||
def pseudo_label_to_idx(self, data_examples: ListData) -> List[List[Any]]: | |||||
abduced_pseudo_label = data_examples.abduced_pseudo_label | |||||
abduced_idx = [] | |||||
for sub_list in abduced_pseudo_label: | |||||
sub_list = np.array([self.reasoner.label_to_idx[_lab] for _lab in sub_list]) | |||||
abduced_idx.append(sub_list) | |||||
data_examples.abduced_idx = abduced_idx | |||||
return data_examples.abduced_idx |
@@ -0,0 +1,19 @@ | |||||
import os | |||||
import numpy as np | |||||
CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) | |||||
def get_dataset(fname, get_pseudo_label=True): | |||||
fname = os.path.join(CURRENT_DIR, fname) | |||||
data = np.load(fname) | |||||
X = data["X"] | |||||
X = [[emb.astype(np.float32)] for emb in X] | |||||
pseudo_label = data["pseudo_label"].astype(int).tolist() if get_pseudo_label else None | |||||
Y = data["Y"][:, :4].astype(int).tolist() | |||||
Y = [tuple(y) for y in Y] | |||||
return X, pseudo_label, Y | |||||
if __name__ == "__main__": | |||||
dataset = get_dataset("val.npz") |
@@ -0,0 +1,149 @@ | |||||
import argparse | |||||
import os.path as osp | |||||
import numpy as np | |||||
import torch | |||||
from torch import optim | |||||
from ablkit.data.evaluation import SymbolAccuracy | |||||
from ablkit.reasoning import Reasoner | |||||
from ablkit.utils import ABLLogger, print_log | |||||
from models.nn import ConceptNet | |||||
from models.bdd_nn import BDDNN | |||||
from models.bdd_model import BDDABLModel | |||||
from reasoning.bddkb import BDDKB | |||||
from dataset.data_util import get_dataset | |||||
from bridge import BDDBridge | |||||
from metric import BDDReasoningMetric | |||||
def multi_label_confidence_dist(data_example, candidates, candidates_idxs, reasoning_results): | |||||
pred_prob = data_example.pred_prob.T # nc x 1 | |||||
pred_prob = np.concatenate([1 - pred_prob, pred_prob], axis=1) # nc x 2 | |||||
cols = np.arange(len(candidates_idxs[0]))[None, :] | |||||
corr_prob = pred_prob[cols, candidates_idxs] | |||||
costs = -np.sum(np.log(corr_prob + 1e-6), axis=1) | |||||
return costs | |||||
def get_args(): | |||||
parser = argparse.ArgumentParser(description="BDD-OIA example") | |||||
parser.add_argument( | |||||
"--no-cuda", action="store_true", default=False, help="disables CUDA training" | |||||
) | |||||
parser.add_argument( | |||||
"--epochs", | |||||
type=int, | |||||
default=1, | |||||
help="number of epochs in each learning loop iteration (default : 1)", | |||||
) | |||||
parser.add_argument( | |||||
"--lr", type=float, default=2e-3, help="base model learning rate (default : 0.002)" | |||||
) | |||||
parser.add_argument( | |||||
"--batch-size", type=int, default=32, help="base model batch size (default : 32)" | |||||
) | |||||
parser.add_argument( | |||||
"--loops", type=int, default=2, help="number of loop iterations (default : 2)" | |||||
) | |||||
parser.add_argument( | |||||
"--segment_size", type=int, default=0.01, help="segment size (default : 0.01)" | |||||
) | |||||
parser.add_argument("--save_interval", type=int, default=1, help="save interval (default : 1)") | |||||
parser.add_argument( | |||||
"--max-revision", type=int, default=3, help="maximum revision in reasoner (default : 3)" | |||||
) | |||||
parser.add_argument( | |||||
"--require-more-revision", | |||||
type=int, | |||||
default=3, | |||||
help="require more revision in reasoner (default : 3)", | |||||
) | |||||
args = parser.parse_args() | |||||
return args | |||||
def main(): | |||||
args = get_args() | |||||
# Build logger | |||||
print_log("Abductive Learning on the BDD-OIA example.", logger="current") | |||||
# -- Working with Data ------------------------------ | |||||
print_log("Working with Data.", logger="current") | |||||
train_data = get_dataset(fname="train.npz", get_pseudo_label=True) | |||||
val_data = get_dataset(fname="val.npz", get_pseudo_label=True) | |||||
test_data = get_dataset(fname="test.npz", get_pseudo_label=True) | |||||
# -- Building the Learning Part --------------------- | |||||
print_log("Building the Learning Part.", logger="current") | |||||
# Build necessary components for BDDNN | |||||
cls = ConceptNet() | |||||
loss_fn = nn.BCEWithLogitsLoss() | |||||
optimizer = optim.Adam(cls.parameters(), lr=args.lr) | |||||
use_cuda = not args.no_cuda and torch.cuda.is_available() | |||||
device = torch.device("cuda" if use_cuda else "cpu") | |||||
scheduler = optim.lr_scheduler.OneCycleLR( | |||||
optimizer, | |||||
max_lr=args.lr, | |||||
pct_start=0.15, | |||||
epochs=args.loops, | |||||
steps_per_epoch=int(1 / args.segment_size) + 1, | |||||
) | |||||
# Build BDDNN | |||||
base_model = BDDNN( | |||||
cls, | |||||
loss_fn, | |||||
optimizer, | |||||
scheduler=scheduler, | |||||
device=device, | |||||
batch_size=args.batch_size, | |||||
num_epochs=args.epochs, | |||||
) | |||||
# Build ABLModel | |||||
model = BDDABLModel(base_model) | |||||
# -- Building the Reasoning Part -------------------- | |||||
print_log("Building the Reasoning Part.", logger="current") | |||||
# Build knowledge base | |||||
kb = BDDKB() | |||||
# Create reasoner | |||||
reasoner = Reasoner( | |||||
kb, | |||||
dist_func=multi_label_confidence_dist, | |||||
max_revision=args.max_revision, | |||||
require_more_revision=args.require_more_revision, | |||||
) | |||||
# -- Building Evaluation Metrics -------------------- | |||||
print_log("Building Evaluation Metrics.", logger="current") | |||||
metric_list = [SymbolAccuracy(prefix="bdd_oia"), BDDReasoningMetric(kb=kb, prefix="bdd_oia")] | |||||
# -- Bridging Learning and Reasoning ---------------- | |||||
print_log("Bridge Learning and Reasoning.", logger="current") | |||||
bridge = BDDBridge(model, reasoner, metric_list) | |||||
# Retrieve the directory of the Log file and define the directory for saving the model weights. | |||||
log_dir = ABLLogger.get_current_instance().log_dir | |||||
weights_dir = osp.join(log_dir, "weights") | |||||
# Train and Test | |||||
bridge.train( | |||||
train_data=train_data, | |||||
val_data=val_data, | |||||
loops=args.loops, | |||||
segment_size=args.segment_size, | |||||
save_interval=args.save_interval, | |||||
save_dir=weights_dir, | |||||
) | |||||
bridge.test(test_data) | |||||
if __name__ == "__main__": | |||||
main() |
@@ -0,0 +1,27 @@ | |||||
from typing import Optional | |||||
from ablkit.reasoning import KBBase | |||||
from ablkit.data import BaseMetric, ListData | |||||
class BDDReasoningMetric(BaseMetric): | |||||
def __init__(self, kb: KBBase, prefix: Optional[str] = None) -> None: | |||||
super().__init__(prefix) | |||||
self.kb = kb | |||||
def process(self, data_examples: ListData) -> None: | |||||
pred_pseudo_label_list = data_examples.pred_pseudo_label | |||||
y_list = data_examples.Y | |||||
x_list = data_examples.X | |||||
for pred_pseudo_label, y, x in zip(pred_pseudo_label_list, y_list, x_list): | |||||
pred_y = self.kb.logic_forward( | |||||
pred_pseudo_label, *(x,) if self.kb._num_args == 2 else () | |||||
) | |||||
for py, yy in zip(pred_y, y): | |||||
self.results.append(int(py == yy)) | |||||
def compute_metrics(self) -> dict: | |||||
results = self.results | |||||
metrics = dict() | |||||
metrics["reasoning_accuracy"] = sum(results) / len(results) | |||||
return metrics |
@@ -0,0 +1,25 @@ | |||||
from typing import Dict | |||||
import numpy as np | |||||
from ablkit.data import ListData | |||||
from ablkit.learning import ABLModel | |||||
from ablkit.utils import reform_list | |||||
class BDDABLModel(ABLModel): | |||||
def predict(self, data_examples: ListData) -> Dict: | |||||
model = self.base_model | |||||
data_X = data_examples.flatten("X") | |||||
if hasattr(model, "predict_proba"): | |||||
prob = model.predict_proba(X=data_X) | |||||
label = np.where(prob > 0.5, 1, 0).astype(int) | |||||
prob = reform_list(prob, data_examples.X) | |||||
else: | |||||
prob = None | |||||
label = model.predict(X=data_X) | |||||
label = reform_list(label, data_examples.X) | |||||
data_examples.pred_idx = label | |||||
data_examples.pred_prob = prob | |||||
return {"label": label, "prob": prob} |
@@ -0,0 +1,93 @@ | |||||
import logging | |||||
from typing import Any, Callable, List, Optional | |||||
import numpy | |||||
import torch | |||||
from torch.utils.data import DataLoader | |||||
from ablkit.learning import BasicNN, PredictionDataset, ClassificationDataset | |||||
from ablkit.utils.logger import print_log | |||||
class MultiLabelClassificationDataset(ClassificationDataset): | |||||
def __init__(self, X: List[Any], Y: List[int], transform: Optional[Callable[..., Any]] = None): | |||||
if (not isinstance(X, list)) or (not isinstance(Y, list)): | |||||
raise ValueError("X and Y should be of type list.") | |||||
self.X = X | |||||
self.Y = torch.FloatTensor(numpy.stack(Y, axis=0)) # float32 for BCELoss | |||||
self.transform = transform | |||||
class BDDNN(BasicNN): | |||||
def predict( | |||||
self, | |||||
data_loader: Optional[DataLoader] = None, | |||||
X: Optional[List[Any]] = None, | |||||
) -> numpy.ndarray: | |||||
if data_loader is not None and X is not None: | |||||
print_log( | |||||
"Predict the class of input data in data_loader instead of X.", | |||||
logger="current", | |||||
level=logging.WARNING, | |||||
) | |||||
if data_loader is None: | |||||
dataset = PredictionDataset(X, self.test_transform) | |||||
data_loader = DataLoader( | |||||
dataset, | |||||
batch_size=self.batch_size, | |||||
num_workers=self.num_workers, | |||||
collate_fn=self.collate_fn, | |||||
pin_memory=torch.cuda.is_available(), | |||||
) | |||||
pred_probs = self._predict(data_loader).sigmoid() | |||||
pred = torch.where(pred_probs > 0.5, 1, 0).int() | |||||
return pred.cpu().numpy() | |||||
def predict_proba( | |||||
self, | |||||
data_loader: Optional[DataLoader] = None, | |||||
X: Optional[List[Any]] = None, | |||||
) -> numpy.ndarray: | |||||
if data_loader is not None and X is not None: | |||||
print_log( | |||||
"Predict the class probability of input data in data_loader instead of X.", | |||||
logger="current", | |||||
level=logging.WARNING, | |||||
) | |||||
if data_loader is None: | |||||
dataset = PredictionDataset(X, self.test_transform) | |||||
data_loader = DataLoader( | |||||
dataset, | |||||
batch_size=self.batch_size, | |||||
num_workers=self.num_workers, | |||||
collate_fn=self.collate_fn, | |||||
pin_memory=torch.cuda.is_available(), | |||||
) | |||||
pred_probs = self._predict(data_loader).sigmoid() # B x NC | |||||
return pred_probs.cpu().numpy() | |||||
def _data_loader( | |||||
self, | |||||
X: Optional[List[Any]], | |||||
y: Optional[List[int]] = None, | |||||
shuffle: Optional[bool] = True, | |||||
) -> DataLoader: | |||||
if X is None: | |||||
raise ValueError("X should not be None.") | |||||
if y is None: | |||||
y = [0] * len(X) | |||||
if not len(y) == len(X): | |||||
raise ValueError("X and y should have equal length.") | |||||
dataset = MultiLabelClassificationDataset(X, y, transform=self.train_transform) | |||||
data_loader = DataLoader( | |||||
dataset, | |||||
batch_size=self.batch_size, | |||||
shuffle=shuffle, | |||||
num_workers=self.num_workers, | |||||
collate_fn=self.collate_fn, | |||||
pin_memory=torch.cuda.is_available(), | |||||
) | |||||
return data_loader |
@@ -0,0 +1,24 @@ | |||||
from torch import nn | |||||
class SimpleNet(nn.Module): | |||||
def __init__(self, num_features=2048, num_concepts=21): | |||||
super(SimpleNet, self).__init__() | |||||
self.fc = nn.Linear(num_features, num_concepts) | |||||
def forward(self, x): | |||||
return self.fc(x) | |||||
class ConceptNet(nn.Module): | |||||
def __init__(self, num_features=2048, num_concepts=21): | |||||
super(ConceptNet, self).__init__() | |||||
intermidate_dim = 256 | |||||
self.fc = nn.Sequential( | |||||
nn.Linear(num_features, intermidate_dim), | |||||
nn.SiLU(), | |||||
nn.Linear(intermidate_dim, num_concepts), | |||||
) | |||||
def forward(self, x): | |||||
return self.fc(x) |
@@ -0,0 +1,67 @@ | |||||
# -*- coding: utf-8 -*- | |||||
from ablkit.reasoning import KBBase | |||||
class BDDKB(KBBase): | |||||
def __init__(self, pseudo_label_list=None): | |||||
if pseudo_label_list is None: | |||||
pseudo_label_list = [0, 1] | |||||
super().__init__(pseudo_label_list) | |||||
def logic_forward(self, attrs): | |||||
""" | |||||
Abduction space | |||||
(0, 1, 0, 0) 610812 | |||||
(0, 1, 0, 1) 75012 | |||||
(0, 1, 1, 0) 75012 | |||||
(0, 1, 1, 1) 9212 | |||||
(1, 0, 0, 0) 12996 | |||||
(1, 0, 0, 1) 1596 | |||||
(1, 0, 1, 0) 1596 | |||||
(1, 0, 1, 1) 196 | |||||
""" | |||||
assert len(attrs) == 21 | |||||
( | |||||
green_light, | |||||
follow, | |||||
road_clear, | |||||
red_light, | |||||
traffic_sign, | |||||
car, | |||||
person, | |||||
rider, | |||||
other_obstacle, | |||||
left_lane, | |||||
left_green_light, | |||||
left_follow, | |||||
no_left_lane, | |||||
left_obstacle, | |||||
left_solid_line, | |||||
right_lane, | |||||
right_green_light, | |||||
right_follow, | |||||
no_right_lane, | |||||
right_obstacle, | |||||
right_solid_line, | |||||
) = attrs | |||||
illegal_return = (0, 0, 0, 0) | |||||
if red_light == green_light == 1: | |||||
return illegal_return | |||||
obstacle = car or person or rider or other_obstacle | |||||
if road_clear == obstacle: | |||||
return illegal_return | |||||
move_forward = green_light or follow or road_clear | |||||
stop = red_light or traffic_sign or obstacle | |||||
if stop: | |||||
move_forward = 0 | |||||
can_turn_left = left_lane or left_green_light or left_follow | |||||
cannot_turn_left = no_left_lane or left_obstacle or left_solid_line | |||||
turn_left = can_turn_left and int(not cannot_turn_left) | |||||
can_turn_right = right_lane or right_green_light or right_follow | |||||
cannot_turn_right = no_right_lane or right_obstacle or right_solid_line | |||||
turn_right = can_turn_right and int(not cannot_turn_right) | |||||
return move_forward, stop, turn_left, turn_right |
@@ -0,0 +1 @@ | |||||
ablkit |
@@ -81,7 +81,7 @@ def main(): | |||||
"--label-smoothing", | "--label-smoothing", | ||||
type=float, | type=float, | ||||
default=0.2, | default=0.2, | ||||
help="label smoothing in cross entropy loss (default : 0.2)" | |||||
help="label smoothing in cross entropy loss (default : 0.2)", | |||||
) | ) | ||||
parser.add_argument( | parser.add_argument( | ||||
"--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)" | "--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)" | ||||
@@ -138,7 +138,12 @@ def main(): | |||||
# Build BasicNN | # Build BasicNN | ||||
base_model = BasicNN( | base_model = BasicNN( | ||||
cls, loss_fn, optimizer, device=device, batch_size=args.batch_size, num_epochs=args.epochs, | |||||
cls, | |||||
loss_fn, | |||||
optimizer, | |||||
device=device, | |||||
batch_size=args.batch_size, | |||||
num_epochs=args.epochs, | |||||
) | ) | ||||
# Build ABLModel | # Build ABLModel | ||||