Browse Source

[FEATURE] Enhance SimpleBridge with wandb integration and type parameterization

dev
Gao Enhao 3 months ago
parent
commit
0cd57efb91
1 changed files with 60 additions and 27 deletions
  1. +60
    -27
      ablkit/bridge/simple_bridge.py

+ 60
- 27
ablkit/bridge/simple_bridge.py View File

@@ -9,15 +9,17 @@ from typing import Any, List, Optional, Tuple, Union


from numpy import ndarray from numpy import ndarray


import wandb

from ..data.evaluation import BaseMetric from ..data.evaluation import BaseMetric
from ..data.structures import ListData from ..data.structures import ListData
from ..learning import ABLModel from ..learning import ABLModel
from ..reasoning import Reasoner from ..reasoning import Reasoner
from ..utils import print_log from ..utils import print_log
from .base_bridge import BaseBridge
from .base_bridge import BaseBridge, M, R




class SimpleBridge(BaseBridge):
class SimpleBridge(BaseBridge[M, R]):
""" """
A basic implementation for bridging machine learning and reasoning parts. A basic implementation for bridging machine learning and reasoning parts.


@@ -32,10 +34,10 @@ class SimpleBridge(BaseBridge):


Parameters Parameters
---------- ----------
model : ABLModel
model : M
The machine learning model wrapped in ``ABLModel``, which is mainly used for The machine learning model wrapped in ``ABLModel``, which is mainly used for
prediction and model training. prediction and model training.
reasoner : Reasoner
reasoner : R
The reasoning part wrapped in ``Reasoner``, which is used for pseudo-label revision. The reasoning part wrapped in ``Reasoner``, which is used for pseudo-label revision.
metric_list : List[BaseMetric] metric_list : List[BaseMetric]
A list of metrics used for evaluating the model's performance. A list of metrics used for evaluating the model's performance.
@@ -43,12 +45,13 @@ class SimpleBridge(BaseBridge):


def __init__( def __init__(
self, self,
model: ABLModel,
reasoner: Reasoner,
model: M,
reasoner: R,
metric_list: List[BaseMetric], metric_list: List[BaseMetric],
) -> None: ) -> None:
super().__init__(model, reasoner) super().__init__(model, reasoner)
self.metric_list = metric_list self.metric_list = metric_list
self.use_wandb = self._check_wandb_available()
if not hasattr(model.base_model, "predict_proba") and reasoner.dist_func in [ if not hasattr(model.base_model, "predict_proba") and reasoner.dist_func in [
"confidence", "confidence",
"avg_confidence", "avg_confidence",
@@ -59,6 +62,20 @@ class SimpleBridge(BaseBridge):
+ "or 'avg_confidence', which are related to predicted probability." + "or 'avg_confidence', which are related to predicted probability."
) )


def _check_wandb_available(self):
"""
Check if wandb is available and initialized.

Returns
-------
bool
True if wandb is available and initialized, False otherwise.
"""
try:
return wandb.run is not None
except ImportError:
return False

def predict(self, data_examples: ListData) -> Tuple[List[ndarray], List[ndarray]]: def predict(self, data_examples: ListData) -> Tuple[List[ndarray], List[ndarray]]:
""" """
Predict class indices and probabilities (if ``predict_proba`` is implemented in Predict class indices and probabilities (if ``predict_proba`` is implemented in
@@ -129,10 +146,7 @@ class SimpleBridge(BaseBridge):
A list of indices converted from pseudo-labels. A list of indices converted from pseudo-labels.
""" """
abduced_idx = [ abduced_idx = [
[
self.reasoner.label_to_idx[_abduced_pseudo_label]
for _abduced_pseudo_label in sub_list
]
[self.reasoner.label_to_idx[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list]
for sub_list in data_examples.abduced_pseudo_label for sub_list in data_examples.abduced_pseudo_label
] ]
data_examples.abduced_idx = abduced_idx data_examples.abduced_idx = abduced_idx
@@ -207,11 +221,12 @@ class SimpleBridge(BaseBridge):
def train( def train(
self, self,
train_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]], train_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
label_data: Optional[
Union[ListData, Tuple[List[List[Any]], List[List[Any]], List[Any]]]
] = None,
label_data: Optional[Union[ListData, Tuple[List[List[Any]], List[List[Any]], List[Any]]]] = None,
val_data: Optional[ val_data: Optional[
Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]]
Union[
ListData,
Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]],
]
] = None, ] = None,
loops: int = 50, loops: int = 50,
segment_size: Union[int, float] = 1.0, segment_size: Union[int, float] = 1.0,
@@ -287,28 +302,26 @@ class SimpleBridge(BaseBridge):
logger="current", logger="current",
) )


