Browse Source

Merge branch 'Dev' of https://github.com/AbductiveLearning/ABL-Package into Dev

pull/1/head
Gao Enhao 1 year ago
parent
commit
b43ae6c189
14 changed files with 88 additions and 48 deletions
  1. +1
    -1
      README.md
  2. +3
    -3
      abl/learning/basic_nn.py
  3. +1
    -1
      abl/reasoning/kb.py
  4. +2
    -1
      abl/utils/__init__.py
  5. +10
    -1
      abl/utils/utils.py
  6. +1
    -1
      docs/Overview/Installation.rst
  7. +1
    -1
      docs/README.rst
  8. +11
    -3
      examples/hed/main.py
  9. +11
    -3
      examples/hwf/main.py
  10. +10
    -3
      examples/mnist_add/main.py
  11. +20
    -11
      examples/zoo/main.py
  12. +4
    -6
      examples/zoo/zoo.ipynb
  13. +7
    -7
      pyproject.toml
  14. +6
    -6
      requirements.txt

+ 1
- 1
README.md View File

@@ -53,7 +53,7 @@ For Linux users:
$ sudo apt-get install swi-prolog
```

For Windows and Mac users, please refer to the [Swi-Prolog Download Page](https://www.swi-prolog.org/Download.html).
For Windows and Mac users, please refer to the [Swi-Prolog Install Guide](https://github.com/yuce/pyswip/blob/master/INSTALL.md).

## Examples



+ 3
- 3
abl/learning/basic_nn.py View File

@@ -66,9 +66,9 @@ class BasicNN:
num_workers: int = 0,
save_interval: Optional[int] = None,
save_dir: Optional[str] = None,
train_transform: Callable[..., Any] = None,
test_transform: Callable[..., Any] = None,
collate_fn: Callable[[List[Any]], Any] = None,
train_transform: Optional[Callable[..., Any]] = None,
test_transform: Optional[Callable[..., Any]] = None,
collate_fn: Optional[Callable[[List[Any]], Any]] = None,
) -> None:
if not isinstance(model, torch.nn.Module):
raise TypeError("model must be an instance of torch.nn.Module")


+ 1
- 1
abl/reasoning/kb.py View File

@@ -471,7 +471,7 @@ class PrologKB(KBBase):
except (IndexError, ImportError):
print("A Prolog-based knowledge base is in use. Please install Swi-Prolog \
using the command 'sudo apt-get install swi-prolog' for Linux users, \
or download it from https://www.swi-prolog.org/Download.html for Windows and Mac users.")
or download it following the guide in https://github.com/yuce/pyswip/blob/master/INSTALL.md for Windows and Mac users.")
self.prolog = pyswip.Prolog()
self.pl_file = pl_file


+ 2
- 1
abl/utils/__init__.py View File

@@ -1,6 +1,6 @@
from .cache import Cache, abl_cache
from .logger import ABLLogger, print_log
from .utils import confidence_dist, flatten, hamming_dist, reform_list, to_hashable
from .utils import confidence_dist, flatten, hamming_dist, reform_list, to_hashable, tab_data_to_tuple

__all__ = [
"Cache",
@@ -12,4 +12,5 @@ __all__ = [
"reform_list",
"to_hashable",
"abl_cache",
"tab_data_to_tuple",
]

+ 10
- 1
abl/utils/utils.py View File

@@ -154,4 +154,13 @@ def restore_from_hashable(x):
return [restore_from_hashable(item) for item in x]
return x


def tab_data_to_tuple(X, y, reasoning_result = 0):
'''
Convert a tabular data to a tuple by adding a dimension to each element of X and y. The tuple contains three elements: data, label, and reasoning result.
If X is None, return None.
'''
if X is None:
return None
if len(X) != len(y):
raise ValueError("The length of X and y should be the same, but got {} and {}.".format(len(X), len(y)))
return ([[x] for x in X], [[y_item] for y_item in y], [reasoning_result] * len(y))

+ 1
- 1
docs/Overview/Installation.rst View File

@@ -31,4 +31,4 @@ For Linux users:

$ sudo apt-get install swi-prolog

For Windows and Mac users, please refer to the `Swi-Prolog Download Page <https://www.swi-prolog.org/Download.html>`_.
For Windows and Mac users, please refer to the `Swi-Prolog Install Guide <https://github.com/yuce/pyswip/blob/master/INSTALL.md>`_.

+ 1
- 1
docs/README.rst View File

@@ -51,7 +51,7 @@ For Linux users:

$ sudo apt-get install swi-prolog

