Browse Source

[MNT] resolve comments in basic_nn and abl_model

ab_data
Gao Enhao 1 year ago
parent
commit
4f388e1e2b
6 changed files with 110 additions and 18 deletions
  1. +2
    -4
      abl/bridge/simple_bridge.py
  2. +2
    -1
      abl/dataset/__init__.py
  3. +56
    -0
      abl/dataset/prediction_dataset.py
  4. +5
    -2
      abl/learning/abl_model.py
  5. +44
    -9
      abl/learning/basic_nn.py
  6. +1
    -2
      abl/reasoning/ground_kb.py

+ 2
- 4
abl/bridge/simple_bridge.py View File

@@ -24,10 +24,8 @@ class SimpleBridge(BaseBridge):
# TODO: add abducer.mapping to the property of SimpleBridge # TODO: add abducer.mapping to the property of SimpleBridge


def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]: def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]:
pred_res = self.model.predict(data_samples)
data_samples.pred_idx = pred_res["label"]
data_samples.pred_prob = pred_res["prob"]
return data_samples["pred_idx"], data_samples["pred_prob"]
self.model.predict(data_samples)
return data_samples["pred_idx"], data_samples.get("pred_prob", None)


def abduce_pseudo_label( def abduce_pseudo_label(
self, self,


+ 2
- 1
abl/dataset/__init__.py View File

@@ -1,3 +1,4 @@
from .bridge_dataset import BridgeDataset from .bridge_dataset import BridgeDataset
from .classification_dataset import ClassificationDataset from .classification_dataset import ClassificationDataset
from .regression_dataset import RegressionDataset
from .prediction_dataset import PredictionDataset
from .regression_dataset import RegressionDataset

+ 56
- 0
abl/dataset/prediction_dataset.py View File

@@ -0,0 +1,56 @@
from typing import Any, Callable, List, Tuple

import torch
from torch.utils.data import Dataset


class PredictionDataset(Dataset):
def __init__(self, X: List[Any], transform: Callable[..., Any] = None):
"""
Initialize the dataset used for classification task.

Parameters
----------
X : List[Any]
The input data.
transform : Callable[..., Any], optional
A function/transform that takes in an object and returns a transformed version. Defaults to None.
"""
if not isinstance(X, list):
raise ValueError("X should be of type list.")

self.X = X
self.transform = transform

def __len__(self) -> int:
"""
Return the length of the dataset.

Returns
-------
int
The length of the dataset.
"""
return len(self.X)

def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]:
"""
Get the item at the given index.

Parameters
----------
index : int
The index of the item to get.

Returns
-------
Tuple[Any, torch.Tensor]
A tuple containing the object and its label.
"""
if index >= len(self):
raise ValueError("index range error")

x = self.X[index]
if self.transform is not None:
x = self.transform(x)
return x

+ 5
- 2
abl/learning/abl_model.py View File

@@ -71,11 +71,14 @@ class ABLModel:
label = prob.argmax(axis=1) label = prob.argmax(axis=1)
prob = reform_idx(prob, data_samples["X"]) prob = reform_idx(prob, data_samples["X"])
else: else:
prob = [None] * len(data_samples)
prob = None
label = model.predict(X=data_X) label = model.predict(X=data_X)

label = reform_idx(label, data_samples["X"]) label = reform_idx(label, data_samples["X"])


data_samples.pred_idx = label
if prob is not None:
data_samples.pred_prob = prob

return {"label": label, "prob": prob} return {"label": label, "prob": prob}


def train(self, data_samples: ListData) -> float: def train(self, data_samples: ListData) -> float:


+ 44
- 9
abl/learning/basic_nn.py View File

@@ -11,13 +11,14 @@
# ================================================================# # ================================================================#


import os import os
import logging
from typing import Any, Callable, List, Optional, T, Tuple from typing import Any, Callable, List, Optional, T, Tuple


import numpy import numpy
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader


from ..dataset import ClassificationDataset
from ..dataset import ClassificationDataset, PredictionDataset
from ..utils.logger import print_log from ..utils.logger import print_log




@@ -197,7 +198,12 @@ class BasicNN:


return torch.cat(results, axis=0) return torch.cat(results, axis=0)


def predict(self, data_loader: DataLoader = None, X: List[Any] = None) -> numpy.ndarray:
def predict(
self,
data_loader: DataLoader = None,
X: List[Any] = None,
test_transform: Callable[..., Any] = None,
) -> numpy.ndarray:
""" """
Predict the class of the input data. Predict the class of the input data.


@@ -215,12 +221,29 @@ class BasicNN:
""" """


if data_loader is None: if data_loader is None:
if self.transform is not None:
X = [self.transform(x) for x in X]
data_loader = DataLoader(X, batch_size=self.batch_size)
if test_transform is None:
print_log(
"Transform used in the training phase will be used in prediction.",
"current",
level=logging.WARNING,
)
dataset = PredictionDataset(X, self.transform)
else:
dataset = PredictionDataset(X, test_transform)
data_loader = DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=int(self.num_workers),
collate_fn=self.collate_fn,
)
return self._predict(data_loader).argmax(axis=1).cpu().numpy() return self._predict(data_loader).argmax(axis=1).cpu().numpy()


def predict_proba(self, data_loader: DataLoader = None, X: List[Any] = None) -> numpy.ndarray:
def predict_proba(
self,
data_loader: DataLoader = None,
X: List[Any] = None,
test_transform: Callable[..., Any] = None,
) -> numpy.ndarray:
""" """
Predict the probability of each class for the input data. Predict the probability of each class for the input data.


@@ -238,9 +261,21 @@ class BasicNN:
""" """


if data_loader is None: if data_loader is None:
if self.transform is not None:
X = [self.transform(x) for x in X]
data_loader = DataLoader(X, batch_size=self.batch_size)
if test_transform is None:
print_log(
"Transform used in the training phase will be used in prediction.",
"current",
level=logging.WARNING,
)
dataset = PredictionDataset(X, self.transform)
else:
dataset = PredictionDataset(X, test_transform)
data_loader = DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=int(self.num_workers),
collate_fn=self.collate_fn,
)
return self._predict(data_loader).softmax(axis=1).cpu().numpy() return self._predict(data_loader).softmax(axis=1).cpu().numpy()


def _score(self, data_loader) -> Tuple[float, float]: def _score(self, data_loader) -> Tuple[float, float]:


+ 1
- 2
abl/reasoning/ground_kb.py View File

@@ -1,8 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Hashable, List from typing import Any, Hashable, List


from abl.structures import ListData

from ..structures import ListData
from .base_kb import BaseKB from .base_kb import BaseKB






Loading…
Cancel
Save