Browse Source

[ENH] add abstract data interface to bridge

ab_data
Gao Enhao 1 year ago
parent
commit
53bb17bf37
1 changed files with 33 additions and 28 deletions
  1. +33
    -28
      abl/bridge/simple_bridge.py

+ 33
- 28
abl/bridge/simple_bridge.py View File

@@ -1,14 +1,14 @@
from typing import List, Union, Any, Tuple, Dict, Optional
import os.path as osp
from typing import Any, Dict, List, Optional, Tuple, Union

from numpy import ndarray

from .base_bridge import BaseBridge, DataSet

from ..evaluation import BaseMetric
from ..learning import ABLModel
from ..reasoning import ReasonerBase
from ..evaluation import BaseMetric
from ..structures import ListData
from ..utils.logger import print_log
from ..utils import print_log
from .base_bridge import BaseBridge, DataSet


class SimpleBridge(BaseBridge):
@@ -21,11 +21,13 @@ class SimpleBridge(BaseBridge):
super().__init__(model, abducer)
self.metric_list = metric_list

def predict(self, data_samples: ListData) -> Tuple[List[ndarray], ndarray]:
# TODO: add abducer.mapping to the property of SimpleBridge

def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]:
pred_res = self.model.predict(data_samples)
data_samples.pred_idx = pred_res["label"]
data_samples.pred_prob = pred_res["prob"]
return data_samples["pred_idx"], ["data_samples.pred_prob"]
return data_samples["pred_idx"], data_samples["pred_prob"]

def abduce_pseudo_label(
self,
@@ -37,7 +39,7 @@ class SimpleBridge(BaseBridge):
return data_samples["abduced_pseudo_label"]

def idx_to_pseudo_label(
self, data_samples: ListData, mapping: Dict = None
self, data_samples: ListData, mapping: Optional[Dict] = None
) -> List[List[Any]]:
if mapping is None:
mapping = self.abducer.mapping
@@ -48,7 +50,7 @@ class SimpleBridge(BaseBridge):
return data_samples["pred_pseudo_label"]

def pseudo_label_to_idx(
self, data_samples: ListData, mapping: Dict = None
self, data_samples: ListData, mapping: Optional[Dict] = None
) -> List[List[Any]]:
if mapping is None:
mapping = self.abducer.remapping
@@ -59,9 +61,7 @@ class SimpleBridge(BaseBridge):
data_samples.abduced_idx = abduced_idx
return data_samples["abduced_idx"]

def data_preprocess(
self, X: List[Any], gt_pseudo_label: List[Any], Y: List[Any]
) -> ListData:
def data_preprocess(self, X: List[Any], gt_pseudo_label: List[Any], Y: List[Any]) -> ListData:
data_samples = ListData()

data_samples.X = X
@@ -72,17 +72,22 @@ class SimpleBridge(BaseBridge):

def train(
self,
train_data: DataSet,
epochs: int = 50,
batch_size: Union[int, float] = -1,
train_data: Union[ListData, DataSet],
loops: int = 50,
segment_size: Union[int, float] = -1,
eval_interval: int = 1,
save_interval: Optional[int] = None,
save_dir: Optional[str] = None,
):
data_samples = self.data_preprocess(*train_data)
if isinstance(train_data, ListData):
data_samples = train_data
else:
data_samples = self.data_preprocess(*train_data)

for epoch in range(epochs):
for seg_idx in range((len(data_samples) - 1) // batch_size + 1):
for loop in range(loops):
for seg_idx in range((len(data_samples) - 1) // segment_size + 1):
sub_data_samples = data_samples[
seg_idx * batch_size : (seg_idx + 1) * batch_size
seg_idx * segment_size : (seg_idx + 1) * segment_size
]
self.predict(sub_data_samples)
self.idx_to_pseudo_label(sub_data_samples)
@@ -91,25 +96,25 @@ class SimpleBridge(BaseBridge):
loss = self.model.train(sub_data_samples)

print_log(
f"Epoch(train) [{epoch + 1}] [{(seg_idx + 1):3}/{(len(data_samples) - 1) // batch_size + 1}] model loss is {loss:.5f}",
f"loop(train) [{loop + 1}/{loops}] segment(train) [{(seg_idx + 1)}/{(len(data_samples) - 1) // segment_size + 1}] model loss is {loss:.5f}",
logger="current",
)

if (epoch + 1) % eval_interval == 0 or epoch == epochs - 1:
print_log(f"Evaluation start: Epoch(val) [{epoch}]", logger="current")
if (loop + 1) % eval_interval == 0 or loop == loops - 1:
print_log(f"Evaluation start: loop(val) [{loop + 1}]", logger="current")
self.valid(train_data)

if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1):
print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current")
self.model.save(save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth"))

def _valid(self, data_samples: ListData, batch_size: int = 128) -> None:
for seg_idx in range((len(data_samples) - 1) // batch_size + 1):
sub_data_samples = data_samples[
seg_idx * batch_size : (seg_idx + 1) * batch_size
]
sub_data_samples = data_samples[seg_idx * batch_size : (seg_idx + 1) * batch_size]
self.predict(sub_data_samples)
self.idx_to_pseudo_label(sub_data_samples)

sub_data_samples.set_metainfo(
dict(logic_forward=self.abducer.kb.logic_forward)
)
sub_data_samples.set_metainfo(dict(logic_forward=self.abducer.kb.logic_forward))
for metric in self.metric_list:
metric.process(sub_data_samples)



Loading…
Cancel
Save