From 53bb17bf3703f40846635ae38d087429ad84ee01 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Sat, 11 Nov 2023 00:34:26 +0800 Subject: [PATCH] [ENH] add abstract data interface to bridge --- abl/bridge/simple_bridge.py | 61 ++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 28 deletions(-) diff --git a/abl/bridge/simple_bridge.py b/abl/bridge/simple_bridge.py index 4ca5628..cca83e4 100644 --- a/abl/bridge/simple_bridge.py +++ b/abl/bridge/simple_bridge.py @@ -1,14 +1,14 @@ -from typing import List, Union, Any, Tuple, Dict, Optional +import os.path as osp +from typing import Any, Dict, List, Optional, Tuple, Union from numpy import ndarray -from .base_bridge import BaseBridge, DataSet - +from ..evaluation import BaseMetric from ..learning import ABLModel from ..reasoning import ReasonerBase -from ..evaluation import BaseMetric from ..structures import ListData -from ..utils.logger import print_log +from ..utils import print_log +from .base_bridge import BaseBridge, DataSet class SimpleBridge(BaseBridge): @@ -21,11 +21,13 @@ class SimpleBridge(BaseBridge): super().__init__(model, abducer) self.metric_list = metric_list - def predict(self, data_samples: ListData) -> Tuple[List[ndarray], ndarray]: + # TODO: add abducer.mapping to the property of SimpleBridge + + def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]: pred_res = self.model.predict(data_samples) data_samples.pred_idx = pred_res["label"] data_samples.pred_prob = pred_res["prob"] - return data_samples["pred_idx"], ["data_samples.pred_prob"] + return data_samples["pred_idx"], data_samples["pred_prob"] def abduce_pseudo_label( self, @@ -37,7 +39,7 @@ class SimpleBridge(BaseBridge): return data_samples["abduced_pseudo_label"] def idx_to_pseudo_label( - self, data_samples: ListData, mapping: Dict = None + self, data_samples: ListData, mapping: Optional[Dict] = None ) -> List[List[Any]]: if mapping is None: mapping = self.abducer.mapping @@ -48,7 +50,7 @@ class SimpleBridge(BaseBridge): return data_samples["pred_pseudo_label"] def pseudo_label_to_idx( - self, data_samples: ListData, mapping: Dict = None + self, data_samples: ListData, mapping: Optional[Dict] = None ) -> List[List[Any]]: if mapping is None: mapping = self.abducer.remapping @@ -59,9 +61,7 @@ class SimpleBridge(BaseBridge): data_samples.abduced_idx = abduced_idx return data_samples["abduced_idx"] - def data_preprocess( - self, X: List[Any], gt_pseudo_label: List[Any], Y: List[Any] - ) -> ListData: + def data_preprocess(self, X: List[Any], gt_pseudo_label: List[Any], Y: List[Any]) -> ListData: data_samples = ListData() data_samples.X = X @@ -72,17 +72,22 @@ class SimpleBridge(BaseBridge): def train( self, - train_data: DataSet, - epochs: int = 50, - batch_size: Union[int, float] = -1, + train_data: Union[ListData, DataSet], + loops: int = 50, + segment_size: Union[int, float] = -1, eval_interval: int = 1, + save_interval: Optional[int] = None, + save_dir: Optional[str] = None, ): - data_samples = self.data_preprocess(*train_data) + if isinstance(train_data, ListData): + data_samples = train_data + else: + data_samples = self.data_preprocess(*train_data) - for epoch in range(epochs): - for seg_idx in range((len(data_samples) - 1) // batch_size + 1): + for loop in range(loops): + for seg_idx in range((len(data_samples) - 1) // segment_size + 1): sub_data_samples = data_samples[ - seg_idx * batch_size : (seg_idx + 1) * batch_size + seg_idx * segment_size : (seg_idx + 1) * segment_size ] self.predict(sub_data_samples) self.idx_to_pseudo_label(sub_data_samples) @@ -91,25 +96,25 @@ class SimpleBridge(BaseBridge): loss = self.model.train(sub_data_samples) print_log( - f"Epoch(train) [{epoch + 1}] [{(seg_idx + 1):3}/{(len(data_samples) - 1) // batch_size + 1}] model loss is {loss:.5f}", + f"loop(train) [{loop + 1}/{loops}] segment(train) [{(seg_idx + 1)}/{(len(data_samples) - 1) // segment_size + 1}] model loss is {loss:.5f}", logger="current", ) - if (epoch + 1) % eval_interval == 0 or epoch == epochs - 1: - print_log(f"Evaluation start: Epoch(val) [{epoch}]", logger="current") + if (loop + 1) % eval_interval == 0 or loop == loops - 1: + print_log(f"Evaluation start: loop(val) [{loop + 1}]", logger="current") self.valid(train_data) + if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1): + print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current") + self.model.save(save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth")) + def _valid(self, data_samples: ListData, batch_size: int = 128) -> None: for seg_idx in range((len(data_samples) - 1) // batch_size + 1): - sub_data_samples = data_samples[ - seg_idx * batch_size : (seg_idx + 1) * batch_size - ] + sub_data_samples = data_samples[seg_idx * batch_size : (seg_idx + 1) * batch_size] self.predict(sub_data_samples) self.idx_to_pseudo_label(sub_data_samples) - sub_data_samples.set_metainfo( - dict(logic_forward=self.abducer.kb.logic_forward) - ) + sub_data_samples.set_metainfo(dict(logic_forward=self.abducer.kb.logic_forward)) for metric in self.metric_list: metric.process(sub_data_samples)