Browse Source

Merge branch 'Dev' of https://github.com/AbductiveLearning/ABL-Package into Dev

pull/1/head
Tony-HYX 1 year ago
parent
commit
70b2cc5db2
5 changed files with 16 additions and 13 deletions
  1. +9
    -4
      abl/learning/basic_nn.py
  2. +1
    -1
      docs/Examples/HED.rst
  3. +4
    -6
      docs/Examples/Zoo.rst
  4. +1
    -1
      pyproject.toml
  5. +1
    -1
      requirements.txt

+ 9
- 4
abl/learning/basic_nn.py View File

@@ -2,7 +2,7 @@ from __future__ import annotations

import logging
import os
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, Callable, List, Optional, Tuple, Union

import numpy
import torch
@@ -28,7 +28,7 @@ class BasicNN:
The learning rate scheduler used for training, which will be called
at the end of each run of the ``fit`` method. It should implement the
``step`` method, by default None.
device : torch.device, optional
device : Union[torch.device, str]
The device on which the model will be trained or used for prediction,
by default torch.device("cpu").
batch_size : int, optional
@@ -59,7 +59,7 @@ class BasicNN:
loss_fn: torch.nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: Optional[Callable[..., Any]] = None,
device: torch.device = torch.device("cpu"),
device: Union[torch.device, str] = torch.device("cpu"),
batch_size: int = 32,
num_epochs: int = 1,
stop_loss: Optional[float] = 0.0001,
@@ -79,7 +79,12 @@ class BasicNN:
if scheduler is not None and not hasattr(scheduler, "step"):
raise NotImplementedError("scheduler should implement the ``step`` method")
if not isinstance(device, torch.device):
raise TypeError("device must be an instance of torch.device")
if not isinstance(device, str):
raise TypeError(
"device must be an instance of torch.device or a str indicates the target device"
)
else:
device = torch.device(device)
if not isinstance(batch_size, int):
raise TypeError("batch_size must be an integer")
if not isinstance(num_epochs, int):


+ 1
- 1
docs/Examples/HED.rst View File

@@ -188,7 +188,7 @@ sklearn-style interface.
cls,
loss_fn,
optimizer,
device,
device=device,
batch_size=32,
num_epochs=1,
stop_loss=None,


+ 4
- 6
docs/Examples/Zoo.rst View File

@@ -32,7 +32,7 @@ further update the learning model.
from abl.data.evaluation import ReasoningMetric, SymbolAccuracy
from abl.learning import ABLModel
from abl.reasoning import Reasoner
from abl.utils import ABLLogger, confidence_dist, print_log
from abl.utils import ABLLogger, confidence_dist, print_log, tab_data_to_tuple

from get_dataset import load_and_preprocess_dataset, split_dataset
from kb import ZooKB
@@ -91,11 +91,9 @@ indicating no rules are violated.

.. code:: ipython3

def transform_tab_data(X, y):
return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y))
label_data = transform_tab_data(X_label, y_label)
test_data = transform_tab_data(X_test, y_test)
train_data = transform_tab_data(X_unlabel, y_unlabel)
label_data = tab_data_to_tuple(X_label, y_label, reasoning_result = 0)
data = tab_data_to_tuple(X_test, y_test, reasoning_result = 0)
train_data = tab_data_to_tuple(X_unlabel, y_unlabel, reasoning_result = 0)

Building the Learning Part
--------------------------


+ 1
- 1
pyproject.toml View File

@@ -25,7 +25,7 @@ classifiers = [
"Programming Language :: Python :: 3.9",
]
dependencies = [
"numpy>=1.21.6",
"numpy>=1.15.0",
"pyswip>=0.2.9",
"torch>=1.11.0",
"torchvision>=0.12.0",


+ 1
- 1
requirements.txt View File

@@ -1,4 +1,4 @@
numpy>=1.21.6,
numpy>=1.15.0,
pyswip>=0.2.9,
torch>=1.11.0,
torchvision>=0.12.0,


Loading…
Cancel
Save