Browse Source

[FEATURE] Enhance SimpleBridge with wandb integration and type parameterization

dev
Gao Enhao 3 months ago
parent
commit
0cd57efb91
1 changed files with 60 additions and 27 deletions
  1. +60
    -27
      ablkit/bridge/simple_bridge.py

+ 60
- 27
ablkit/bridge/simple_bridge.py View File

@@ -9,15 +9,17 @@ from typing import Any, List, Optional, Tuple, Union

from numpy import ndarray

import wandb

from ..data.evaluation import BaseMetric
from ..data.structures import ListData
from ..learning import ABLModel
from ..reasoning import Reasoner
from ..utils import print_log
from .base_bridge import BaseBridge
from .base_bridge import BaseBridge, M, R


class SimpleBridge(BaseBridge):
class SimpleBridge(BaseBridge[M, R]):
"""
A basic implementation for bridging machine learning and reasoning parts.

@@ -32,10 +34,10 @@ class SimpleBridge(BaseBridge):

Parameters
----------
model : ABLModel
model : M
The machine learning model wrapped in ``ABLModel``, which is mainly used for
prediction and model training.
reasoner : Reasoner
reasoner : R
The reasoning part wrapped in ``Reasoner``, which is used for pseudo-label revision.
metric_list : List[BaseMetric]
A list of metrics used for evaluating the model's performance.
@@ -43,12 +45,13 @@ class SimpleBridge(BaseBridge):

def __init__(
self,
model: ABLModel,
reasoner: Reasoner,
model: M,
reasoner: R,
metric_list: List[BaseMetric],
) -> None:
super().__init__(model, reasoner)
self.metric_list = metric_list
self.use_wandb = self._check_wandb_available()
if not hasattr(model.base_model, "predict_proba") and reasoner.dist_func in [
"confidence",
"avg_confidence",
@@ -59,6 +62,20 @@ class SimpleBridge(BaseBridge):
+ "or 'avg_confidence', which are related to predicted probability."
)

def _check_wandb_available(self):
"""
Check if wandb is available and initialized.

Returns
-------
bool
True if wandb is available and initialized, False otherwise.
"""
try:
return wandb.run is not None
except ImportError:
return False

def predict(self, data_examples: ListData) -> Tuple[List[ndarray], List[ndarray]]:
"""
Predict class indices and probabilities (if ``predict_proba`` is implemented in
@@ -129,10 +146,7 @@ class SimpleBridge(BaseBridge):
A list of indices converted from pseudo-labels.
"""
abduced_idx = [
[
self.reasoner.label_to_idx[_abduced_pseudo_label]
for _abduced_pseudo_label in sub_list
]
[self.reasoner.label_to_idx[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list]
for sub_list in data_examples.abduced_pseudo_label
]
data_examples.abduced_idx = abduced_idx
@@ -207,11 +221,12 @@ class SimpleBridge(BaseBridge):
def train(
self,
train_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
label_data: Optional[
Union[ListData, Tuple[List[List[Any]], List[List[Any]], List[Any]]]
] = None,
label_data: Optional[Union[ListData, Tuple[List[List[Any]], List[List[Any]], List[Any]]]] = None,
val_data: Optional[
Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]]
Union[
ListData,
Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]],
]
] = None,
loops: int = 50,
segment_size: Union[int, float] = 1.0,
@@ -287,28 +302,26 @@ class SimpleBridge(BaseBridge):
logger="current",
)

sub_data_examples = data_examples[
seg_idx * segment_size : (seg_idx + 1) * segment_size
]
sub_data_examples = data_examples[seg_idx * segment_size : (seg_idx + 1) * segment_size]
self.predict(sub_data_examples)
self.idx_to_pseudo_label(sub_data_examples)
self.abduce_pseudo_label(sub_data_examples)
self.filter_pseudo_label(sub_data_examples)
self.concat_data_examples(sub_data_examples, label_data_examples)
self.pseudo_label_to_idx(sub_data_examples)
if len(sub_data_examples) == 0:
continue
self.model.train(sub_data_examples)

if (loop + 1) % eval_interval == 0 or loop == loops - 1:
print_log(f"Eval start: loop(val) [{loop + 1}]", logger="current")
self._valid(val_data_examples)
self._valid(val_data_examples, prefix="val")

if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1):
print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current")
self.model.save(
save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth")
)
self.model.save(save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth"))

def _valid(self, data_examples: ListData) -> None:
def _valid(self, data_examples: ListData, prefix: str = "val") -> None:
"""
Internal method for validating the model with given data examples.

@@ -320,21 +333,40 @@ class SimpleBridge(BaseBridge):
self.predict(data_examples)
self.idx_to_pseudo_label(data_examples)

for metric in self.metric_list:
metric.prefix = prefix

for metric in self.metric_list:
metric.process(data_examples)

res = dict()
for metric in self.metric_list:
res.update(metric.evaluate())

msg = "Evaluation ended, "
for k, v in res.items():
msg += k + f": {v:.3f} "
try:
v = float(v)
msg += k + f": {v:.3f} "
except:
pass

if self.use_wandb:
try:
wandb_metrics = {}
for k, v in res.items():
wandb_metrics[f"{k}"] = v
wandb.log(wandb_metrics)
except Exception as e:
print_log(f"Failed to log metrics to wandb: {e}", logger="current")

print_log(msg, logger="current")

def valid(
self,
val_data: Union[
ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]
ListData,
Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]],
],
) -> None:
"""
@@ -349,12 +381,13 @@ class SimpleBridge(BaseBridge):
``self.metric_list``.
"""
val_data_examples = self.data_preprocess("val", val_data)
self._valid(val_data_examples)
self._valid(val_data_examples, prefix="val")

def test(
self,
test_data: Union[
ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]
ListData,
Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]],
],
) -> None:
"""
@@ -370,4 +403,4 @@ class SimpleBridge(BaseBridge):
"""
print_log("Test start:", logger="current")
test_data_examples = self.data_preprocess("test", test_data)
self._valid(test_data_examples)
self._valid(test_data_examples, prefix="test")

Loading…
Cancel
Save