|
|
@@ -1,14 +1,14 @@ |
|
|
|
from typing import List, Union, Any, Tuple, Dict, Optional |
|
|
|
import os.path as osp |
|
|
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
|
|
|
|
from numpy import ndarray |
|
|
|
|
|
|
|
from .base_bridge import BaseBridge, DataSet |
|
|
|
|
|
|
|
from ..evaluation import BaseMetric |
|
|
|
from ..learning import ABLModel |
|
|
|
from ..reasoning import ReasonerBase |
|
|
|
from ..evaluation import BaseMetric |
|
|
|
from ..structures import ListData |
|
|
|
from ..utils.logger import print_log |
|
|
|
from ..utils import print_log |
|
|
|
from .base_bridge import BaseBridge, DataSet |
|
|
|
|
|
|
|
|
|
|
|
class SimpleBridge(BaseBridge): |
|
|
@@ -21,11 +21,13 @@ class SimpleBridge(BaseBridge): |
|
|
|
super().__init__(model, abducer) |
|
|
|
self.metric_list = metric_list |
|
|
|
|
|
|
|
def predict(self, data_samples: ListData) -> Tuple[List[ndarray], ndarray]: |
|
|
|
# TODO: add abducer.mapping to the property of SimpleBridge |
|
|
|
|
|
|
|
def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]: |
|
|
|
pred_res = self.model.predict(data_samples) |
|
|
|
data_samples.pred_idx = pred_res["label"] |
|
|
|
data_samples.pred_prob = pred_res["prob"] |
|
|
|
return data_samples["pred_idx"], ["data_samples.pred_prob"] |
|
|
|
return data_samples["pred_idx"], data_samples["pred_prob"] |
|
|
|
|
|
|
|
def abduce_pseudo_label( |
|
|
|
self, |
|
|
@@ -37,7 +39,7 @@ class SimpleBridge(BaseBridge): |
|
|
|
return data_samples["abduced_pseudo_label"] |
|
|
|
|
|
|
|
def idx_to_pseudo_label( |
|
|
|
self, data_samples: ListData, mapping: Dict = None |
|
|
|
self, data_samples: ListData, mapping: Optional[Dict] = None |
|
|
|
) -> List[List[Any]]: |
|
|
|
if mapping is None: |
|
|
|
mapping = self.abducer.mapping |
|
|
@@ -48,7 +50,7 @@ class SimpleBridge(BaseBridge): |
|
|
|
return data_samples["pred_pseudo_label"] |
|
|
|
|
|
|
|
def pseudo_label_to_idx( |
|
|
|
self, data_samples: ListData, mapping: Dict = None |
|
|
|
self, data_samples: ListData, mapping: Optional[Dict] = None |
|
|
|
) -> List[List[Any]]: |
|
|
|
if mapping is None: |
|
|
|
mapping = self.abducer.remapping |
|
|
@@ -59,9 +61,7 @@ class SimpleBridge(BaseBridge): |
|
|
|
data_samples.abduced_idx = abduced_idx |
|
|
|
return data_samples["abduced_idx"] |
|
|
|
|
|
|
|
def data_preprocess( |
|
|
|
self, X: List[Any], gt_pseudo_label: List[Any], Y: List[Any] |
|
|
|
) -> ListData: |
|
|
|
def data_preprocess(self, X: List[Any], gt_pseudo_label: List[Any], Y: List[Any]) -> ListData: |
|
|
|
data_samples = ListData() |
|
|
|
|
|
|
|
data_samples.X = X |
|
|
@@ -72,17 +72,22 @@ class SimpleBridge(BaseBridge): |
|
|
|
|
|
|
|
def train( |
|
|
|
self, |
|
|
|
train_data: DataSet, |
|
|
|
epochs: int = 50, |
|
|
|
batch_size: Union[int, float] = -1, |
|
|
|
train_data: Union[ListData, DataSet], |
|
|
|
loops: int = 50, |
|
|
|
segment_size: Union[int, float] = -1, |
|
|
|
eval_interval: int = 1, |
|
|
|
save_interval: Optional[int] = None, |
|
|
|
save_dir: Optional[str] = None, |
|
|
|
): |
|
|
|
data_samples = self.data_preprocess(*train_data) |
|
|
|
if isinstance(train_data, ListData): |
|
|
|
data_samples = train_data |
|
|
|
else: |
|
|
|
data_samples = self.data_preprocess(*train_data) |
|
|
|
|
|
|
|
for epoch in range(epochs): |
|
|
|
for seg_idx in range((len(data_samples) - 1) // batch_size + 1): |
|
|
|
for loop in range(loops): |
|
|
|
for seg_idx in range((len(data_samples) - 1) // segment_size + 1): |
|
|
|
sub_data_samples = data_samples[ |
|
|
|
seg_idx * batch_size : (seg_idx + 1) * batch_size |
|
|
|
seg_idx * segment_size : (seg_idx + 1) * segment_size |
|
|
|
] |
|
|
|
self.predict(sub_data_samples) |
|
|
|
self.idx_to_pseudo_label(sub_data_samples) |
|
|
@@ -91,25 +96,25 @@ class SimpleBridge(BaseBridge): |
|
|
|
loss = self.model.train(sub_data_samples) |
|
|
|
|
|
|
|
print_log( |
|
|
|
f"Epoch(train) [{epoch + 1}] [{(seg_idx + 1):3}/{(len(data_samples) - 1) // batch_size + 1}] model loss is {loss:.5f}", |
|
|
|
f"loop(train) [{loop + 1}/{loops}] segment(train) [{(seg_idx + 1)}/{(len(data_samples) - 1) // segment_size + 1}] model loss is {loss:.5f}", |
|
|
|
logger="current", |
|
|
|
) |
|
|
|
|
|
|
|
if (epoch + 1) % eval_interval == 0 or epoch == epochs - 1: |
|
|
|
print_log(f"Evaluation start: Epoch(val) [{epoch}]", logger="current") |
|
|
|
if (loop + 1) % eval_interval == 0 or loop == loops - 1: |
|
|
|
print_log(f"Evaluation start: loop(val) [{loop + 1}]", logger="current") |
|
|
|
self.valid(train_data) |
|
|
|
|
|
|
|
if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1): |
|
|
|
print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current") |
|
|
|
self.model.save(save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth")) |
|
|
|
|
|
|
|
def _valid(self, data_samples: ListData, batch_size: int = 128) -> None: |
|
|
|
for seg_idx in range((len(data_samples) - 1) // batch_size + 1): |
|
|
|
sub_data_samples = data_samples[ |
|
|
|
seg_idx * batch_size : (seg_idx + 1) * batch_size |
|
|
|
] |
|
|
|
sub_data_samples = data_samples[seg_idx * batch_size : (seg_idx + 1) * batch_size] |
|
|
|
self.predict(sub_data_samples) |
|
|
|
self.idx_to_pseudo_label(sub_data_samples) |
|
|
|
|
|
|
|
sub_data_samples.set_metainfo( |
|
|
|
dict(logic_forward=self.abducer.kb.logic_forward) |
|
|
|
) |
|
|
|
sub_data_samples.set_metainfo(dict(logic_forward=self.abducer.kb.logic_forward)) |
|
|
|
for metric in self.metric_list: |
|
|
|
metric.process(sub_data_samples) |
|
|
|
|
|
|
|