You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

methods.py 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import numpy as np
  2. from sklearn.metrics import mean_squared_error
  3. from sklearn.model_selection import train_test_split
  4. from learnware.reuse import AveragingReuser, EnsemblePruningReuser, FeatureAugmentReuser, HeteroMapAlignLearnware
  5. from config import align_model_params
  6. from train import train_model
  7. def loss_func_rmse(y_true, y_pred):
  8. return np.sqrt(mean_squared_error(y_true, y_pred))
  9. def user_model_score(x_train, y_train, test_info):
  10. x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.2, random_state=42)
  11. user_model = train_model(x_train, y_train, x_val, y_val, test_info)
  12. return user_model
  13. class HomoScoringMethods:
  14. @staticmethod
  15. def single_aug_score(x_train, y_train, test_info):
  16. single_learnware = test_info["single_learnware"]
  17. reuse_single_augment = FeatureAugmentReuser(single_learnware, mode="regression")
  18. reuse_single_augment.fit(x_train=x_train, y_train=y_train)
  19. return reuse_single_augment
  20. @staticmethod
  21. def multiple_aug_score(x_train, y_train, test_info):
  22. multiple_learnwares = test_info["learnwares"]
  23. reuse_multiple_augment = FeatureAugmentReuser(multiple_learnwares, mode="regression")
  24. reuse_multiple_augment.fit(x_train=x_train, y_train=y_train)
  25. return reuse_multiple_augment
  26. @staticmethod
  27. def multiple_avg_score(x_train, y_train, test_info):
  28. multiple_learnwares = test_info["learnwares"]
  29. reuse_multiple_avg = AveragingReuser(multiple_learnwares, mode="mean")
  30. return reuse_multiple_avg
  31. @staticmethod
  32. def multiple_ensemble_pruning_score(x_train, y_train, test_info):
  33. multiple_learnwares = test_info["learnwares"]
  34. if len(multiple_learnwares) == 1:
  35. return multiple_learnwares[0]
  36. reuse_pruning = EnsemblePruningReuser(multiple_learnwares, mode="regression")
  37. reuse_pruning.fit(val_X=x_train, val_y=y_train)
  38. return reuse_pruning
  39. class HeteroMethods:
  40. @staticmethod
  41. def create_hetero_learnware_list(learnware_list, user_rkme, x_train, y_train):
  42. hetero_learnware_list = []
  43. for learnware in learnware_list:
  44. hetero_learnware = HeteroMapAlignLearnware(learnware, mode="regression", **align_model_params)
  45. hetero_learnware.align(user_rkme, x_train, y_train)
  46. hetero_learnware_list.append(hetero_learnware)
  47. return hetero_learnware_list
  48. @staticmethod
  49. def single_aug_score(x_train, y_train, test_info):
  50. user_rkme, single_learnware = test_info["user_rkme"], test_info["single_learnware"]
  51. reuse_single_augment = HeteroMapAlignLearnware(single_learnware, mode="regression", **align_model_params)
  52. reuse_single_augment.align(user_rkme=user_rkme, x_train=x_train, y_train=y_train)
  53. return reuse_single_augment
  54. @staticmethod
  55. def multiple_aug_score(x_train, y_train, test_info):
  56. user_rkme, multiple_learnwares = test_info["user_rkme"], test_info["learnwares"]
  57. hetero_learnware_list = HeteroMethods.create_hetero_learnware_list(multiple_learnwares, user_rkme, x_train, y_train)
  58. reuse_multiple_augment = FeatureAugmentReuser(hetero_learnware_list, mode="regression")
  59. reuse_multiple_augment.fit(x_train=x_train, y_train=y_train)
  60. return reuse_multiple_augment
  61. @staticmethod
  62. def multiple_ensemble_pruning_score(x_train, y_train, test_info):
  63. user_rkme, multiple_learnwares = test_info["user_rkme"], test_info["learnwares"]
  64. hetero_learnware_list = HeteroMethods.create_hetero_learnware_list(multiple_learnwares, user_rkme, x_train, y_train)
  65. if len(hetero_learnware_list) == 1:
  66. return hetero_learnware_list[0]
  67. reuse_pruning = EnsemblePruningReuser(hetero_learnware_list, mode="regression")
  68. reuse_pruning.fit(val_X=x_train, val_y=y_train)
  69. return reuse_pruning
  70. @staticmethod
  71. def multiple_avg_score(x_train, y_train, test_info):
  72. user_rkme, multiple_learnwares = test_info["user_rkme"], test_info["learnwares"]
  73. hetero_learnware_list = HeteroMethods.create_hetero_learnware_list(multiple_learnwares, user_rkme, x_train, y_train)
  74. reuse_multiple_avg = AveragingReuser(hetero_learnware_list, mode="mean")
  75. return reuse_multiple_avg
  76. test_methods = {
  77. "user_model": user_model_score,
  78. "hetero_single_aug": HeteroMethods.single_aug_score,
  79. "hetero_multiple_aug": HeteroMethods.multiple_aug_score,
  80. "hetero_multiple_avg": HeteroMethods.multiple_avg_score,
  81. "hetero_ensemble_pruning": HeteroMethods.multiple_ensemble_pruning_score,
  82. "homo_single_aug": HomoScoringMethods.single_aug_score,
  83. "homo_multiple_aug": HomoScoringMethods.multiple_aug_score,
  84. "homo_multiple_avg": HomoScoringMethods.multiple_avg_score,
  85. "homo_ensemble_pruning": HomoScoringMethods.multiple_ensemble_pruning_score
  86. }