You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

config.py 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. from learnware.tests.benchmarks import BenchmarkConfig
  2. n_labeled_list = [100, 200, 500, 1000, 2000, 4000, 6000, 8000, 10000]
  3. n_repeat_list = [10, 10, 10, 3, 3, 3, 3, 3, 3]
  4. styles = {
  5. 'user_model': {"color": "navy", "marker": "o", "linestyle": "-"},
  6. 'select_score': {'color': 'gold', 'marker': 's', 'linestyle': '--'},
  7. 'oracle_score': {'color': 'darkorange', 'marker': '^', 'linestyle': '-.'},
  8. 'mean_score': {'color': 'gray', 'marker': 'x', 'linestyle': ':'},
  9. 'single_aug': {'color': 'gold', 'marker': 's', 'linestyle': '--'},
  10. 'multiple_avg': {'color': 'blue', 'marker': '*', 'linestyle': '-'},
  11. 'multiple_aug': {'color': 'purple', 'marker': 'd', 'linestyle': '--'},
  12. 'ensemble_pruning': {"color": "magenta", "marker": "d", "linestyle": "-."}
  13. }
  14. labels = {
  15. 'user_model': "User Model",
  16. 'single_aug': "Single Learnware Reuse (Select)",
  17. "select_score": "Single Learnware Reuse (Select)",
  18. 'multiple_aug': "Multiple Learnware Reuse (FeatAug)",
  19. 'ensemble_pruning': "Multiple Learnware Reuse (EnsemblePrune)",
  20. 'multiple_avg': "Multiple Learnware Reuse (Averaging)"
  21. }
  22. align_model_params = {
  23. "network_type": "ArbitraryMapping", # ["ArbitraryMapping", "BaseMapping", "BaseMapping_BN", "BaseMapping_Dropout"]
  24. "num_epoch": 50,
  25. "lr": 1e-5,
  26. "dropout_ratio": 0.2,
  27. "activation": "relu",
  28. "use_bn": True,
  29. "hidden_dims": [128, 256, 128, 256],
  30. }
  31. market_mapping_params = {
  32. "lr": 1e-4, # [5e-5, 1e-4, 2e-4, 5e-4],
  33. "num_epoch": 50,
  34. "batch_size": 64, # [64, 128, 256, 512, 1024],
  35. "num_partition": 2, # [2, 3, 4], # num of column partitions for pos/neg sampling
  36. "overlap_ratio": 0.7, # [0.1, 0.3, 0.5, 0.7], # specify the overlap ratio of column partitions during the CL
  37. "hidden_dim": 256, # [64, 128, 256, 512, 768, 1024], # the dimension of hidden embeddings
  38. "num_layer": 6, # [4, 6, 8, 10, 12, 14, 16, 20], # the number of transformer layers used in the encoder
  39. "num_attention_head": 8, # [4, 8, 16], # the numebr of heads of multihead self-attention layer in the transformers, should be divisible by hidden_dim
  40. "hidden_dropout_prob": 0.5, # [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6], # the dropout ratio in the transformer encoder
  41. "ffn_dim": 512, # [128, 256, 512, 768, 1024], # the dimension of feed-forward layer in the transformer layer
  42. "activation": "leakyrelu",
  43. }
  44. user_model_params = {
  45. "Corporacion": {
  46. "lgb": {
  47. "params": {
  48. "num_leaves": 31,
  49. "objective": "regression",
  50. "learning_rate": 0.1,
  51. "feature_fraction": 0.8,
  52. "bagging_fraction": 0.8,
  53. "bagging_freq": 2,
  54. "metric": "l2",
  55. "num_threads": 4,
  56. "verbose": -1,
  57. },
  58. "MAX_ROUNDS": 500,
  59. "early_stopping_rounds": 50,
  60. }
  61. }
  62. }
  63. homo_table_benchmark_config = BenchmarkConfig(
  64. name="Corporacion",
  65. user_num=54,
  66. learnware_ids=[
  67. "00000912",
  68. "00000911",
  69. "00000910",
  70. "00000909",
  71. "00000908",
  72. "00000907",
  73. "00000906",
  74. "00000905",
  75. "00000904",
  76. "00000903",
  77. "00000902",
  78. "00000901",
  79. "00000900",
  80. "00000899",
  81. "00000898",
  82. "00000897",
  83. "00000896",
  84. "00000895",
  85. "00000894",
  86. "00000893",
  87. "00000892",
  88. "00000891",
  89. "00000890",
  90. "00000889",
  91. "00000888",
  92. "00000887",
  93. "00000886",
  94. "00000885",
  95. "00000884",
  96. "00000883",
  97. "00000882",
  98. "00000881",
  99. "00000880",
  100. "00000879",
  101. "00000878",
  102. "00000877",
  103. "00000876",
  104. "00000875",
  105. "00000874",
  106. "00000873",
  107. "00000872",
  108. "00000871",
  109. "00000870",
  110. "00000869",
  111. "00000868",
  112. "00000867",
  113. "00000866",
  114. "00000865",
  115. "00000864",
  116. "00000863",
  117. "00000862",
  118. "00000861",
  119. "00000860",
  120. "00000859"
  121. ],
  122. test_data_path="Corporacion/test_data.zip",
  123. train_data_path="Corporacion/train_data.zip",
  124. extra_info_path="Corporacion/extra_info.zip",
  125. )