@@ -22,14 +22,15 @@ | |||
"import torch\n", | |||
"import torch.nn as nn\n", | |||
"import matplotlib.pyplot as plt\n", | |||
"from examples.hed.datasets import get_dataset, split_equation\n", | |||
"from examples.models.nn import SymbolNet\n", | |||
"\n", | |||
"from abl.learning import ABLModel, BasicNN\n", | |||
"from examples.hed.reasoning import HedKB, HedReasoner\n", | |||
"from abl.data.evaluation import SymbolAccuracy\n", | |||
"from examples.hed.consistency_metric import ConsistencyMetric\n", | |||
"from abl.utils import ABLLogger, print_log\n", | |||
"from examples.hed.bridge import HedBridge" | |||
"\n", | |||
"from bridge import HedBridge\n", | |||
"from consistency_metric import ConsistencyMetric\n", | |||
"from datasets import get_dataset, split_equation\n", | |||
"from models.nn import SymbolNet\n", | |||
"from reasoning import HedKB, HedReasoner" | |||
] | |||
}, | |||
{ | |||
@@ -382,7 +383,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"execution_count": 11, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -415,7 +416,7 @@ | |||
"name": "python", | |||
"nbconvert_exporter": "python", | |||
"pygments_lexer": "ipython3", | |||
"version": "3.8.13" | |||
"version": "3.8.18" | |||
}, | |||
"orig_nbformat": 4, | |||
"vscode": { | |||
@@ -4,13 +4,15 @@ import argparse | |||
import torch | |||
import torch.nn as nn | |||
from examples.hed.datasets import get_dataset, split_equation | |||
from examples.models.nn import SymbolNet | |||
from abl.learning import ABLModel, BasicNN | |||
from examples.hed.reasoning import HedKB, HedReasoner | |||
from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | |||
from abl.utils import ABLLogger, print_log | |||
from examples.hed.bridge import HedBridge | |||
from bridge import HedBridge | |||
from datasets import get_dataset, split_equation | |||
from models.nn import SymbolNet | |||
from reasoning import HedKB, HedReasoner | |||
def main(): | |||
parser = argparse.ArgumentParser(description="Handwritten Equation Decipherment example") | |||
@@ -54,7 +56,7 @@ def main(): | |||
# Build necessary components for BasicNN | |||
cls = SymbolNet(num_classes=4) | |||
loss_fn = nn.CrossEntropyLoss() | |||
optimizer = torch.optim.RMSprop(cls.parameters(), lr=args.lr, weight_decay=args.weight_deccay) | |||
optimizer = torch.optim.RMSprop(cls.parameters(), lr=args.lr, weight_decay=args.weight_decay) | |||
use_cuda = not args.no_cuda and torch.cuda.is_available() | |||
device = torch.device("cuda" if use_cuda else "cpu") | |||
@@ -63,7 +65,7 @@ def main(): | |||
cls, | |||
loss_fn, | |||
optimizer, | |||
device, | |||
device=device, | |||
batch_size=args.batch_size, | |||
num_epochs=args.epochs, | |||
stop_loss=None, | |||
@@ -81,7 +83,7 @@ def main(): | |||
### Building Evaluation Metrics | |||
metric_list = [SymbolAccuracy(prefix="hed"), ReasoningMetric(kb=kb, prefix="hed")] | |||
### Bridge Learning and Reasoning | |||
bridge = HedBridge(model, reasoner, metric_list) | |||
@@ -0,0 +1,62 @@ | |||
# coding: utf-8 | |||
# ================================================================# | |||
# Copyright (C) 2021 Freecss All rights reserved. | |||
# | |||
# File Name :lenet5.py | |||
# Author :freecss | |||
# Email :karlfreecss@gmail.com | |||
# Created Date :2021/03/03 | |||
# Description : | |||
# | |||
# ================================================================# | |||
import torch | |||
from torch import nn | |||
class SymbolNet(nn.Module): | |||
def __init__(self, num_classes=4, image_size=(28, 28, 1)): | |||
super(SymbolNet, self).__init__() | |||
self.conv1 = nn.Sequential( | |||
nn.Conv2d(1, 32, 5, stride=1), | |||
nn.ReLU(), | |||
nn.MaxPool2d(kernel_size=2, stride=2), | |||
nn.BatchNorm2d(32, momentum=0.99, eps=0.001), | |||
) | |||
self.conv2 = nn.Sequential( | |||
nn.Conv2d(32, 64, 5, padding=2, stride=1), | |||
nn.ReLU(), | |||
nn.MaxPool2d(kernel_size=2, stride=2), | |||
nn.BatchNorm2d(64, momentum=0.99, eps=0.001), | |||
) | |||
num_features = 64 * (image_size[0] // 4 - 1) * (image_size[1] // 4 - 1) | |||
self.fc1 = nn.Sequential(nn.Linear(num_features, 120), nn.ReLU()) | |||
self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.ReLU()) | |||
self.fc3 = nn.Sequential(nn.Linear(84, num_classes)) | |||
def forward(self, x): | |||
x = self.conv1(x) | |||
x = self.conv2(x) | |||
x = torch.flatten(x, 1) | |||
x = self.fc1(x) | |||
x = self.fc2(x) | |||
x = self.fc3(x) | |||
return x | |||
class SymbolNetAutoencoder(nn.Module): | |||
def __init__(self, num_classes=4, image_size=(28, 28, 1)): | |||
super(SymbolNetAutoencoder, self).__init__() | |||
self.base_model = SymbolNet(num_classes, image_size) | |||
self.softmax = nn.Softmax(dim=1) | |||
self.fc1 = nn.Sequential(nn.Linear(num_classes, 100), nn.ReLU()) | |||
self.fc2 = nn.Sequential(nn.Linear(100, image_size[0] * image_size[1]), nn.ReLU()) | |||
def forward(self, x): | |||
x = self.base_model(x) | |||
# x = self.softmax(x) | |||
x = self.fc1(x) | |||
x = self.fc2(x) | |||
return x |
@@ -1,4 +1,3 @@ | |||
import os | |||
import os.path as osp | |||
import argparse | |||
@@ -6,14 +5,15 @@ import numpy as np | |||
import torch | |||
from torch import nn | |||
from examples.hwf.datasets import get_dataset | |||
from examples.models.nn import SymbolNet | |||
from abl.learning import ABLModel, BasicNN | |||
from abl.reasoning import KBBase, GroundKB, Reasoner | |||
from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | |||
from abl.utils import ABLLogger, print_log | |||
from abl.bridge import SimpleBridge | |||
from datasets import get_dataset | |||
from models.nn import SymbolNet | |||
class HwfKB(KBBase): | |||
def __init__( | |||
@@ -0,0 +1,46 @@ | |||
# coding: utf-8 | |||
# ================================================================# | |||
# Copyright (C) 2021 Freecss All rights reserved. | |||
# | |||
# File Name :lenet5.py | |||
# Author :freecss | |||
# Email :karlfreecss@gmail.com | |||
# Created Date :2021/03/03 | |||
# Description : | |||
# | |||
# ================================================================# | |||
import torch | |||
from torch import nn | |||
class SymbolNet(nn.Module): | |||
def __init__(self, num_classes=4, image_size=(28, 28, 1)): | |||
super(SymbolNet, self).__init__() | |||
self.conv1 = nn.Sequential( | |||
nn.Conv2d(1, 32, 5, stride=1), | |||
nn.ReLU(), | |||
nn.MaxPool2d(kernel_size=2, stride=2), | |||
nn.BatchNorm2d(32, momentum=0.99, eps=0.001), | |||
) | |||
self.conv2 = nn.Sequential( | |||
nn.Conv2d(32, 64, 5, padding=2, stride=1), | |||
nn.ReLU(), | |||
nn.MaxPool2d(kernel_size=2, stride=2), | |||
nn.BatchNorm2d(64, momentum=0.99, eps=0.001), | |||
) | |||
num_features = 64 * (image_size[0] // 4 - 1) * (image_size[1] // 4 - 1) | |||
self.fc1 = nn.Sequential(nn.Linear(num_features, 120), nn.ReLU()) | |||
self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.ReLU()) | |||
self.fc3 = nn.Sequential(nn.Linear(84, num_classes)) | |||
def forward(self, x): | |||
x = self.conv1(x) | |||
x = self.conv2(x) | |||
x = torch.flatten(x, 1) | |||
x = self.fc1(x) | |||
x = self.fc2(x) | |||
x = self.fc3(x) | |||
return x |
@@ -10,8 +10,9 @@ from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | |||
from abl.learning import ABLModel, BasicNN | |||
from abl.reasoning import GroundKB, KBBase, PrologKB, Reasoner | |||
from abl.utils import ABLLogger, print_log | |||
from examples.mnist_add.datasets import get_dataset | |||
from examples.models.nn import LeNet5 | |||
from datasets import get_dataset | |||
from models.nn import LeNet5 | |||
class AddKB(KBBase): | |||
@@ -25,13 +25,14 @@ | |||
"\n", | |||
"from torch.optim import RMSprop, lr_scheduler\n", | |||
"\n", | |||
"from examples.mnist_add.datasets import get_dataset\n", | |||
"from examples.models.nn import LeNet5\n", | |||
"from abl.bridge import SimpleBridge\n", | |||
"from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n", | |||
"from abl.learning import ABLModel, BasicNN\n", | |||
"from abl.reasoning import KBBase, Reasoner\n", | |||
"from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n", | |||
"from abl.utils import ABLLogger, print_log\n", | |||
"from abl.bridge import SimpleBridge" | |||
"\n", | |||
"from datasets import get_dataset\n", | |||
"from models.nn import LeNet5" | |||
] | |||
}, | |||
{ | |||
@@ -425,7 +426,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"execution_count": 14, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -0,0 +1,94 @@ | |||
# coding: utf-8 | |||
# ================================================================# | |||
# Copyright (C) 2021 Freecss All rights reserved. | |||
# | |||
# File Name :lenet5.py | |||
# Author :freecss | |||
# Email :karlfreecss@gmail.com | |||
# Created Date :2021/03/03 | |||
# Description : | |||
# | |||
# ================================================================# | |||
import numpy as np | |||
import torch | |||
from torch import nn | |||
class LeNet5(nn.Module): | |||
def __init__(self, num_classes=10, image_size=(28, 28)): | |||
super(LeNet5, self).__init__() | |||
self.conv1 = nn.Sequential( | |||
nn.Conv2d(1, 6, 3, padding=1), | |||
nn.ReLU(), | |||
nn.MaxPool2d(kernel_size=2, stride=2), | |||
) | |||
self.conv2 = nn.Sequential( | |||
nn.Conv2d(6, 16, 3), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) | |||
) | |||
self.conv3 = nn.Sequential(nn.Conv2d(16, 16, 3), nn.ReLU()) | |||
feature_map_size = (np.array(image_size) // 2 - 2) // 2 - 2 | |||
num_features = 16 * feature_map_size[0] * feature_map_size[1] | |||
self.fc1 = nn.Sequential(nn.Linear(num_features, 120), nn.ReLU()) | |||
self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.ReLU()) | |||
self.fc3 = nn.Linear(84, num_classes) | |||
def forward(self, x): | |||
x = self.conv1(x) | |||
x = self.conv2(x) | |||
x = self.conv3(x) | |||
x = torch.flatten(x, 1) | |||
x = self.fc1(x) | |||
x = self.fc2(x) | |||
x = self.fc3(x) | |||
return x | |||
class SymbolNet(nn.Module): | |||
def __init__(self, num_classes=4, image_size=(28, 28, 1)): | |||
super(SymbolNet, self).__init__() | |||
self.conv1 = nn.Sequential( | |||
nn.Conv2d(1, 32, 5, stride=1), | |||
nn.ReLU(), | |||
nn.MaxPool2d(kernel_size=2, stride=2), | |||
nn.BatchNorm2d(32, momentum=0.99, eps=0.001), | |||
) | |||
self.conv2 = nn.Sequential( | |||
nn.Conv2d(32, 64, 5, padding=2, stride=1), | |||
nn.ReLU(), | |||
nn.MaxPool2d(kernel_size=2, stride=2), | |||
nn.BatchNorm2d(64, momentum=0.99, eps=0.001), | |||
) | |||
num_features = 64 * (image_size[0] // 4 - 1) * (image_size[1] // 4 - 1) | |||
self.fc1 = nn.Sequential(nn.Linear(num_features, 120), nn.ReLU()) | |||
self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.ReLU()) | |||
self.fc3 = nn.Sequential(nn.Linear(84, num_classes)) | |||
def forward(self, x): | |||
x = self.conv1(x) | |||
x = self.conv2(x) | |||
x = torch.flatten(x, 1) | |||
x = self.fc1(x) | |||
x = self.fc2(x) | |||
x = self.fc3(x) | |||
return x | |||
class SymbolNetAutoencoder(nn.Module): | |||
def __init__(self, num_classes=4, image_size=(28, 28, 1)): | |||
super(SymbolNetAutoencoder, self).__init__() | |||
self.base_model = SymbolNet(num_classes, image_size) | |||
self.softmax = nn.Softmax(dim=1) | |||
self.fc1 = nn.Sequential(nn.Linear(num_classes, 100), nn.ReLU()) | |||
self.fc2 = nn.Sequential(nn.Linear(100, image_size[0] * image_size[1]), nn.ReLU()) | |||
def forward(self, x): | |||
x = self.base_model(x) | |||
# x = self.softmax(x) | |||
x = self.fc1(x) | |||
x = self.fc2(x) | |||
return x |