sub_data_examples = data_examples[
seg_idx * segment_size : (seg_idx + 1) * segment_size
]
sub_data_examples = data_examples[seg_idx * segment_size : (seg_idx + 1) * segment_size]
self.predict(sub_data_examples) self.predict(sub_data_examples)
self.idx_to_pseudo_label(sub_data_examples) self.idx_to_pseudo_label(sub_data_examples)
self.abduce_pseudo_label(sub_data_examples) self.abduce_pseudo_label(sub_data_examples)
self.filter_pseudo_label(sub_data_examples) self.filter_pseudo_label(sub_data_examples)
self.concat_data_examples(sub_data_examples, label_data_examples) self.concat_data_examples(sub_data_examples, label_data_examples)
self.pseudo_label_to_idx(sub_data_examples) self.pseudo_label_to_idx(sub_data_examples)
if len(sub_data_examples) == 0:
continue
self.model.train(sub_data_examples) self.model.train(sub_data_examples)


if (loop + 1) % eval_interval == 0 or loop == loops - 1: if (loop + 1) % eval_interval == 0 or loop == loops - 1:
print_log(f"Eval start: loop(val) [{loop + 1}]", logger="current") print_log(f"Eval start: loop(val) [{loop + 1}]", logger="current")
self._valid(val_data_examples)
self._valid(val_data_examples, prefix="val")


if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1): if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1):
print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current") print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current")
self.model.save(
save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth")
)
self.model.save(save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth"))


def _valid(self, data_examples: ListData) -> None:
def _valid(self, data_examples: ListData, prefix: str = "val") -> None:
""" """
Internal method for validating the model with given data examples. Internal method for validating the model with given data examples.


@@ -320,21 +333,40 @@ class SimpleBridge(BaseBridge):
self.predict(data_examples) self.predict(data_examples)
self.idx_to_pseudo_label(data_examples) self.idx_to_pseudo_label(data_examples)


for metric in self.metric_list:
metric.prefix = prefix

for metric in self.metric_list: for metric in self.metric_list:
metric.process(data_examples) metric.process(data_examples)


res = dict() res = dict()
for metric in self.metric_list: for metric in self.metric_list:
res.update(metric.evaluate()) res.update(metric.evaluate())

msg = "Evaluation ended, " msg = "Evaluation ended, "
for k, v in res.items(): for k, v in res.items():
msg += k + f": {v:.3f} "
try:
v = float(v)
msg += k + f": {v:.3f} "
except:
pass

if self.use_wandb:
try:
wandb_metrics = {}
for k, v in res.items():
wandb_metrics[f"{k}"] = v
wandb.log(wandb_metrics)
except Exception as e:
print_log(f"Failed to log metrics to wandb: {e}", logger="current")

print_log(msg, logger="current") print_log(msg, logger="current")


def valid( def valid(
self, self,
val_data: Union[ val_data: Union[
ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]
ListData,
Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]],
], ],
) -> None: ) -> None:
""" """
@@ -349,12 +381,13 @@ class SimpleBridge(BaseBridge):
``self.metric_list``. ``self.metric_list``.
""" """
val_data_examples = self.data_preprocess("val", val_data) val_data_examples = self.data_preprocess("val", val_data)
self._valid(val_data_examples)
self._valid(val_data_examples, prefix="val")


def test( def test(
self, self,
test_data: Union[ test_data: Union[
ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]
ListData,
Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]],
], ],
) -> None: ) -> None:
""" """
@@ -370,4 +403,4 @@ class SimpleBridge(BaseBridge):
""" """
print_log("Test start:", logger="current") print_log("Test start:", logger="current")
test_data_examples = self.data_preprocess("test", test_data) test_data_examples = self.data_preprocess("test", test_data)
self._valid(test_data_examples)
self._valid(test_data_examples, prefix="test")

Loading…
Cancel
Save