Browse Source

[MNT] minor modification

pull/3/head
Gao Enhao 2 years ago
parent
commit
6e88dc9469
3 changed files with 7 additions and 3 deletions
  1. +1
    -1
      abl/reasoning/__init__.py
  2. +3
    -0
      abl/reasoning/reasoner.py
  3. +3
    -2
      examples/models/nn.py

+ 1
- 1
abl/reasoning/__init__.py View File

@@ -1,2 +1,2 @@
from .reasoner import ReasonerBase
from .kb import KBBase
from .kb import KBBase, prolog_KB

+ 3
- 0
abl/reasoning/reasoner.py View File

@@ -28,6 +28,9 @@ class ReasonerBase(abc.ABC):
)
else:
self.mapping = mapping
self.set_remapping()
def set_remapping(self):
self.remapping = dict(zip(self.mapping.values(), self.mapping.keys()))

def _get_cost_list(self, pseudo_label, pred_res_prob, candidates):


+ 3
- 2
examples/models/nn.py View File

@@ -70,7 +70,7 @@ class SymbolNet(nn.Module):
num_features = 64 * (image_size[0] // 4 - 1) * (image_size[1] // 4 - 1)
self.fc1 = nn.Sequential(nn.Linear(num_features, 120), nn.ReLU())
self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.ReLU())
self.fc3 = nn.Linear(84, num_classes)
self.fc3 = nn.Sequential(nn.Linear(84, num_classes), nn.Softmax(dim=1))

def forward(self, x):
x = self.conv1(x)
@@ -86,6 +86,7 @@ class SymbolNetAutoencoder(nn.Module):
def __init__(self, num_classes=4, image_size=(28, 28, 1)):
super(SymbolNetAutoencoder, self).__init__()
self.base_model = SymbolNet(num_classes, image_size)
self.softmax = nn.Softmax(dim=1)
self.fc1 = nn.Sequential(nn.Linear(num_classes, 100), nn.ReLU())
self.fc2 = nn.Sequential(
nn.Linear(100, image_size[0] * image_size[1]), nn.ReLU()
@@ -93,7 +94,7 @@ class SymbolNetAutoencoder(nn.Module):

def forward(self, x):
x = self.base_model(x)
x = nn.Softmax(x, dim=1)
# x = self.softmax(x)
x = self.fc1(x)
x = self.fc2(x)
return x

Loading…
Cancel
Save