For Windows and Mac users, please refer to the `Swi-Prolog Download Page <https://www.swi-prolog.org/Download.html>`_.
For Windows and Mac users, please refer to the `Swi-Prolog Install Guide <https://github.com/yuce/pyswip/blob/master/INSTALL.md>`_.

References
----------


+ 11
- 3
examples/hed/main.py View File

@@ -46,13 +46,20 @@ def main():
)

args = parser.parse_args()
# Build logger
print_log("Abductive Learning on the HED example.", logger="current")

### Working with Data
print_log("Working with Data.", logger="current")
total_train_data = get_dataset(train=True)
train_data, val_data = split_equation(total_train_data, 3, 1)
test_data = get_dataset(train=False)

### Building the Learning Part
print_log("Building the Learning Part.", logger="current")
# Build necessary components for BasicNN
cls = SymbolNet(num_classes=4)
loss_fn = nn.CrossEntropyLoss()
@@ -75,6 +82,8 @@ def main():
model = ABLModel(base_model)

### Building the Reasoning Part
print_log("Building the Reasoning Part.", logger="current")
# Build knowledge base
kb = HedKB()

@@ -82,14 +91,13 @@ def main():
reasoner = HedReasoner(kb, dist_func="hamming", use_zoopt=True, max_revision=args.max_revision)

### Building Evaluation Metrics
print_log("Building Evaluation Metrics.", logger="current")
metric_list = [ConsistencyMetric(kb=kb)]

### Bridge Learning and Reasoning
print_log("Bridge Learning and Reasoning.", logger="current")
bridge = HedBridge(model, reasoner, metric_list)

# Build logger
print_log("Abductive Learning on the HED example.", logger="current")

# Retrieve the directory of the Log file and define the directory for saving the model weights.
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")


+ 11
- 3
examples/hwf/main.py View File

@@ -113,12 +113,19 @@ def main():
)

args = parser.parse_args()
# Build logger
print_log("Abductive Learning on the HWF example.", logger="current")

### Working with Data
print_log("Working with Data.", logger="current")
train_data = get_dataset(train=True, get_pseudo_label=True)
test_data = get_dataset(train=False, get_pseudo_label=True)

### Building the Learning Part
print_log("Building the Learning Part.", logger="current")
# Build necessary components for BasicNN
cls = SymbolNet(num_classes=13, image_size=(45, 45, 1))
loss_fn = nn.CrossEntropyLoss()
@@ -140,6 +147,8 @@ def main():
model = ABLModel(base_model)

### Building the Reasoning Part
print_log("Building the Reasoning Part.", logger="current")
# Build knowledge base
if args.ground:
kb = HwfGroundKB()
@@ -152,14 +161,13 @@ def main():
)

### Building Evaluation Metrics
print_log("Building Evaluation Metrics.", logger="current")
metric_list = [SymbolAccuracy(prefix="hwf"), ReasoningMetric(kb=kb, prefix="hwf")]

### Bridge Learning and Reasoning
print_log("Bridge Learning and Reasoning.", logger="current")
bridge = SimpleBridge(model, reasoner, metric_list)

# Build logger
print_log("Abductive Learning on the HWF example.", logger="current")

# Retrieve the directory of the Log file and define the directory for saving the model weights.
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")


+ 10
- 3
examples/mnist_add/main.py View File

@@ -78,11 +78,17 @@ def main():

args = parser.parse_args()

# Build logger
print_log("Abductive Learning on the MNIST Addition example.", logger="current")

### Working with Data
print_log("Working with Data.", logger="current")
train_data = get_dataset(train=True, get_pseudo_label=True)
test_data = get_dataset(train=False, get_pseudo_label=True)

### Building the Learning Part
print_log("Building the Learning Part.", logger="current")
# Build necessary components for BasicNN
cls = LeNet5(num_classes=10)
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.2)
@@ -112,6 +118,8 @@ def main():
model = ABLModel(base_model)

### Building the Reasoning Part
print_log("Building the Reasoning Part.", logger="current")
# Build knowledge base
if args.prolog:
kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="add.pl")
@@ -126,14 +134,13 @@ def main():
)

### Building Evaluation Metrics
print_log("Building Evaluation Metrics.", logger="current")
metric_list = [SymbolAccuracy(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")]

### Bridge Learning and Reasoning
print_log("Bridge Learning and Reasoning.", logger="current")
bridge = SimpleBridge(model, reasoner, metric_list)

# Build logger
print_log("Abductive Learning on the MNIST Addition example.", logger="current")

# Retrieve the directory of the Log file and define the directory for saving the model weights.
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")


+ 20
- 11
examples/zoo/main.py View File

@@ -8,14 +8,12 @@ from abl.bridge import SimpleBridge
from abl.data.evaluation import ReasoningMetric, SymbolAccuracy
from abl.learning import ABLModel
from abl.reasoning import Reasoner
from abl.utils import ABLLogger, confidence_dist, print_log
from abl.utils import ABLLogger, confidence_dist, print_log, tab_data_to_tuple

from get_dataset import load_and_preprocess_dataset, split_dataset
from kb import ZooKB


def transform_tab_data(X, y):
return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y))

