diff --git a/abl/bridge/base_bridge.py b/abl/bridge/base_bridge.py index 9aec7cb..8ec0ebd 100644 --- a/abl/bridge/base_bridge.py +++ b/abl/bridge/base_bridge.py @@ -1,9 +1,9 @@ from abc import ABCMeta, abstractmethod from typing import Any, List, Optional, Tuple, Union +from ..data.structures import ListData from ..learning import ABLModel from ..reasoning import Reasoner -from ..data.structures import ListData class BaseBridge(metaclass=ABCMeta): diff --git a/abl/bridge/simple_bridge.py b/abl/bridge/simple_bridge.py index e24706e..44a41e9 100644 --- a/abl/bridge/simple_bridge.py +++ b/abl/bridge/simple_bridge.py @@ -4,9 +4,9 @@ from typing import Any, List, Optional, Tuple, Union from numpy import ndarray from ..data.evaluation import BaseMetric +from ..data.structures import ListData from ..learning import ABLModel from ..reasoning import Reasoner -from ..data.structures import ListData from ..utils import print_log from .base_bridge import BaseBridge diff --git a/abl/data/__init__.py b/abl/data/__init__.py index 3dc4849..8084dd7 100644 --- a/abl/data/__init__.py +++ b/abl/data/__init__.py @@ -1,2 +1,4 @@ -from .evaluation import * -from .structures import * \ No newline at end of file +from .evaluation import BaseMetric, ReasoningMetric, SymbolAccuracy +from .structures import ListData + +__all__ = ["BaseMetric", "ReasoningMetric", "SymbolAccuracy", "ListData"] diff --git a/abl/data/evaluation/base_metric.py b/abl/data/evaluation/base_metric.py index 37e36dd..61f6428 100644 --- a/abl/data/evaluation/base_metric.py +++ b/abl/data/evaluation/base_metric.py @@ -2,8 +2,8 @@ import logging from abc import ABCMeta, abstractmethod from typing import Any, List, Optional -from ..structures import ListData from ...utils import print_log +from ..structures import ListData class BaseMetric(metaclass=ABCMeta): diff --git a/abl/learning/__init__.py b/abl/learning/__init__.py index c3cfa0a..ad016a6 100644 --- a/abl/learning/__init__.py +++ b/abl/learning/__init__.py @@ -1,5 +1,5 @@ from .abl_model import ABLModel from .basic_nn import BasicNN -from .torch_dataset import * +from .torch_dataset import ClassificationDataset, PredictionDataset, RegressionDataset -__all__ = ["ABLModel", "BasicNN"] +__all__ = ["ABLModel", "BasicNN", "ClassificationDataset", "PredictionDataset", "RegressionDataset"] diff --git a/abl/learning/basic_nn.py b/abl/learning/basic_nn.py index 15e58b6..0a346e5 100644 --- a/abl/learning/basic_nn.py +++ b/abl/learning/basic_nn.py @@ -8,8 +8,8 @@ import numpy import torch from torch.utils.data import DataLoader -from .torch_dataset import ClassificationDataset, PredictionDataset from ..utils.logger import print_log +from .torch_dataset import ClassificationDataset, PredictionDataset class BasicNN: diff --git a/abl/reasoning/kb.py b/abl/reasoning/kb.py index 5534e83..e8b1b8c 100644 --- a/abl/reasoning/kb.py +++ b/abl/reasoning/kb.py @@ -1,17 +1,17 @@ import bisect -import os import inspect import logging +import os from abc import ABC, abstractmethod from collections import defaultdict from itertools import combinations, product from multiprocessing import Pool -from typing import Callable, Any, List, Optional +from typing import Any, Callable, List, Optional import numpy as np -from ..utils.logger import print_log from ..utils.cache import abl_cache +from ..utils.logger import print_log from ..utils.utils import flatten, hamming_dist, reform_list, to_hashable diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index e7e609e..0799605 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -1,11 +1,11 @@ import inspect -from typing import Callable, Any, List, Optional, Union +from typing import Any, Callable, List, Optional, Union import numpy as np from zoopt import Dimension, Objective, Opt, Parameter, Solution -from ..reasoning import KBBase from ..data.structures import ListData +from ..reasoning import KBBase from ..utils.utils import confidence_dist, hamming_dist diff --git a/abl/utils/__init__.py b/abl/utils/__init__.py index cf22485..d69e09b 100644 --- a/abl/utils/__init__.py +++ b/abl/utils/__init__.py @@ -1,12 +1,6 @@ from .cache import Cache, abl_cache from .logger import ABLLogger, print_log -from .utils import ( - confidence_dist, - flatten, - hamming_dist, - reform_list, - to_hashable, -) +from .utils import confidence_dist, flatten, hamming_dist, reform_list, to_hashable __all__ = [ "Cache", diff --git a/docs/conf.py b/docs/conf.py index d67b6a8..b9bfbb9 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,9 +1,11 @@ import os import re import sys + from docutils import nodes from docutils.parsers.rst import roles + def colored_text_role(role, rawtext, text, lineno, inliner, options={}, content=[]): node = nodes.inline(rawtext, text, classes=[role]) return [node], [] diff --git a/examples/hed/bridge.py b/examples/hed/bridge.py index dd98025..0386d8c 100644 --- a/examples/hed/bridge.py +++ b/examples/hed/bridge.py @@ -5,15 +5,16 @@ from typing import Any, List, Optional, Tuple, Union import torch from abl.bridge import SimpleBridge -from abl.learning.torch_dataset import RegressionDataset from abl.data.evaluation import BaseMetric +from abl.data.structures import ListData from abl.learning import ABLModel, BasicNN +from abl.learning.torch_dataset import RegressionDataset from abl.reasoning import Reasoner -from abl.data.structures import ListData from abl.utils import print_log + from datasets import get_pretrain_data -from utils import InfiniteSampler, gen_mappings from models.nn import SymbolNetAutoencoder +from utils import InfiniteSampler, gen_mappings class HedBridge(SimpleBridge): diff --git a/examples/hed/consistency_metric.py b/examples/hed/consistency_metric.py index 5c68eb6..48f0a5e 100644 --- a/examples/hed/consistency_metric.py +++ b/examples/hed/consistency_metric.py @@ -1,8 +1,8 @@ from typing import Optional -from abl.reasoning import KBBase -from abl.data.structures import ListData from abl.data.evaluation.base_metric import BaseMetric +from abl.data.structures import ListData +from abl.reasoning import KBBase class ConsistencyMetric(BaseMetric): diff --git a/examples/hed/datasets/__init__.py b/examples/hed/datasets/__init__.py index ad88c85..b07b583 100644 --- a/examples/hed/datasets/__init__.py +++ b/examples/hed/datasets/__init__.py @@ -1,4 +1,3 @@ from .get_dataset import get_dataset, get_pretrain_data, split_equation - -__all__ = ["get_dataset", "get_pretrain_data", "split_equation"] \ No newline at end of file +__all__ = ["get_dataset", "get_pretrain_data", "split_equation"] diff --git a/examples/hed/datasets/get_dataset.py b/examples/hed/datasets/get_dataset.py index fb80f65..ced1463 100644 --- a/examples/hed/datasets/get_dataset.py +++ b/examples/hed/datasets/get_dataset.py @@ -2,11 +2,11 @@ import os import os.path as osp import pickle import random -import gdown import zipfile from collections import defaultdict import cv2 +import gdown import numpy as np from torchvision.transforms import transforms diff --git a/examples/hed/main.py b/examples/hed/main.py index 07a6321..e2233c4 100644 --- a/examples/hed/main.py +++ b/examples/hed/main.py @@ -1,16 +1,17 @@ -import os.path as osp import argparse +import os.path as osp import torch import torch.nn as nn -from datasets import get_dataset, split_equation -from models.nn import SymbolNet from abl.learning import ABLModel, BasicNN -from reasoning import HedKB, HedReasoner -from consistency_metric import ConsistencyMetric from abl.utils import ABLLogger, print_log + from bridge import HedBridge +from consistency_metric import ConsistencyMetric +from datasets import get_dataset, split_equation +from models.nn import SymbolNet +from reasoning import HedKB, HedReasoner def main(): diff --git a/examples/hed/reasoning/reasoning.py b/examples/hed/reasoning/reasoning.py index 4263cae..788fbc3 100644 --- a/examples/hed/reasoning/reasoning.py +++ b/examples/hed/reasoning/reasoning.py @@ -1,6 +1,8 @@ +import math import os + import numpy as np -import math + from abl.reasoning import PrologKB, Reasoner from abl.utils import reform_list diff --git a/examples/hwf/datasets/get_dataset.py b/examples/hwf/datasets/get_dataset.py index 6c79d0f..d89b1e3 100644 --- a/examples/hwf/datasets/get_dataset.py +++ b/examples/hwf/datasets/get_dataset.py @@ -1,8 +1,8 @@ import json import os -import gdown import zipfile +import gdown from PIL import Image from torchvision.transforms import transforms diff --git a/examples/hwf/main.py b/examples/hwf/main.py index dbfcc5c..328e849 100644 --- a/examples/hwf/main.py +++ b/examples/hwf/main.py @@ -5,13 +5,14 @@ import numpy as np import torch from torch import nn -from datasets import get_dataset -from models.nn import SymbolNet -from abl.learning import ABLModel, BasicNN -from abl.reasoning import KBBase, GroundKB, Reasoner +from abl.bridge import SimpleBridge from abl.data.evaluation import ReasoningMetric, SymbolAccuracy +from abl.learning import ABLModel, BasicNN +from abl.reasoning import GroundKB, KBBase, Reasoner from abl.utils import ABLLogger, print_log -from abl.bridge import SimpleBridge + +from datasets import get_dataset +from models.nn import SymbolNet class HwfKB(KBBase): diff --git a/examples/mnist_add/main.py b/examples/mnist_add/main.py index cc6af7b..ab1fa98 100644 --- a/examples/mnist_add/main.py +++ b/examples/mnist_add/main.py @@ -5,13 +5,14 @@ import torch from torch import nn from torch.optim import RMSprop, lr_scheduler -from datasets import get_dataset -from models.nn import LeNet5 +from abl.bridge import SimpleBridge +from abl.data.evaluation import ReasoningMetric, SymbolAccuracy from abl.learning import ABLModel, BasicNN from abl.reasoning import GroundKB, KBBase, PrologKB, Reasoner -from abl.data.evaluation import ReasoningMetric, SymbolAccuracy from abl.utils import ABLLogger, print_log -from abl.bridge import SimpleBridge + +from datasets import get_dataset +from models.nn import LeNet5 class AddKB(KBBase): diff --git a/examples/zoo/get_dataset.py b/examples/zoo/get_dataset.py index e7dd3db..600b338 100644 --- a/examples/zoo/get_dataset.py +++ b/examples/zoo/get_dataset.py @@ -1,6 +1,7 @@ import numpy as np import openml + # Function to load and preprocess the dataset def load_and_preprocess_dataset(dataset_id): dataset = openml.datasets.get_dataset(dataset_id, download_data=True, download_qualities=False, download_features_meta_data=False) diff --git a/examples/zoo/kb.py b/examples/zoo/kb.py index 7442f6e..4757184 100644 --- a/examples/zoo/kb.py +++ b/examples/zoo/kb.py @@ -1,7 +1,9 @@ -from z3 import Solver, Int, If, Not, Implies, Sum, sat import openml +from z3 import If, Implies, Int, Not, Solver, Sum, sat # noqa: F401 + from abl.reasoning import KBBase + class ZooKB(KBBase): def __init__(self): super().__init__(pseudo_label_list=list(range(7)), use_cache=False) diff --git a/examples/zoo/main.py b/examples/zoo/main.py index f5c86eb..26bdc66 100644 --- a/examples/zoo/main.py +++ b/examples/zoo/main.py @@ -1,16 +1,18 @@ -import os.path as osp import argparse +import os.path as osp import numpy as np from sklearn.ensemble import RandomForestClassifier -from get_dataset import load_and_preprocess_dataset, split_dataset +from abl.bridge import SimpleBridge +from abl.data.evaluation import ReasoningMetric, SymbolAccuracy from abl.learning import ABLModel -from kb import ZooKB from abl.reasoning import Reasoner -from abl.data.evaluation import ReasoningMetric, SymbolAccuracy -from abl.utils import ABLLogger, print_log, confidence_dist -from abl.bridge import SimpleBridge +from abl.utils import ABLLogger, confidence_dist, print_log + +from get_dataset import load_and_preprocess_dataset, split_dataset +from kb import ZooKB + def transform_tab_data(X, y): return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y)) diff --git a/tests/conftest.py b/tests/conftest.py index dc299a8..3d9a2cd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,12 @@ -import pytest import numpy as np +import pytest import torch import torch.nn as nn import torch.optim as optim +from abl.data.structures import ListData from abl.learning import BasicNN from abl.reasoning import GroundKB, KBBase, PrologKB, Reasoner -from abl.data.structures import ListData class LeNet5(nn.Module): diff --git a/tests/test_basic_nn.py b/tests/test_basic_nn.py index b0bdada..deaa760 100644 --- a/tests/test_basic_nn.py +++ b/tests/test_basic_nn.py @@ -5,6 +5,7 @@ import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset + class TestBasicNN(object): @pytest.fixture def sample_data(self): diff --git a/tests/test_reasoning.py b/tests/test_reasoning.py index 744b10d..e97ce03 100644 --- a/tests/test_reasoning.py +++ b/tests/test_reasoning.py @@ -1,5 +1,5 @@ -import pytest import numpy as np +import pytest from abl.reasoning import PrologKB, Reasoner