Browse Source

[MNT] resolve comments in basic_nn and abl_model

ab_data
Gao Enhao 1 year ago
parent
commit
4f388e1e2b
6 changed files with 110 additions and 18 deletions
  1. +2
    -4
      abl/bridge/simple_bridge.py
  2. +2
    -1
      abl/dataset/__init__.py
  3. +56
    -0
      abl/dataset/prediction_dataset.py
  4. +5
    -2
      abl/learning/abl_model.py
  5. +44
    -9
      abl/learning/basic_nn.py
  6. +1
    -2
      abl/reasoning/ground_kb.py

+ 2
- 4
abl/bridge/simple_bridge.py View File

@@ -24,10 +24,8 @@ class SimpleBridge(BaseBridge):
# TODO: add abducer.mapping to the property of SimpleBridge

def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]:
pred_res = self.model.predict(data_samples)
data_samples.pred_idx = pred_res["label"]
data_samples.pred_prob = pred_res["prob"]
return data_samples["pred_idx"], data_samples["pred_prob"]
self.model.predict(data_samples)
return data_samples["pred_idx"], data_samples.get("pred_prob", None)

def abduce_pseudo_label(
self,


+ 2
- 1
abl/dataset/__init__.py View File

@@ -1,3 +1,4 @@
from .bridge_dataset import BridgeDataset
from .classification_dataset import ClassificationDataset
from .regression_dataset import RegressionDataset
from .prediction_dataset import PredictionDataset
from .regression_dataset import RegressionDataset

+ 56
- 0
abl/dataset/prediction_dataset.py View File

@@ -0,0 +1,56 @@
from typing import Any, Callable, List, Tuple

import torch
from torch.utils.data import Dataset


class PredictionDataset(Dataset):
def __init__(self, X: List[Any], transform: Callable[..., Any] = None):
"""
Initialize the dataset used for classification task.

Parameters
----------
X : List[Any]
The input data.
transform : Callable[..., Any], optional
A function/transform that takes in an object and returns a transformed version. Defaults to None.
"""
if not isinstance(X, list):
raise ValueError("X should be of type list.")

self.X = X
self.transform = transform

def __len__(self) -> int:
"""
Return the length of the dataset.

Returns
-------
int
The length of the dataset.
"""
return len(self.X)

def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]:
"""
Get the item at the given index.

Parameters
----------
index : int
The index of the item to get.

Returns
-------
Tuple[Any, torch.Tensor]
A tuple containing the object and its label.
"""
if index >= len(self):
raise ValueError("index range error")

x = self.X[index]
if self.transform is not None:
x = self.transform(x)
return x

+ 5
- 2
abl/learning/abl_model.py View File

@@ -71,11 +71,14 @@ class ABLModel:
label = prob.argmax(axis=1)
prob = reform_idx(prob, data_samples["X"])
else:
prob = [None] * len(data_samples)
prob = None
label = model.predict(X=data_X)

label = reform_idx(label, data_samples["X"])

data_samples.pred_idx = label
if prob is not None:
data_samples.pred_prob = prob

return {"label": label, "prob": prob}

def train(self, data_samples: ListData) -> float:


+ 44
- 9
abl/learning/basic_nn.py View File

@@ -11,13 +11,14 @@
# ================================================================#

import os
import logging
from typing import Any, Callable, List, Optional, T, Tuple

import numpy
import torch
from torch.utils.data import DataLoader

from ..dataset import ClassificationDataset
from ..dataset import ClassificationDataset, PredictionDataset
from ..utils.logger import print_log


@@ -197,7 +198,12 @@ class BasicNN:

return torch.cat(results, axis=0)

def predict(self, data_loader: DataLoader = None, X: List[Any] = None) -> numpy.ndarray:
def predict(
self,
data_loader: DataLoader = None,
X: List[Any] = None,
test_transform: Callable[..., Any] = None,
) -> numpy.ndarray:
"""
Predict the class of the input data.

@@ -215,12 +221,29 @@ class BasicNN:
"""

if data_loader is None:
if self.transform is not None:
X = [self.transform(x) for x in X]
data_loader = DataLoader(X, batch_size=self.batch_size)
if test_transform is None:
print_log(
"Transform used in the training phase will be used in prediction.",
"current",
level=logging.WARNING,
)
dataset = PredictionDataset(X, self.transform)
else:
dataset = PredictionDataset(X, test_transform)
data_loader = DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=int(self.num_workers),
collate_fn=self.collate_fn,
)
return self._predict(data_loader).argmax(axis=1).cpu().numpy()

def predict_proba(self, data_loader: DataLoader = None, X: List[Any] = None) -> numpy.ndarray:
def predict_proba(
self,
data_loader: DataLoader = None,
X: List[Any] = None,
test_transform: Callable[..., Any] = None,
) -> numpy.ndarray:
"""
Predict the probability of each class for the input data.

@@ -238,9 +261,21 @@ class BasicNN:
"""

if data_loader is None:
if self.transform is not None:
X = [self.transform(x) for x in X]
data_loader = DataLoader(X, batch_size=self.batch_size)
if test_transform is None:
print_log(
"Transform used in the training phase will be used in prediction.",
"current",
level=logging.WARNING,
)
dataset = PredictionDataset(X, self.transform)
else:
dataset = PredictionDataset(X, test_transform)
data_loader = DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=int(self.num_workers),
collate_fn=self.collate_fn,
)
return self._predict(data_loader).softmax(axis=1).cpu().numpy()

def _score(self, data_loader) -> Tuple[float, float]:


+ 1
- 2
abl/reasoning/ground_kb.py View File

@@ -1,8 +1,7 @@
from abc import ABC, abstractmethod
from typing import Any, Hashable, List

from abl.structures import ListData

from ..structures import ListData
from .base_kb import BaseKB




Loading…
Cancel
Save