def consitency(data_example, candidates, candidate_idxs, reasoning_results):
pred_prob = data_example.pred_prob
@@ -30,21 +28,31 @@ def main():
"--loops", type=int, default=3, help="number of loop iterations (default : 3)"
)
args = parser.parse_args()
# Build logger
print_log("Abductive Learning on the ZOO example.", logger="current")

### Working with Data
print_log("Working with Data.", logger="current")
X, y = load_and_preprocess_dataset(dataset_id=62)
X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3)
label_data = transform_tab_data(X_label, y_label)
test_data = transform_tab_data(X_test, y_test)
train_data = transform_tab_data(X_unlabel, y_unlabel)
label_data = tab_data_to_tuple(X_label, y_label)
test_data = tab_data_to_tuple(X_test, y_test)
train_data = tab_data_to_tuple(X_unlabel, y_unlabel)

### Building the Learning Part
print_log("Building the Learning Part.", logger="current")
# Build base model
base_model = RandomForestClassifier()

# Build ABLModel
model = ABLModel(base_model)

### Building the Reasoning Part
print_log("Building the Reasoning Part.", logger="current")
# Build knowledge base
kb = ZooKB()
@@ -52,16 +60,17 @@ def main():
reasoner = Reasoner(kb, dist_func=consitency)

### Building Evaluation Metrics
print_log("Building Evaluation Metrics.", logger="current")
metric_list = [SymbolAccuracy(prefix="zoo"), ReasoningMetric(kb=kb, prefix="zoo")]
# Build logger
print_log("Abductive Learning on the ZOO example.", logger="current")
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")
### Bridging learning and reasoning
print_log("Bridge Learning and Reasoning.", logger="current")
bridge = SimpleBridge(model, reasoner, metric_list)
# Retrieve the directory of the Log file and define the directory for saving the model weights.
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")
# Performing training and testing
print_log("------- Use labeled data to pretrain the model -----------", logger="current")
base_model.fit(X_label, y_label)


+ 4
- 6
examples/zoo/zoo.ipynb View File

@@ -27,7 +27,7 @@
"from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n",
"from abl.learning import ABLModel\n",
"from abl.reasoning import Reasoner\n",
"from abl.utils import ABLLogger, confidence_dist, print_log\n",
"from abl.utils import ABLLogger, confidence_dist, print_log, tab_data_to_tuple\n",
"\n",
"from get_dataset import load_and_preprocess_dataset, split_dataset\n",
"from kb import ZooKB"
@@ -106,11 +106,9 @@
"metadata": {},
"outputs": [],
"source": [
"def transform_tab_data(X, y):\n",
" return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y))\n",
"label_data = transform_tab_data(X_label, y_label)\n",
"test_data = transform_tab_data(X_test, y_test)\n",
"train_data = transform_tab_data(X_unlabel, y_unlabel)"
"label_data = tab_data_to_tuple(X_label, y_label, reasoning_result = 0)\n",
"test_data = tab_data_to_tuple(X_test, y_test, reasoning_result = 0)\n",
"train_data = tab_data_to_tuple(X_unlabel, y_unlabel, reasoning_result = 0)"
]
},
{


+ 7
- 7
pyproject.toml View File

@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
name = "abl"
version = "0.1.4"
authors = [
{ name="LAMDA 2023" },
{ name="LAMDA 2024" },
]
description = "Abductive learning package project"
readme = "README.md"
@@ -25,12 +25,12 @@ classifiers = [
"Programming Language :: Python :: 3.9",
]
dependencies = [
"numpy",
"pyswip==0.2.9",
"torch",
"torchvision",
"zoopt",
"termcolor"
"numpy>=1.15.0",
"pyswip>=0.2.9",
"torch>=1.11.0",
"torchvision>=0.12.0",
"zoopt>=0.3.0",
"termcolor>=2.3.0"
]

[project.urls]


+ 6
- 6
requirements.txt View File

@@ -1,6 +1,6 @@
numpy
pyswip==0.2.9
torch
torchvision
zoopt
termcolor
numpy>=1.15.0,
pyswip>=0.2.9,
torch>=1.11.0,
torchvision>=0.12.0,
zoopt>=0.3.0,
termcolor>=2.3.0

Loading…
Cancel
Save