@@ -28,13 +28,13 @@ model. | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
import matplotlib.pyplot as plt | import matplotlib.pyplot as plt | ||||
from examples.hed.datasets import get_dataset, split_equation | |||||
from examples.models.nn import SymbolNet | |||||
from datasets import get_dataset, split_equation | |||||
from models.nn import SymbolNet | |||||
from abl.learning import ABLModel, BasicNN | from abl.learning import ABLModel, BasicNN | ||||
from examples.hed.reasoning import HedKB, HedReasoner | |||||
from reasoning import HedKB, HedReasoner | |||||
from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | ||||
from abl.utils import ABLLogger, print_log | from abl.utils import ABLLogger, print_log | ||||
from examples.hed.bridge import HedBridge | |||||
from bridge import HedBridge | |||||
Working with Data | Working with Data | ||||
----------------- | ----------------- | ||||
@@ -26,8 +26,8 @@ machine learning model. | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
import matplotlib.pyplot as plt | import matplotlib.pyplot as plt | ||||
from examples.hwf.datasets import get_dataset | |||||
from examples.models.nn import SymbolNet | |||||
from datasets import get_dataset | |||||
from models.nn import SymbolNet | |||||
from abl.learning import ABLModel, BasicNN | from abl.learning import ABLModel, BasicNN | ||||
from abl.reasoning import KBBase, Reasoner | from abl.reasoning import KBBase, Reasoner | ||||
from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | ||||
@@ -27,8 +27,8 @@ machine learning model. | |||||
from torch.optim import RMSprop, lr_scheduler | from torch.optim import RMSprop, lr_scheduler | ||||
from examples.mnist_add.datasets import get_dataset | |||||
from examples.models.nn import LeNet5 | |||||
from datasets import get_dataset | |||||
from models.nn import LeNet5 | |||||
from abl.learning import ABLModel, BasicNN | from abl.learning import ABLModel, BasicNN | ||||
from abl.reasoning import KBBase, Reasoner | from abl.reasoning import KBBase, Reasoner | ||||
from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | ||||
@@ -22,9 +22,9 @@ further update the learning model. | |||||
import os.path as osp | import os.path as osp | ||||
import numpy as np | import numpy as np | ||||
from sklearn.ensemble import RandomForestClassifier | from sklearn.ensemble import RandomForestClassifier | ||||
from examples.zoo.get_dataset import load_and_preprocess_dataset, split_dataset | |||||
from get_dataset import load_and_preprocess_dataset, split_dataset | |||||
from abl.learning import ABLModel | from abl.learning import ABLModel | ||||
from examples.zoo.kb import ZooKB | |||||
from kb import ZooKB | |||||
from abl.reasoning import Reasoner | from abl.reasoning import Reasoner | ||||
from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | ||||
from abl.utils import ABLLogger, print_log, confidence_dist | from abl.utils import ABLLogger, print_log, confidence_dist | ||||
@@ -21,7 +21,7 @@ In the MNIST Addition task, the data loading looks like | |||||
.. code:: python | .. code:: python | ||||
from examples.mnist_add.datasets.get_mnist_add import get_mnist_add | |||||
from datasets.get_mnist_add import get_mnist_add | |||||
# train_data and test_data are tuples in the format (X, gt_pseudo_label, Y) | # train_data and test_data are tuples in the format (X, gt_pseudo_label, Y) | ||||
# If get_pseudo_label is set to False, the gt_pseudo_label in each tuple will be None. | # If get_pseudo_label is set to False, the gt_pseudo_label in each tuple will be None. | ||||
@@ -38,7 +38,7 @@ In this example, we build a simple LeNet5 network as the base model. | |||||
.. code:: python | .. code:: python | ||||
from examples.models.nn import LeNet5 | |||||
from models.nn import LeNet5 | |||||
# The number of pseudo-labels is 10 | # The number of pseudo-labels is 10 | ||||
cls = LeNet5(num_classes=10) | cls = LeNet5(num_classes=10) | ||||
@@ -11,9 +11,9 @@ from abl.learning import ABLModel, BasicNN | |||||
from abl.reasoning import Reasoner | from abl.reasoning import Reasoner | ||||
from abl.data.structures import ListData | from abl.data.structures import ListData | ||||
from abl.utils import print_log | from abl.utils import print_log | ||||
from examples.hed.datasets import get_pretrain_data | |||||
from examples.hed.utils import InfiniteSampler, gen_mappings | |||||
from examples.models.nn import SymbolNetAutoencoder | |||||
from datasets import get_pretrain_data | |||||
from utils import InfiniteSampler, gen_mappings | |||||
from models.nn import SymbolNetAutoencoder | |||||
class HedBridge(SimpleBridge): | class HedBridge(SimpleBridge): | ||||
@@ -13,7 +13,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 1, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -23,14 +23,13 @@ | |||||
"import torch.nn as nn\n", | "import torch.nn as nn\n", | ||||
"import matplotlib.pyplot as plt\n", | "import matplotlib.pyplot as plt\n", | ||||
"\n", | "\n", | ||||
"from abl.learning import ABLModel, BasicNN\n", | |||||
"from abl.utils import ABLLogger, print_log\n", | |||||
"\n", | |||||
"from bridge import HedBridge\n", | |||||
"from consistency_metric import ConsistencyMetric\n", | |||||
"from datasets import get_dataset, split_equation\n", | "from datasets import get_dataset, split_equation\n", | ||||
"from models.nn import SymbolNet\n", | "from models.nn import SymbolNet\n", | ||||
"from reasoning import HedKB, HedReasoner" | |||||
"from abl.learning import ABLModel, BasicNN\n", | |||||
"from reasoning import HedKB, HedReasoner\n", | |||||
"from consistency_metric import ConsistencyMetric\n", | |||||
"from abl.utils import ABLLogger, print_log\n", | |||||
"from bridge import HedBridge" | |||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
@@ -49,7 +48,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 2, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -67,7 +66,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 3, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [ | "outputs": [ | ||||
{ | { | ||||
@@ -121,7 +120,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 4, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [ | "outputs": [ | ||||
{ | { | ||||
@@ -242,7 +241,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 5, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -272,7 +271,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 6, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -298,7 +297,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 7, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -323,7 +322,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 8, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -347,7 +346,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 9, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -366,7 +365,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 10, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -383,7 +382,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 11, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -4,14 +4,13 @@ import argparse | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
from abl.learning import ABLModel, BasicNN | |||||
from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | |||||
from abl.utils import ABLLogger, print_log | |||||
from bridge import HedBridge | |||||
from datasets import get_dataset, split_equation | from datasets import get_dataset, split_equation | ||||
from models.nn import SymbolNet | from models.nn import SymbolNet | ||||
from abl.learning import ABLModel, BasicNN | |||||
from reasoning import HedKB, HedReasoner | from reasoning import HedKB, HedReasoner | ||||
from consistency_metric import ConsistencyMetric | |||||
from abl.utils import ABLLogger, print_log | |||||
from bridge import HedBridge | |||||
def main(): | def main(): | ||||
@@ -82,7 +81,7 @@ def main(): | |||||
reasoner = HedReasoner(kb, dist_func="hamming", use_zoopt=True, max_revision=args.max_revision) | reasoner = HedReasoner(kb, dist_func="hamming", use_zoopt=True, max_revision=args.max_revision) | ||||
### Building Evaluation Metrics | ### Building Evaluation Metrics | ||||
metric_list = [SymbolAccuracy(prefix="hed"), ReasoningMetric(kb=kb, prefix="hed")] | |||||
metric_list = [ConsistencyMetric(kb=kb)] | |||||
### Bridge Learning and Reasoning | ### Bridge Learning and Reasoning | ||||
bridge = HedBridge(model, reasoner, metric_list) | bridge = HedBridge(model, reasoner, metric_list) | ||||
@@ -1,16 +1,3 @@ | |||||
# coding: utf-8 | |||||
# ================================================================# | |||||
# Copyright (C) 2021 Freecss All rights reserved. | |||||
# | |||||
# File Name :lenet5.py | |||||
# Author :freecss | |||||
# Email :karlfreecss@gmail.com | |||||
# Created Date :2021/03/03 | |||||
# Description : | |||||
# | |||||
# ================================================================# | |||||
import torch | import torch | ||||
from torch import nn | from torch import nn | ||||
@@ -24,14 +24,13 @@ | |||||
"import torch.nn as nn\n", | "import torch.nn as nn\n", | ||||
"import matplotlib.pyplot as plt\n", | "import matplotlib.pyplot as plt\n", | ||||
"\n", | "\n", | ||||
"from datasets import get_dataset\n", | |||||
"from models.nn import SymbolNet\n", | |||||
"from abl.learning import ABLModel, BasicNN\n", | "from abl.learning import ABLModel, BasicNN\n", | ||||
"from abl.reasoning import KBBase, Reasoner\n", | "from abl.reasoning import KBBase, Reasoner\n", | ||||
"from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n", | "from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n", | ||||
"from abl.utils import ABLLogger, print_log\n", | "from abl.utils import ABLLogger, print_log\n", | ||||
"from abl.bridge import SimpleBridge\n", | |||||
"\n", | |||||
"from datasets import get_dataset\n", | |||||
"from models.nn import SymbolNet" | |||||
"from abl.bridge import SimpleBridge" | |||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
@@ -1,19 +1,18 @@ | |||||
import os.path as osp | |||||
import argparse | import argparse | ||||
import os.path as osp | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from torch import nn | from torch import nn | ||||
from datasets import get_dataset | |||||
from models.nn import SymbolNet | |||||
from abl.learning import ABLModel, BasicNN | from abl.learning import ABLModel, BasicNN | ||||
from abl.reasoning import KBBase, GroundKB, Reasoner | from abl.reasoning import KBBase, GroundKB, Reasoner | ||||
from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | ||||
from abl.utils import ABLLogger, print_log | from abl.utils import ABLLogger, print_log | ||||
from abl.bridge import SimpleBridge | from abl.bridge import SimpleBridge | ||||
from datasets import get_dataset | |||||
from models.nn import SymbolNet | |||||
class HwfKB(KBBase): | class HwfKB(KBBase): | ||||
def __init__( | def __init__( | ||||
@@ -1,16 +1,3 @@ | |||||
# coding: utf-8 | |||||
# ================================================================# | |||||
# Copyright (C) 2021 Freecss All rights reserved. | |||||
# | |||||
# File Name :lenet5.py | |||||
# Author :freecss | |||||
# Email :karlfreecss@gmail.com | |||||
# Created Date :2021/03/03 | |||||
# Description : | |||||
# | |||||
# ================================================================# | |||||
import torch | import torch | ||||
from torch import nn | from torch import nn | ||||
@@ -5,14 +5,13 @@ import torch | |||||
from torch import nn | from torch import nn | ||||
from torch.optim import RMSprop, lr_scheduler | from torch.optim import RMSprop, lr_scheduler | ||||
from abl.bridge import SimpleBridge | |||||
from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | |||||
from datasets import get_dataset | |||||
from models.nn import LeNet5 | |||||
from abl.learning import ABLModel, BasicNN | from abl.learning import ABLModel, BasicNN | ||||
from abl.reasoning import GroundKB, KBBase, PrologKB, Reasoner | from abl.reasoning import GroundKB, KBBase, PrologKB, Reasoner | ||||
from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | |||||
from abl.utils import ABLLogger, print_log | from abl.utils import ABLLogger, print_log | ||||
from datasets import get_dataset | |||||
from models.nn import LeNet5 | |||||
from abl.bridge import SimpleBridge | |||||
class AddKB(KBBase): | class AddKB(KBBase): | ||||
@@ -25,14 +25,13 @@ | |||||
"\n", | "\n", | ||||
"from torch.optim import RMSprop, lr_scheduler\n", | "from torch.optim import RMSprop, lr_scheduler\n", | ||||
"\n", | "\n", | ||||
"from abl.bridge import SimpleBridge\n", | |||||
"from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n", | |||||
"from datasets import get_dataset\n", | |||||
"from models.nn import LeNet5\n", | |||||
"from abl.learning import ABLModel, BasicNN\n", | "from abl.learning import ABLModel, BasicNN\n", | ||||
"from abl.reasoning import KBBase, Reasoner\n", | "from abl.reasoning import KBBase, Reasoner\n", | ||||
"from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n", | |||||
"from abl.utils import ABLLogger, print_log\n", | "from abl.utils import ABLLogger, print_log\n", | ||||
"\n", | |||||
"from datasets import get_dataset\n", | |||||
"from models.nn import LeNet5" | |||||
"from abl.bridge import SimpleBridge" | |||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
@@ -1,16 +1,3 @@ | |||||
# coding: utf-8 | |||||
# ================================================================# | |||||
# Copyright (C) 2021 Freecss All rights reserved. | |||||
# | |||||
# File Name :lenet5.py | |||||
# Author :freecss | |||||
# Email :karlfreecss@gmail.com | |||||
# Created Date :2021/03/03 | |||||
# Description : | |||||
# | |||||
# ================================================================# | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from torch import nn | from torch import nn | ||||
@@ -45,50 +32,3 @@ class LeNet5(nn.Module): | |||||
x = self.fc2(x) | x = self.fc2(x) | ||||
x = self.fc3(x) | x = self.fc3(x) | ||||
return x | return x | ||||
class SymbolNet(nn.Module): | |||||
def __init__(self, num_classes=4, image_size=(28, 28, 1)): | |||||
super(SymbolNet, self).__init__() | |||||
self.conv1 = nn.Sequential( | |||||
nn.Conv2d(1, 32, 5, stride=1), | |||||
nn.ReLU(), | |||||
nn.MaxPool2d(kernel_size=2, stride=2), | |||||
nn.BatchNorm2d(32, momentum=0.99, eps=0.001), | |||||
) | |||||
self.conv2 = nn.Sequential( | |||||
nn.Conv2d(32, 64, 5, padding=2, stride=1), | |||||
nn.ReLU(), | |||||
nn.MaxPool2d(kernel_size=2, stride=2), | |||||
nn.BatchNorm2d(64, momentum=0.99, eps=0.001), | |||||
) | |||||
num_features = 64 * (image_size[0] // 4 - 1) * (image_size[1] // 4 - 1) | |||||
self.fc1 = nn.Sequential(nn.Linear(num_features, 120), nn.ReLU()) | |||||
self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.ReLU()) | |||||
self.fc3 = nn.Sequential(nn.Linear(84, num_classes)) | |||||
def forward(self, x): | |||||
x = self.conv1(x) | |||||
x = self.conv2(x) | |||||
x = torch.flatten(x, 1) | |||||
x = self.fc1(x) | |||||
x = self.fc2(x) | |||||
x = self.fc3(x) | |||||
return x | |||||
class SymbolNetAutoencoder(nn.Module): | |||||
def __init__(self, num_classes=4, image_size=(28, 28, 1)): | |||||
super(SymbolNetAutoencoder, self).__init__() | |||||
self.base_model = SymbolNet(num_classes, image_size) | |||||
self.softmax = nn.Softmax(dim=1) | |||||
self.fc1 = nn.Sequential(nn.Linear(num_classes, 100), nn.ReLU()) | |||||
self.fc2 = nn.Sequential(nn.Linear(100, image_size[0] * image_size[1]), nn.ReLU()) | |||||
def forward(self, x): | |||||
x = self.base_model(x) | |||||
# x = self.softmax(x) | |||||
x = self.fc1(x) | |||||
x = self.fc2(x) | |||||
return x |
@@ -1,94 +0,0 @@ | |||||
# coding: utf-8 | |||||
# ================================================================# | |||||
# Copyright (C) 2021 Freecss All rights reserved. | |||||
# | |||||
# File Name :lenet5.py | |||||
# Author :freecss | |||||
# Email :karlfreecss@gmail.com | |||||
# Created Date :2021/03/03 | |||||
# Description : | |||||
# | |||||
# ================================================================# | |||||
import numpy as np | |||||
import torch | |||||
from torch import nn | |||||
class LeNet5(nn.Module): | |||||
def __init__(self, num_classes=10, image_size=(28, 28)): | |||||
super(LeNet5, self).__init__() | |||||
self.conv1 = nn.Sequential( | |||||
nn.Conv2d(1, 6, 3, padding=1), | |||||
nn.ReLU(), | |||||
nn.MaxPool2d(kernel_size=2, stride=2), | |||||
) | |||||
self.conv2 = nn.Sequential( | |||||
nn.Conv2d(6, 16, 3), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) | |||||
) | |||||
self.conv3 = nn.Sequential(nn.Conv2d(16, 16, 3), nn.ReLU()) | |||||
feature_map_size = (np.array(image_size) // 2 - 2) // 2 - 2 | |||||
num_features = 16 * feature_map_size[0] * feature_map_size[1] | |||||
self.fc1 = nn.Sequential(nn.Linear(num_features, 120), nn.ReLU()) | |||||
self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.ReLU()) | |||||
self.fc3 = nn.Linear(84, num_classes) | |||||
def forward(self, x): | |||||
x = self.conv1(x) | |||||
x = self.conv2(x) | |||||
x = self.conv3(x) | |||||
x = torch.flatten(x, 1) | |||||
x = self.fc1(x) | |||||
x = self.fc2(x) | |||||
x = self.fc3(x) | |||||
return x | |||||
class SymbolNet(nn.Module): | |||||
def __init__(self, num_classes=4, image_size=(28, 28, 1)): | |||||
super(SymbolNet, self).__init__() | |||||
self.conv1 = nn.Sequential( | |||||
nn.Conv2d(1, 32, 5, stride=1), | |||||
nn.ReLU(), | |||||
nn.MaxPool2d(kernel_size=2, stride=2), | |||||
nn.BatchNorm2d(32, momentum=0.99, eps=0.001), | |||||
) | |||||
self.conv2 = nn.Sequential( | |||||
nn.Conv2d(32, 64, 5, padding=2, stride=1), | |||||
nn.ReLU(), | |||||
nn.MaxPool2d(kernel_size=2, stride=2), | |||||
nn.BatchNorm2d(64, momentum=0.99, eps=0.001), | |||||
) | |||||
num_features = 64 * (image_size[0] // 4 - 1) * (image_size[1] // 4 - 1) | |||||
self.fc1 = nn.Sequential(nn.Linear(num_features, 120), nn.ReLU()) | |||||
self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.ReLU()) | |||||
self.fc3 = nn.Sequential(nn.Linear(84, num_classes)) | |||||
def forward(self, x): | |||||
x = self.conv1(x) | |||||
x = self.conv2(x) | |||||
x = torch.flatten(x, 1) | |||||
x = self.fc1(x) | |||||
x = self.fc2(x) | |||||
x = self.fc3(x) | |||||
return x | |||||
class SymbolNetAutoencoder(nn.Module): | |||||
def __init__(self, num_classes=4, image_size=(28, 28, 1)): | |||||
super(SymbolNetAutoencoder, self).__init__() | |||||
self.base_model = SymbolNet(num_classes, image_size) | |||||
self.softmax = nn.Softmax(dim=1) | |||||
self.fc1 = nn.Sequential(nn.Linear(num_classes, 100), nn.ReLU()) | |||||
self.fc2 = nn.Sequential(nn.Linear(100, image_size[0] * image_size[1]), nn.ReLU()) | |||||
def forward(self, x): | |||||
x = self.base_model(x) | |||||
# x = self.softmax(x) | |||||
x = self.fc1(x) | |||||
x = self.fc2(x) | |||||
return x |
@@ -4,9 +4,9 @@ import argparse | |||||
import numpy as np | import numpy as np | ||||
from sklearn.ensemble import RandomForestClassifier | from sklearn.ensemble import RandomForestClassifier | ||||
from examples.zoo.get_dataset import load_and_preprocess_dataset, split_dataset | |||||
from get_dataset import load_and_preprocess_dataset, split_dataset | |||||
from abl.learning import ABLModel | from abl.learning import ABLModel | ||||
from examples.zoo.kb import ZooKB | |||||
from kb import ZooKB | |||||
from abl.reasoning import Reasoner | from abl.reasoning import Reasoner | ||||
from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | ||||
from abl.utils import ABLLogger, print_log, confidence_dist | from abl.utils import ABLLogger, print_log, confidence_dist | ||||
@@ -13,7 +13,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 1, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -21,9 +21,9 @@ | |||||
"import os.path as osp\n", | "import os.path as osp\n", | ||||
"import numpy as np\n", | "import numpy as np\n", | ||||
"from sklearn.ensemble import RandomForestClassifier\n", | "from sklearn.ensemble import RandomForestClassifier\n", | ||||
"from examples.zoo.get_dataset import load_and_preprocess_dataset, split_dataset\n", | |||||
"from get_dataset import load_and_preprocess_dataset, split_dataset\n", | |||||
"from abl.learning import ABLModel\n", | "from abl.learning import ABLModel\n", | ||||
"from examples.zoo.kb import ZooKB\n", | |||||
"from kb import ZooKB\n", | |||||
"from abl.reasoning import Reasoner\n", | "from abl.reasoning import Reasoner\n", | ||||
"from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n", | "from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n", | ||||
"from abl.utils import ABLLogger, print_log, confidence_dist\n", | "from abl.utils import ABLLogger, print_log, confidence_dist\n", | ||||
@@ -41,7 +41,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 2, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -58,7 +58,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 3, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [ | "outputs": [ | ||||
{ | { | ||||
@@ -99,7 +99,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 4, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -127,7 +127,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 5, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -143,7 +143,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 6, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -166,7 +166,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 7, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -182,7 +182,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 8, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [ | "outputs": [ | ||||
{ | { | ||||
@@ -228,7 +228,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 9, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -259,7 +259,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 10, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -278,7 +278,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 11, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -294,7 +294,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 12, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [ | "outputs": [ | ||||
{ | { | ||||
@@ -366,7 +366,7 @@ | |||||
"name": "python", | "name": "python", | ||||
"nbconvert_exporter": "python", | "nbconvert_exporter": "python", | ||||
"pygments_lexer": "ipython3", | "pygments_lexer": "ipython3", | ||||
"version": "3.8.13" | |||||
"version": "3.8.18" | |||||
}, | }, | ||||
"orig_nbformat": 4, | "orig_nbformat": 4, | ||||
"vscode": { | "vscode": { | ||||