import argparse import os.path as osp import numpy as np from sklearn.ensemble import RandomForestClassifier from abl.bridge import SimpleBridge from abl.data.evaluation import ReasoningMetric, SymbolAccuracy from abl.learning import ABLModel from abl.reasoning import Reasoner from abl.utils import ABLLogger, confidence_dist, print_log from get_dataset import load_and_preprocess_dataset, split_dataset from kb import ZooKB def transform_tab_data(X, y): return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y)) def consitency(data_example, candidates, candidate_idxs, reasoning_results): pred_prob = data_example.pred_prob model_scores = confidence_dist(pred_prob, candidate_idxs) rule_scores = np.array(reasoning_results) scores = model_scores + rule_scores return scores def main(): parser = argparse.ArgumentParser(description="Zoo example") parser.add_argument( "--loops", type=int, default=3, help="number of loop iterations (default : 3)" ) args = parser.parse_args() ### Working with Data X, y = load_and_preprocess_dataset(dataset_id=62) X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3) label_data = transform_tab_data(X_label, y_label) test_data = transform_tab_data(X_test, y_test) train_data = transform_tab_data(X_unlabel, y_unlabel) ### Building the Learning Part base_model = RandomForestClassifier() # Build ABLModel model = ABLModel(base_model) ### Building the Reasoning Part # Build knowledge base kb = ZooKB() # Create reasoner reasoner = Reasoner(kb, dist_func=consitency) ### Building Evaluation Metrics metric_list = [SymbolAccuracy(prefix="zoo"), ReasoningMetric(kb=kb, prefix="zoo")] # Build logger print_log("Abductive Learning on the ZOO example.", logger="current") log_dir = ABLLogger.get_current_instance().log_dir weights_dir = osp.join(log_dir, "weights") ### Bridging learning and reasoning bridge = SimpleBridge(model, reasoner, metric_list) # Performing training and testing print_log("------- Use labeled data to pretrain the model -----------", logger="current") base_model.fit(X_label, y_label) print_log("------- Test the initial model -----------", logger="current") bridge.test(test_data) print_log("------- Use ABL to train the model -----------", logger="current") bridge.train(train_data=train_data, label_data=label_data, loops=args.loops, segment_size=len(X_unlabel), save_dir=weights_dir) print_log("------- Test the final model -----------", logger="current") bridge.test(test_data) if __name__ == "__main__": main()