@@ -28,13 +28,13 @@ model. | |||
import torch | |||
import torch.nn as nn | |||
import matplotlib.pyplot as plt | |||
from examples.hed.datasets import get_dataset, split_equation | |||
from examples.models.nn import SymbolNet | |||
from datasets import get_dataset, split_equation | |||
from models.nn import SymbolNet | |||
from abl.learning import ABLModel, BasicNN | |||
from examples.hed.reasoning import HedKB, HedReasoner | |||
from reasoning import HedKB, HedReasoner | |||
from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | |||
from abl.utils import ABLLogger, print_log | |||
from examples.hed.bridge import HedBridge | |||
from bridge import HedBridge | |||
Working with Data | |||
----------------- | |||
@@ -26,8 +26,8 @@ machine learning model. | |||
import torch | |||
import torch.nn as nn | |||
import matplotlib.pyplot as plt | |||
from examples.hwf.datasets import get_dataset | |||
from examples.models.nn import SymbolNet | |||
from datasets import get_dataset | |||
from models.nn import SymbolNet | |||
from abl.learning import ABLModel, BasicNN | |||
from abl.reasoning import KBBase, Reasoner | |||
from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | |||
@@ -27,8 +27,8 @@ machine learning model. | |||
from torch.optim import RMSprop, lr_scheduler | |||
from examples.mnist_add.datasets import get_dataset | |||
from examples.models.nn import LeNet5 | |||
from datasets import get_dataset | |||
from models.nn import LeNet5 | |||
from abl.learning import ABLModel, BasicNN | |||
from abl.reasoning import KBBase, Reasoner | |||
from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | |||
@@ -22,9 +22,9 @@ further update the learning model. | |||
import os.path as osp | |||
import numpy as np | |||
from sklearn.ensemble import RandomForestClassifier | |||
from examples.zoo.get_dataset import load_and_preprocess_dataset, split_dataset | |||
from get_dataset import load_and_preprocess_dataset, split_dataset | |||
from abl.learning import ABLModel | |||
from examples.zoo.kb import ZooKB | |||
from kb import ZooKB | |||
from abl.reasoning import Reasoner | |||
from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | |||
from abl.utils import ABLLogger, print_log, confidence_dist | |||
@@ -21,7 +21,7 @@ In the MNIST Addition task, the data loading looks like | |||
.. code:: python | |||
from examples.mnist_add.datasets.get_mnist_add import get_mnist_add | |||
from datasets.get_mnist_add import get_mnist_add | |||
# train_data and test_data are tuples in the format (X, gt_pseudo_label, Y) | |||
# If get_pseudo_label is set to False, the gt_pseudo_label in each tuple will be None. | |||
@@ -38,7 +38,7 @@ In this example, we build a simple LeNet5 network as the base model. | |||
.. code:: python | |||
from examples.models.nn import LeNet5 | |||
from models.nn import LeNet5 | |||
# The number of pseudo-labels is 10 | |||
cls = LeNet5(num_classes=10) | |||
@@ -11,9 +11,9 @@ from abl.learning import ABLModel, BasicNN | |||
from abl.reasoning import Reasoner | |||
from abl.data.structures import ListData | |||
from abl.utils import print_log | |||
from examples.hed.datasets import get_pretrain_data | |||
from examples.hed.utils import InfiniteSampler, gen_mappings | |||
from examples.models.nn import SymbolNetAutoencoder | |||
from datasets import get_pretrain_data | |||
from utils import InfiniteSampler, gen_mappings | |||
from models.nn import SymbolNetAutoencoder | |||
class HedBridge(SimpleBridge): | |||
@@ -13,7 +13,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 1, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -23,14 +23,13 @@ | |||
"import torch.nn as nn\n", | |||
"import matplotlib.pyplot as plt\n", | |||
"\n", | |||
"from abl.learning import ABLModel, BasicNN\n", | |||
"from abl.utils import ABLLogger, print_log\n", | |||
"\n", | |||
"from bridge import HedBridge\n", | |||
"from consistency_metric import ConsistencyMetric\n", | |||
"from datasets import get_dataset, split_equation\n", | |||
"from models.nn import SymbolNet\n", | |||
"from reasoning import HedKB, HedReasoner" | |||
"from abl.learning import ABLModel, BasicNN\n", | |||
"from reasoning import HedKB, HedReasoner\n", | |||
"from consistency_metric import ConsistencyMetric\n", | |||
"from abl.utils import ABLLogger, print_log\n", | |||
"from bridge import HedBridge" | |||
] | |||
}, | |||
{ | |||
@@ -49,7 +48,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 2, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -67,7 +66,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 3, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
@@ -121,7 +120,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 4, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
@@ -242,7 +241,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 5, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -272,7 +271,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 6, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -298,7 +297,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 7, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -323,7 +322,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 8, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -347,7 +346,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 9, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -366,7 +365,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 10, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -383,7 +382,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 11, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -4,14 +4,13 @@ import argparse | |||
import torch | |||
import torch.nn as nn | |||
from abl.learning import ABLModel, BasicNN | |||
from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | |||
from abl.utils import ABLLogger, print_log | |||
from bridge import HedBridge | |||
from datasets import get_dataset, split_equation | |||
from models.nn import SymbolNet | |||
from abl.learning import ABLModel, BasicNN | |||
from reasoning import HedKB, HedReasoner | |||
from consistency_metric import ConsistencyMetric | |||
from abl.utils import ABLLogger, print_log | |||
from bridge import HedBridge | |||
def main(): | |||
@@ -82,7 +81,7 @@ def main(): | |||
reasoner = HedReasoner(kb, dist_func="hamming", use_zoopt=True, max_revision=args.max_revision) | |||
### Building Evaluation Metrics | |||
metric_list = [SymbolAccuracy(prefix="hed"), ReasoningMetric(kb=kb, prefix="hed")] | |||
metric_list = [ConsistencyMetric(kb=kb)] | |||
### Bridge Learning and Reasoning | |||
bridge = HedBridge(model, reasoner, metric_list) | |||
@@ -1,16 +1,3 @@ | |||
# coding: utf-8 | |||
# ================================================================# | |||
# Copyright (C) 2021 Freecss All rights reserved. | |||
# | |||
# File Name :lenet5.py | |||
# Author :freecss | |||
# Email :karlfreecss@gmail.com | |||
# Created Date :2021/03/03 | |||
# Description : | |||
# | |||
# ================================================================# | |||
import torch | |||
from torch import nn | |||
@@ -24,14 +24,13 @@ | |||
"import torch.nn as nn\n", | |||
"import matplotlib.pyplot as plt\n", | |||
"\n", | |||
"from datasets import get_dataset\n", | |||
"from models.nn import SymbolNet\n", | |||
"from abl.learning import ABLModel, BasicNN\n", | |||
"from abl.reasoning import KBBase, Reasoner\n", | |||
"from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n", | |||
"from abl.utils import ABLLogger, print_log\n", | |||
"from abl.bridge import SimpleBridge\n", | |||
"\n", | |||
"from datasets import get_dataset\n", | |||
"from models.nn import SymbolNet" | |||
"from abl.bridge import SimpleBridge" | |||
] | |||
}, | |||
{ | |||
@@ -1,19 +1,18 @@ | |||
import os.path as osp | |||
import argparse | |||
import os.path as osp | |||
import numpy as np | |||
import torch | |||
from torch import nn | |||
from datasets import get_dataset | |||
from models.nn import SymbolNet | |||
from abl.learning import ABLModel, BasicNN | |||
from abl.reasoning import KBBase, GroundKB, Reasoner | |||
from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | |||
from abl.utils import ABLLogger, print_log | |||
from abl.bridge import SimpleBridge | |||
from datasets import get_dataset | |||
from models.nn import SymbolNet | |||
class HwfKB(KBBase): | |||
def __init__( | |||
@@ -1,16 +1,3 @@ | |||
# coding: utf-8 | |||
# ================================================================# | |||
# Copyright (C) 2021 Freecss All rights reserved. | |||
# | |||
# File Name :lenet5.py | |||
# Author :freecss | |||
# Email :karlfreecss@gmail.com | |||
# Created Date :2021/03/03 | |||
# Description : | |||
# | |||
# ================================================================# | |||
import torch | |||
from torch import nn | |||
@@ -5,14 +5,13 @@ import torch | |||
from torch import nn | |||
from torch.optim import RMSprop, lr_scheduler | |||
from abl.bridge import SimpleBridge | |||
from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | |||
from datasets import get_dataset | |||
from models.nn import LeNet5 | |||
from abl.learning import ABLModel, BasicNN | |||
from abl.reasoning import GroundKB, KBBase, PrologKB, Reasoner | |||
from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | |||
from abl.utils import ABLLogger, print_log | |||
from datasets import get_dataset | |||
from models.nn import LeNet5 | |||
from abl.bridge import SimpleBridge | |||
class AddKB(KBBase): | |||
@@ -25,14 +25,13 @@ | |||
"\n", | |||
"from torch.optim import RMSprop, lr_scheduler\n", | |||
"\n", | |||
"from abl.bridge import SimpleBridge\n", | |||
"from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n", | |||
"from datasets import get_dataset\n", | |||
"from models.nn import LeNet5\n", | |||
"from abl.learning import ABLModel, BasicNN\n", | |||
"from abl.reasoning import KBBase, Reasoner\n", | |||
"from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n", | |||
"from abl.utils import ABLLogger, print_log\n", | |||
"\n", | |||
"from datasets import get_dataset\n", | |||
"from models.nn import LeNet5" | |||
"from abl.bridge import SimpleBridge" | |||
] | |||
}, | |||
{ | |||
@@ -1,16 +1,3 @@ | |||
# coding: utf-8 | |||
# ================================================================# | |||
# Copyright (C) 2021 Freecss All rights reserved. | |||
# | |||
# File Name :lenet5.py | |||
# Author :freecss | |||
# Email :karlfreecss@gmail.com | |||
# Created Date :2021/03/03 | |||
# Description : | |||
# | |||
# ================================================================# | |||
import numpy as np | |||
import torch | |||
from torch import nn | |||
@@ -45,50 +32,3 @@ class LeNet5(nn.Module): | |||
x = self.fc2(x) | |||
x = self.fc3(x) | |||
return x | |||
class SymbolNet(nn.Module): | |||
def __init__(self, num_classes=4, image_size=(28, 28, 1)): | |||
super(SymbolNet, self).__init__() | |||
self.conv1 = nn.Sequential( | |||
nn.Conv2d(1, 32, 5, stride=1), | |||
nn.ReLU(), | |||
nn.MaxPool2d(kernel_size=2, stride=2), | |||
nn.BatchNorm2d(32, momentum=0.99, eps=0.001), | |||
) | |||
self.conv2 = nn.Sequential( | |||
nn.Conv2d(32, 64, 5, padding=2, stride=1), | |||
nn.ReLU(), | |||
nn.MaxPool2d(kernel_size=2, stride=2), | |||
nn.BatchNorm2d(64, momentum=0.99, eps=0.001), | |||
) | |||
num_features = 64 * (image_size[0] // 4 - 1) * (image_size[1] // 4 - 1) | |||
self.fc1 = nn.Sequential(nn.Linear(num_features, 120), nn.ReLU()) | |||
self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.ReLU()) | |||
self.fc3 = nn.Sequential(nn.Linear(84, num_classes)) | |||
def forward(self, x): | |||
x = self.conv1(x) | |||
x = self.conv2(x) | |||
x = torch.flatten(x, 1) | |||
x = self.fc1(x) | |||
x = self.fc2(x) | |||
x = self.fc3(x) | |||
return x | |||
class SymbolNetAutoencoder(nn.Module): | |||
def __init__(self, num_classes=4, image_size=(28, 28, 1)): | |||
super(SymbolNetAutoencoder, self).__init__() | |||
self.base_model = SymbolNet(num_classes, image_size) | |||
self.softmax = nn.Softmax(dim=1) | |||
self.fc1 = nn.Sequential(nn.Linear(num_classes, 100), nn.ReLU()) | |||
self.fc2 = nn.Sequential(nn.Linear(100, image_size[0] * image_size[1]), nn.ReLU()) | |||
def forward(self, x): | |||
x = self.base_model(x) | |||
# x = self.softmax(x) | |||
x = self.fc1(x) | |||
x = self.fc2(x) | |||
return x |
@@ -1,94 +0,0 @@ | |||
# coding: utf-8 | |||
# ================================================================# | |||
# Copyright (C) 2021 Freecss All rights reserved. | |||
# | |||
# File Name :lenet5.py | |||
# Author :freecss | |||
# Email :karlfreecss@gmail.com | |||
# Created Date :2021/03/03 | |||
# Description : | |||
# | |||
# ================================================================# | |||
import numpy as np | |||
import torch | |||
from torch import nn | |||
class LeNet5(nn.Module): | |||
def __init__(self, num_classes=10, image_size=(28, 28)): | |||
super(LeNet5, self).__init__() | |||
self.conv1 = nn.Sequential( | |||
nn.Conv2d(1, 6, 3, padding=1), | |||
nn.ReLU(), | |||
nn.MaxPool2d(kernel_size=2, stride=2), | |||
) | |||
self.conv2 = nn.Sequential( | |||
nn.Conv2d(6, 16, 3), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) | |||
) | |||
self.conv3 = nn.Sequential(nn.Conv2d(16, 16, 3), nn.ReLU()) | |||
feature_map_size = (np.array(image_size) // 2 - 2) // 2 - 2 | |||
num_features = 16 * feature_map_size[0] * feature_map_size[1] | |||
self.fc1 = nn.Sequential(nn.Linear(num_features, 120), nn.ReLU()) | |||
self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.ReLU()) | |||
self.fc3 = nn.Linear(84, num_classes) | |||
def forward(self, x): | |||
x = self.conv1(x) | |||
x = self.conv2(x) | |||
x = self.conv3(x) | |||
x = torch.flatten(x, 1) | |||
x = self.fc1(x) | |||
x = self.fc2(x) | |||
x = self.fc3(x) | |||
return x | |||
class SymbolNet(nn.Module): | |||
def __init__(self, num_classes=4, image_size=(28, 28, 1)): | |||
super(SymbolNet, self).__init__() | |||
self.conv1 = nn.Sequential( | |||
nn.Conv2d(1, 32, 5, stride=1), | |||
nn.ReLU(), | |||
nn.MaxPool2d(kernel_size=2, stride=2), | |||
nn.BatchNorm2d(32, momentum=0.99, eps=0.001), | |||
) | |||
self.conv2 = nn.Sequential( | |||
nn.Conv2d(32, 64, 5, padding=2, stride=1), | |||
nn.ReLU(), | |||
nn.MaxPool2d(kernel_size=2, stride=2), | |||
nn.BatchNorm2d(64, momentum=0.99, eps=0.001), | |||
) | |||
num_features = 64 * (image_size[0] // 4 - 1) * (image_size[1] // 4 - 1) | |||
self.fc1 = nn.Sequential(nn.Linear(num_features, 120), nn.ReLU()) | |||
self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.ReLU()) | |||
self.fc3 = nn.Sequential(nn.Linear(84, num_classes)) | |||
def forward(self, x): | |||
x = self.conv1(x) | |||
x = self.conv2(x) | |||
x = torch.flatten(x, 1) | |||
x = self.fc1(x) | |||
x = self.fc2(x) | |||
x = self.fc3(x) | |||
return x | |||
class SymbolNetAutoencoder(nn.Module): | |||
def __init__(self, num_classes=4, image_size=(28, 28, 1)): | |||
super(SymbolNetAutoencoder, self).__init__() | |||
self.base_model = SymbolNet(num_classes, image_size) | |||
self.softmax = nn.Softmax(dim=1) | |||
self.fc1 = nn.Sequential(nn.Linear(num_classes, 100), nn.ReLU()) | |||
self.fc2 = nn.Sequential(nn.Linear(100, image_size[0] * image_size[1]), nn.ReLU()) | |||
def forward(self, x): | |||
x = self.base_model(x) | |||
# x = self.softmax(x) | |||
x = self.fc1(x) | |||
x = self.fc2(x) | |||
return x |
@@ -4,9 +4,9 @@ import argparse | |||
import numpy as np | |||
from sklearn.ensemble import RandomForestClassifier | |||
from examples.zoo.get_dataset import load_and_preprocess_dataset, split_dataset | |||
from get_dataset import load_and_preprocess_dataset, split_dataset | |||
from abl.learning import ABLModel | |||
from examples.zoo.kb import ZooKB | |||
from kb import ZooKB | |||
from abl.reasoning import Reasoner | |||
from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | |||
from abl.utils import ABLLogger, print_log, confidence_dist | |||
@@ -13,7 +13,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 1, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -21,9 +21,9 @@ | |||
"import os.path as osp\n", | |||
"import numpy as np\n", | |||
"from sklearn.ensemble import RandomForestClassifier\n", | |||
"from examples.zoo.get_dataset import load_and_preprocess_dataset, split_dataset\n", | |||
"from get_dataset import load_and_preprocess_dataset, split_dataset\n", | |||
"from abl.learning import ABLModel\n", | |||
"from examples.zoo.kb import ZooKB\n", | |||
"from kb import ZooKB\n", | |||
"from abl.reasoning import Reasoner\n", | |||
"from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n", | |||
"from abl.utils import ABLLogger, print_log, confidence_dist\n", | |||
@@ -41,7 +41,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 2, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -58,7 +58,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 3, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
@@ -99,7 +99,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 4, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -127,7 +127,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 5, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -143,7 +143,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 6, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -166,7 +166,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 7, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -182,7 +182,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 8, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
@@ -228,7 +228,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 9, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -259,7 +259,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 10, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -278,7 +278,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 11, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -294,7 +294,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 12, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
@@ -366,7 +366,7 @@ | |||
"name": "python", | |||
"nbconvert_exporter": "python", | |||
"pygments_lexer": "ipython3", | |||
"version": "3.8.13" | |||
"version": "3.8.18" | |||
}, | |||
"orig_nbformat": 4, | |||
"vscode": { | |||