|
- import numpy as np
- from sklearn.metrics import mean_squared_error
- from sklearn.model_selection import train_test_split
-
- from learnware.reuse import AveragingReuser, EnsemblePruningReuser, FeatureAugmentReuser, HeteroMapAlignLearnware
- from config import align_model_params
- from train import train_model
-
-
- def loss_func_rmse(y_true, y_pred):
- return np.sqrt(mean_squared_error(y_true, y_pred))
-
-
- def user_model_score(x_train, y_train, test_info):
- x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.2, random_state=42)
- user_model = train_model(x_train, y_train, x_val, y_val, test_info)
- return user_model
-
-
- class HomoScoringMethods:
- @staticmethod
- def single_aug_score(x_train, y_train, test_info):
- single_learnware = test_info["single_learnware"]
- reuse_single_augment = FeatureAugmentReuser(single_learnware, mode="regression")
- reuse_single_augment.fit(x_train=x_train, y_train=y_train)
- return reuse_single_augment
-
- @staticmethod
- def multiple_aug_score(x_train, y_train, test_info):
- multiple_learnwares = test_info["learnwares"]
- reuse_multiple_augment = FeatureAugmentReuser(multiple_learnwares, mode="regression")
- reuse_multiple_augment.fit(x_train=x_train, y_train=y_train)
- return reuse_multiple_augment
-
- @staticmethod
- def multiple_avg_score(x_train, y_train, test_info):
- multiple_learnwares = test_info["learnwares"]
- reuse_multiple_avg = AveragingReuser(multiple_learnwares, mode="mean")
- return reuse_multiple_avg
-
- @staticmethod
- def multiple_ensemble_pruning_score(x_train, y_train, test_info):
- multiple_learnwares = test_info["learnwares"]
- if len(multiple_learnwares) == 1:
- return multiple_learnwares[0]
- reuse_pruning = EnsemblePruningReuser(multiple_learnwares, mode="regression")
- reuse_pruning.fit(val_X=x_train, val_y=y_train)
- return reuse_pruning
-
-
- class HeteroMethods:
- @staticmethod
- def create_hetero_learnware_list(learnware_list, user_rkme, x_train, y_train):
- hetero_learnware_list = []
- for learnware in learnware_list:
- hetero_learnware = HeteroMapAlignLearnware(learnware, mode="regression", **align_model_params)
- hetero_learnware.align(user_rkme, x_train, y_train)
- hetero_learnware_list.append(hetero_learnware)
- return hetero_learnware_list
-
- @staticmethod
- def single_aug_score(x_train, y_train, test_info):
- user_rkme, single_learnware = test_info["user_rkme"], test_info["single_learnware"]
- reuse_single_augment = HeteroMapAlignLearnware(single_learnware, mode="regression", **align_model_params)
- reuse_single_augment.align(user_rkme=user_rkme, x_train=x_train, y_train=y_train)
- return reuse_single_augment
-
- @staticmethod
- def multiple_aug_score(x_train, y_train, test_info):
- user_rkme, multiple_learnwares = test_info["user_rkme"], test_info["learnwares"]
- hetero_learnware_list = HeteroMethods.create_hetero_learnware_list(multiple_learnwares, user_rkme, x_train, y_train)
- reuse_multiple_augment = FeatureAugmentReuser(hetero_learnware_list, mode="regression")
- reuse_multiple_augment.fit(x_train=x_train, y_train=y_train)
- return reuse_multiple_augment
-
- @staticmethod
- def multiple_ensemble_pruning_score(x_train, y_train, test_info):
- user_rkme, multiple_learnwares = test_info["user_rkme"], test_info["learnwares"]
- hetero_learnware_list = HeteroMethods.create_hetero_learnware_list(multiple_learnwares, user_rkme, x_train, y_train)
- if len(hetero_learnware_list) == 1:
- return hetero_learnware_list[0]
- reuse_pruning = EnsemblePruningReuser(hetero_learnware_list, mode="regression")
- reuse_pruning.fit(val_X=x_train, val_y=y_train)
- return reuse_pruning
-
- @staticmethod
- def multiple_avg_score(x_train, y_train, test_info):
- user_rkme, multiple_learnwares = test_info["user_rkme"], test_info["learnwares"]
- hetero_learnware_list = HeteroMethods.create_hetero_learnware_list(multiple_learnwares, user_rkme, x_train, y_train)
- reuse_multiple_avg = AveragingReuser(hetero_learnware_list, mode="mean")
- return reuse_multiple_avg
-
-
- test_methods = {
- "user_model": user_model_score,
- "hetero_single_aug": HeteroMethods.single_aug_score,
- "hetero_multiple_aug": HeteroMethods.multiple_aug_score,
- "hetero_multiple_avg": HeteroMethods.multiple_avg_score,
- "hetero_ensemble_pruning": HeteroMethods.multiple_ensemble_pruning_score,
- "homo_single_aug": HomoScoringMethods.single_aug_score,
- "homo_multiple_aug": HomoScoringMethods.multiple_aug_score,
- "homo_multiple_avg": HomoScoringMethods.multiple_avg_score,
- "homo_ensemble_pruning": HomoScoringMethods.multiple_ensemble_pruning_score
- }
|