|
|
@@ -9,15 +9,17 @@ from typing import Any, List, Optional, Tuple, Union |
|
|
|
|
|
|
|
from numpy import ndarray |
|
|
|
|
|
|
|
import wandb |
|
|
|
|
|
|
|
from ..data.evaluation import BaseMetric |
|
|
|
from ..data.structures import ListData |
|
|
|
from ..learning import ABLModel |
|
|
|
from ..reasoning import Reasoner |
|
|
|
from ..utils import print_log |
|
|
|
from .base_bridge import BaseBridge |
|
|
|
from .base_bridge import BaseBridge, M, R |
|
|
|
|
|
|
|
|
|
|
|
class SimpleBridge(BaseBridge): |
|
|
|
class SimpleBridge(BaseBridge[M, R]): |
|
|
|
""" |
|
|
|
A basic implementation for bridging machine learning and reasoning parts. |
|
|
|
|
|
|
@@ -32,10 +34,10 @@ class SimpleBridge(BaseBridge): |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
model : ABLModel |
|
|
|
model : M |
|
|
|
The machine learning model wrapped in ``ABLModel``, which is mainly used for |
|
|
|
prediction and model training. |
|
|
|
reasoner : Reasoner |
|
|
|
reasoner : R |
|
|
|
The reasoning part wrapped in ``Reasoner``, which is used for pseudo-label revision. |
|
|
|
metric_list : List[BaseMetric] |
|
|
|
A list of metrics used for evaluating the model's performance. |
|
|
@@ -43,12 +45,13 @@ class SimpleBridge(BaseBridge): |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
model: ABLModel, |
|
|
|
reasoner: Reasoner, |
|
|
|
model: M, |
|
|
|
reasoner: R, |
|
|
|
metric_list: List[BaseMetric], |
|
|
|
) -> None: |
|
|
|
super().__init__(model, reasoner) |
|
|
|
self.metric_list = metric_list |
|
|
|
self.use_wandb = self._check_wandb_available() |
|
|
|
if not hasattr(model.base_model, "predict_proba") and reasoner.dist_func in [ |
|
|
|
"confidence", |
|
|
|
"avg_confidence", |
|
|
@@ -59,6 +62,20 @@ class SimpleBridge(BaseBridge): |
|
|
|
+ "or 'avg_confidence', which are related to predicted probability." |
|
|
|
) |
|
|
|
|
|
|
|
def _check_wandb_available(self): |
|
|
|
""" |
|
|
|
Check if wandb is available and initialized. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
bool |
|
|
|
True if wandb is available and initialized, False otherwise. |
|
|
|
""" |
|
|
|
try: |
|
|
|
return wandb.run is not None |
|
|
|
except ImportError: |
|
|
|
return False |
|
|
|
|
|
|
|
def predict(self, data_examples: ListData) -> Tuple[List[ndarray], List[ndarray]]: |
|
|
|
""" |
|
|
|
Predict class indices and probabilities (if ``predict_proba`` is implemented in |
|
|
@@ -129,10 +146,7 @@ class SimpleBridge(BaseBridge): |
|
|
|
A list of indices converted from pseudo-labels. |
|
|
|
""" |
|
|
|
abduced_idx = [ |
|
|
|
[ |
|
|
|
self.reasoner.label_to_idx[_abduced_pseudo_label] |
|
|
|
for _abduced_pseudo_label in sub_list |
|
|
|
] |
|
|
|
[self.reasoner.label_to_idx[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list] |
|
|
|
for sub_list in data_examples.abduced_pseudo_label |
|
|
|
] |
|
|
|
data_examples.abduced_idx = abduced_idx |
|
|
@@ -207,11 +221,12 @@ class SimpleBridge(BaseBridge): |
|
|
|
def train( |
|
|
|
self, |
|
|
|
train_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]], |
|
|
|
label_data: Optional[ |
|
|
|
Union[ListData, Tuple[List[List[Any]], List[List[Any]], List[Any]]] |
|
|
|
] = None, |
|
|
|
label_data: Optional[Union[ListData, Tuple[List[List[Any]], List[List[Any]], List[Any]]]] = None, |
|
|
|
val_data: Optional[ |
|
|
|
Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] |
|
|
|
Union[ |
|
|
|
ListData, |
|
|
|
Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]], |
|
|
|
] |
|
|
|
] = None, |
|
|
|
loops: int = 50, |
|
|
|
segment_size: Union[int, float] = 1.0, |
|
|
@@ -287,28 +302,26 @@ class SimpleBridge(BaseBridge): |
|
|
|
logger="current", |
|
|
|
) |
|
|
|
|
|
|
|
sub_data_examples = data_examples[ |
|
|
|
seg_idx * segment_size : (seg_idx + 1) * segment_size |
|
|
|
] |
|
|
|
sub_data_examples = data_examples[seg_idx * segment_size : (seg_idx + 1) * segment_size] |
|
|
|
self.predict(sub_data_examples) |
|
|
|
self.idx_to_pseudo_label(sub_data_examples) |
|
|
|
self.abduce_pseudo_label(sub_data_examples) |
|
|
|
self.filter_pseudo_label(sub_data_examples) |
|
|
|
self.concat_data_examples(sub_data_examples, label_data_examples) |
|
|
|
self.pseudo_label_to_idx(sub_data_examples) |
|
|
|
if len(sub_data_examples) == 0: |
|
|
|
continue |
|
|
|
self.model.train(sub_data_examples) |
|
|
|
|
|
|
|
if (loop + 1) % eval_interval == 0 or loop == loops - 1: |
|
|
|
print_log(f"Eval start: loop(val) [{loop + 1}]", logger="current") |
|
|
|
self._valid(val_data_examples) |
|
|
|
self._valid(val_data_examples, prefix="val") |
|
|
|
|
|
|
|
if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1): |
|
|
|
print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current") |
|
|
|
self.model.save( |
|
|
|
save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth") |
|
|
|
) |
|
|
|
self.model.save(save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth")) |
|
|
|
|
|
|
|
def _valid(self, data_examples: ListData) -> None: |
|
|
|
def _valid(self, data_examples: ListData, prefix: str = "val") -> None: |
|
|
|
""" |
|
|
|
Internal method for validating the model with given data examples. |
|
|
|
|
|
|
@@ -320,21 +333,40 @@ class SimpleBridge(BaseBridge): |
|
|
|
self.predict(data_examples) |
|
|
|
self.idx_to_pseudo_label(data_examples) |
|
|
|
|
|
|
|
for metric in self.metric_list: |
|
|
|
metric.prefix = prefix |
|
|
|
|
|
|
|
for metric in self.metric_list: |
|
|
|
metric.process(data_examples) |
|
|
|
|
|
|
|
res = dict() |
|
|
|
for metric in self.metric_list: |
|
|
|
res.update(metric.evaluate()) |
|
|
|
|
|
|
|
msg = "Evaluation ended, " |
|
|
|
for k, v in res.items(): |
|
|
|
msg += k + f": {v:.3f} " |
|
|
|
try: |
|
|
|
v = float(v) |
|
|
|
msg += k + f": {v:.3f} " |
|
|
|
except: |
|
|
|
pass |
|
|
|
|
|
|
|
if self.use_wandb: |
|
|
|
try: |
|
|
|
wandb_metrics = {} |
|
|
|
for k, v in res.items(): |
|
|
|
wandb_metrics[f"{k}"] = v |
|
|
|
wandb.log(wandb_metrics) |
|
|
|
except Exception as e: |
|
|
|
print_log(f"Failed to log metrics to wandb: {e}", logger="current") |
|
|
|
|
|
|
|
print_log(msg, logger="current") |
|
|
|
|
|
|
|
def valid( |
|
|
|
self, |
|
|
|
val_data: Union[ |
|
|
|
ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]] |
|
|
|
ListData, |
|
|
|
Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]], |
|
|
|
], |
|
|
|
) -> None: |
|
|
|
""" |
|
|
@@ -349,12 +381,13 @@ class SimpleBridge(BaseBridge): |
|
|
|
``self.metric_list``. |
|
|
|
""" |
|
|
|
val_data_examples = self.data_preprocess("val", val_data) |
|
|
|
self._valid(val_data_examples) |
|
|
|
self._valid(val_data_examples, prefix="val") |
|
|
|
|
|
|
|
def test( |
|
|
|
self, |
|
|
|
test_data: Union[ |
|
|
|
ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]] |
|
|
|
ListData, |
|
|
|
Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]], |
|
|
|
], |
|
|
|
) -> None: |
|
|
|
""" |
|
|
@@ -370,4 +403,4 @@ class SimpleBridge(BaseBridge): |
|
|
|
""" |
|
|
|
print_log("Test start:", logger="current") |
|
|
|
test_data_examples = self.data_preprocess("test", test_data) |
|
|
|
self._valid(test_data_examples) |
|
|
|
self._valid(test_data_examples, prefix="test") |