You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

methods.py 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import numpy as np
  2. from config import align_model_params
  3. from sklearn.metrics import mean_squared_error
  4. from sklearn.model_selection import train_test_split
  5. from train import train_model
  6. from learnware.reuse import AveragingReuser, EnsemblePruningReuser, FeatureAugmentReuser, HeteroMapAlignLearnware
  7. def loss_func_rmse(y_true, y_pred):
  8. return np.sqrt(mean_squared_error(y_true, y_pred))
  9. def user_model_score(x_train, y_train, test_info):
  10. x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.2, random_state=42)
  11. user_model = train_model(x_train, y_train, x_val, y_val, test_info)
  12. return user_model
  13. class HomoScoringMethods:
  14. @staticmethod
  15. def single_aug_score(x_train, y_train, test_info):
  16. single_learnware = test_info["single_learnware"]
  17. reuse_single_augment = FeatureAugmentReuser(single_learnware, mode="regression")
  18. reuse_single_augment.fit(x_train=x_train, y_train=y_train)
  19. return reuse_single_augment
  20. @staticmethod
  21. def multiple_aug_score(x_train, y_train, test_info):
  22. multiple_learnwares = test_info["learnwares"]
  23. reuse_multiple_augment = FeatureAugmentReuser(multiple_learnwares, mode="regression")
  24. reuse_multiple_augment.fit(x_train=x_train, y_train=y_train)
  25. return reuse_multiple_augment
  26. @staticmethod
  27. def multiple_avg_score(x_train, y_train, test_info):
  28. multiple_learnwares = test_info["learnwares"]
  29. reuse_multiple_avg = AveragingReuser(multiple_learnwares, mode="mean")
  30. return reuse_multiple_avg
  31. @staticmethod
  32. def multiple_ensemble_pruning_score(x_train, y_train, test_info):
  33. multiple_learnwares = test_info["learnwares"]
  34. if len(multiple_learnwares) == 1:
  35. return multiple_learnwares[0]
  36. reuse_pruning = EnsemblePruningReuser(multiple_learnwares, mode="regression")
  37. reuse_pruning.fit(val_X=x_train, val_y=y_train)
  38. return reuse_pruning
  39. class HeteroMethods:
  40. @staticmethod
  41. def create_hetero_learnware_list(learnware_list, user_rkme, x_train, y_train):
  42. hetero_learnware_list = []
  43. for learnware in learnware_list:
  44. hetero_learnware = HeteroMapAlignLearnware(learnware, mode="regression", **align_model_params)
  45. hetero_learnware.align(user_rkme, x_train, y_train)
  46. hetero_learnware_list.append(hetero_learnware)
  47. return hetero_learnware_list
  48. @staticmethod
  49. def single_aug_score(x_train, y_train, test_info):
  50. user_rkme, single_learnware = test_info["user_rkme"], test_info["single_learnware"]
  51. reuse_single_augment = HeteroMapAlignLearnware(single_learnware, mode="regression", **align_model_params)
  52. reuse_single_augment.align(user_rkme=user_rkme, x_train=x_train, y_train=y_train)
  53. return reuse_single_augment
  54. @staticmethod
  55. def multiple_aug_score(x_train, y_train, test_info):
  56. user_rkme, multiple_learnwares = test_info["user_rkme"], test_info["learnwares"]
  57. hetero_learnware_list = HeteroMethods.create_hetero_learnware_list(
  58. multiple_learnwares, user_rkme, x_train, y_train
  59. )
  60. reuse_multiple_augment = FeatureAugmentReuser(hetero_learnware_list, mode="regression")
  61. reuse_multiple_augment.fit(x_train=x_train, y_train=y_train)
  62. return reuse_multiple_augment
  63. @staticmethod
  64. def multiple_ensemble_pruning_score(x_train, y_train, test_info):
  65. user_rkme, multiple_learnwares = test_info["user_rkme"], test_info["learnwares"]
  66. hetero_learnware_list = HeteroMethods.create_hetero_learnware_list(
  67. multiple_learnwares, user_rkme, x_train, y_train
  68. )
  69. if len(hetero_learnware_list) == 1:
  70. return hetero_learnware_list[0]
  71. reuse_pruning = EnsemblePruningReuser(hetero_learnware_list, mode="regression")
  72. reuse_pruning.fit(val_X=x_train, val_y=y_train)
  73. return reuse_pruning
  74. @staticmethod
  75. def multiple_avg_score(x_train, y_train, test_info):
  76. user_rkme, multiple_learnwares = test_info["user_rkme"], test_info["learnwares"]
  77. hetero_learnware_list = HeteroMethods.create_hetero_learnware_list(
  78. multiple_learnwares, user_rkme, x_train, y_train
  79. )
  80. reuse_multiple_avg = AveragingReuser(hetero_learnware_list, mode="mean")
  81. return reuse_multiple_avg
  82. test_methods = {
  83. "user_model": user_model_score,
  84. "hetero_single_aug": HeteroMethods.single_aug_score,
  85. "hetero_multiple_aug": HeteroMethods.multiple_aug_score,
  86. "hetero_multiple_avg": HeteroMethods.multiple_avg_score,
  87. "hetero_ensemble_pruning": HeteroMethods.multiple_ensemble_pruning_score,
  88. "homo_single_aug": HomoScoringMethods.single_aug_score,
  89. "homo_multiple_aug": HomoScoringMethods.multiple_aug_score,
  90. "homo_multiple_avg": HomoScoringMethods.multiple_avg_score,
  91. "homo_ensemble_pruning": HomoScoringMethods.multiple_ensemble_pruning_score,
  92. }