You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pfs_cross_transfer.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. import os
  2. import pickle
  3. import joblib
  4. import numpy as np
  5. import pandas as pd
  6. import lightgbm as lgb
  7. from sklearn.linear_model import Ridge
  8. from sklearn.model_selection import GridSearchCV
  9. from matplotlib import pyplot as plt
  10. import matplotlib.ticker as ticker
  11. from mpl_toolkits.axes_grid1 import make_axes_locatable
  12. np.seterr(divide="ignore", invalid="ignore")
  13. from .paths import pfs_split_dir, pfs_res_dir, model_dir
  14. np.random.seed(0)
  15. def load_pfs_data(fpath):
  16. df = pd.read_csv(fpath)
  17. features = list(df.columns)
  18. features.remove("item_cnt_month")
  19. features.remove("date_block_num")
  20. # remove id info
  21. # features.remove('shop_id')
  22. # features.remove('item_id')
  23. # remove discrete info
  24. # features.remove('city_code')
  25. # features.remove('item_category_code')
  26. # features.remove('item_category_common')
  27. xs = df[features].values
  28. ys = df["item_cnt_month"].values
  29. categorical_feature_names = ["country_part", "item_category_common", "item_category_code", "city_code"]
  30. types = None
  31. return xs, ys, features, types
  32. def get_split_errs(algo):
  33. """
  34. according to proportion_list, generate errs whose shape is [shop, split_data]
  35. """
  36. shop_ids = [i for i in range(60) if i not in [0, 1, 40]]
  37. shop_ids = [i for i in shop_ids if i not in [8, 11, 23, 36]]
  38. user_list = [i for i in range(53)]
  39. proportion_list = [100, 300, 500, 700, 900, 1000, 3000, 5000, 7000, 9000, 10000, 30000, 50000, 70000]
  40. # train
  41. errs = np.zeros((len(user_list), len(proportion_list)))
  42. for s, sid in enumerate(user_list):
  43. # load train data
  44. fpath = os.path.join(pfs_split_dir, "Shop{:0>2d}-train.csv".format(shop_ids[sid]))
  45. fpath_val = os.path.join(pfs_split_dir, "Shop{:0>2d}-val.csv".format(shop_ids[sid]))
  46. train_xs, train_ys, _, _ = load_pfs_data(fpath)
  47. val_xs, val_ys, _, _ = load_pfs_data(fpath_val)
  48. print(shop_ids[sid], train_xs.shape, train_ys.shape)
  49. # data regu
  50. # train_xs = (train_xs - train_xs.min(0)) / (train_xs.max(0) - train_xs.min(0) + 0.0001)
  51. # val_xs = (val_xs - val_xs.min(0)) / (val_xs.max(0) - val_xs.min(0) + 0.0001)
  52. if algo == "lgb":
  53. for tmp in range(len(proportion_list)):
  54. model = lgb.LGBMModel(
  55. boosting_type="gbdt",
  56. num_leaves=2**7 - 1,
  57. learning_rate=0.01,
  58. objective="rmse",
  59. metric="rmse",
  60. feature_fraction=0.75,
  61. bagging_fraction=0.75,
  62. bagging_freq=5,
  63. seed=1,
  64. verbose=1,
  65. n_estimators=100000,
  66. )
  67. model_ori = joblib.load(os.path.join(model_dir, "{}_Shop{:0>2d}.out".format("lgb", shop_ids[sid])))
  68. para = model_ori.get_params()
  69. para["n_estimators"] = 1000
  70. model.set_params(**para)
  71. split = train_xs.shape[0] - proportion_list[tmp]
  72. model.fit(
  73. train_xs[split:,],
  74. train_ys[split:],
  75. eval_set=[(val_xs, val_ys)],
  76. early_stopping_rounds=50,
  77. verbose=100,
  78. )
  79. pred_ys = model.predict(val_xs)
  80. rmse = np.sqrt(((val_ys - pred_ys) ** 2).mean())
  81. errs[s][tmp] = rmse
  82. return errs
  83. def get_errors(algo):
  84. shop_ids = [i for i in range(60) if i not in [0, 1, 40]]
  85. shop_ids = [i for i in shop_ids if i not in [8, 11, 23, 36]]
  86. # train
  87. K = len(shop_ids)
  88. feature_weight = np.zeros(())
  89. errs = np.zeros((K, K))
  90. for s, sid in enumerate(shop_ids):
  91. # load train data
  92. fpath = os.path.join(pfs_split_dir, "Shop{:0>2d}-train.csv".format(sid))
  93. fpath_val = os.path.join(pfs_split_dir, "Shop{:0>2d}-val.csv".format(sid))
  94. train_xs, train_ys, features, _ = load_pfs_data(fpath)
  95. val_xs, val_ys, _, _ = load_pfs_data(fpath_val)
  96. print(sid, train_xs.shape, train_ys.shape)
  97. if s == 0:
  98. feature_weight = np.zeros((K, len(features)))
  99. if algo == "lgb":
  100. model = lgb.LGBMModel(
  101. boosting_type="gbdt",
  102. num_leaves=2**7 - 1,
  103. learning_rate=0.01,
  104. objective="rmse",
  105. metric="rmse",
  106. feature_fraction=0.75,
  107. bagging_fraction=0.75,
  108. bagging_freq=5,
  109. seed=1,
  110. verbose=1,
  111. n_estimators=1000,
  112. )
  113. # train regu data
  114. # train_xs = (train_xs - train_xs.min(0)) / (train_xs.max(0) - train_xs.min(0) + 0.0001)
  115. # val_xs = (val_xs - val_xs.min(0)) / (val_xs.max(0) - val_xs.min(0) + 0.0001)
  116. model.fit(train_xs, train_ys, eval_set=[(val_xs, val_ys)], early_stopping_rounds=100, verbose=100)
  117. # grid search
  118. # para = {'learning_rate': [0.005, 0.01, 0.015], 'num_leaves' : [128, 224, 300], 'max_depth' : [50, 66, 80]}
  119. # grid_search = GridSearchCV(model, para, scoring='neg_mean_squared_error')
  120. # grid_result = grid_search.fit(train_xs, train_ys, eval_set=[(val_xs, val_ys)], verbose = 1000, early_stopping_rounds=1000)
  121. # model = grid_result.best_estimator_
  122. joblib.dump(model, os.path.join(model_dir, "{}_Shop{:0>2d}.out".format(algo, sid)))
  123. importances = model.feature_importances_
  124. elif algo == "ridge":
  125. # train_xs = (train_xs - train_xs.min(0)) / (train_xs.max(0) - train_xs.min(0) + 0.0001)
  126. model = Ridge()
  127. para = {"alpha": [0.01, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10, 20, 30]}
  128. grid_search = GridSearchCV(model, para)
  129. grid_result = grid_search.fit(train_xs, train_ys)
  130. model = grid_result.best_estimator_
  131. importances = model.coef_
  132. joblib.dump(model, os.path.join(model_dir, "{}_Shop{:0>2d}.out".format(algo, sid)))
  133. feature_weight[s] = importances
  134. # leave one out test
  135. for t, tid in enumerate(shop_ids):
  136. # load test data
  137. fpath = os.path.join(pfs_split_dir, "Shop{:0>2d}-val.csv".format(tid))
  138. test_xs, test_ys, _, _ = load_pfs_data(fpath)
  139. # data regu
  140. # test_xs = (test_xs - test_xs.min(0)) / (test_xs.max(0) - test_xs.min(0) + 0.0001)
  141. pred_ys = model.predict(test_xs)
  142. rmse = np.sqrt(((test_ys - pred_ys) ** 2).mean())
  143. print("Shop{} --> Shop{}: {}".format(s, t, rmse))
  144. errs[s][t] = rmse
  145. np.savetxt(os.path.join(pfs_res_dir, "PFS_{}_weights.txt".format(algo)), feature_weight)
  146. return errs
  147. def plot_heatmap(mat, algo):
  148. x_labels = [f"Model{i}" for i in range(mat.shape[1])]
  149. y_labels = [f"Task{i}" for i in range(mat.shape[0])]
  150. fig = plt.figure(figsize=(10, 9))
  151. plt.subplot(1, 1, 1)
  152. ax = plt.gca()
  153. im = plt.imshow(mat)
  154. divider = make_axes_locatable(ax)
  155. cax = divider.append_axes("right", size="4%", pad=0.3)
  156. plt.colorbar(im, cax=cax)
  157. ax.set_xticks(range(len(x_labels)))
  158. ax.set_xticklabels(x_labels)
  159. ax.set_yticks(range(len(y_labels)))
  160. ax.set_yticklabels(y_labels)
  161. ax.xaxis.set_major_locator(ticker.MultipleLocator(base=5))
  162. ax.yaxis.set_major_locator(ticker.MultipleLocator(base=5))
  163. ax.set_title(f"RMSE on Test set ({algo})")
  164. plt.tight_layout()
  165. plt.savefig(os.path.join(pfs_res_dir, "PFS_{}_heatmap.jpg".format(algo)), dpi=700)
  166. def plot_var(errs, algo):
  167. avg_err = []
  168. min_err = []
  169. med_err = []
  170. max_err = []
  171. std_err = []
  172. cnts = []
  173. improves = []
  174. for j in range(len(errs)):
  175. inds = [i for i in range(len(errs)) if i != j]
  176. ys = errs[:, j][inds]
  177. avg_err.append(np.mean(ys))
  178. min_err.append(np.min(ys))
  179. med_err.append(np.median(ys))
  180. max_err.append(np.max(ys))
  181. std_err.append(np.std(ys))
  182. cnts.append(np.sum(ys >= np.mean(ys)))
  183. improves.append((np.mean(ys) - np.min(ys)) / np.mean(ys))
  184. avg_err = np.array(avg_err)
  185. min_err = np.array(min_err)
  186. med_err = np.array(med_err)
  187. max_err = np.array(max_err)
  188. std_err = np.array(std_err)
  189. cnts = np.array(cnts)
  190. improves = np.array(improves)
  191. inds = np.argsort(avg_err)
  192. avg_err = avg_err[inds]
  193. min_err = min_err[inds]
  194. med_err = med_err[inds]
  195. max_err = max_err[inds]
  196. std_err = std_err[inds]
  197. cnts = cnts[inds]
  198. improves = improves[inds]
  199. xs = list(range(len(inds)))
  200. fig = plt.figure(figsize=(8, 8))
  201. ax = plt.subplot(3, 1, 1)
  202. ax.plot(xs, avg_err, color="red", linestyle="solid", linewidth=2.5)
  203. ax.plot(xs, min_err, color="blue", linestyle="dotted", linewidth=1.5)
  204. ax.plot(xs, med_err, color="purple", linestyle="solid", linewidth=1.0)
  205. ax.plot(xs, max_err, color="green", linestyle="dashed", linewidth=1.5)
  206. ax.legend(["Avg", "Min", "Median", "Max"], fontsize=14)
  207. ax.fill_between(xs, avg_err - std_err, avg_err + std_err, alpha=0.2)
  208. gap = np.mean(avg_err - min_err)
  209. ax.set_ylabel("RMSE", fontsize=14)
  210. ax.set_title("RMSE of Source Models ({}) [Avg-Min:{:.3f}]".format(algo, gap), fontsize=18)
  211. ax = plt.subplot(3, 1, 2)
  212. ax.bar(xs, cnts)
  213. ax.set_ylabel("Number", fontsize=14)
  214. ax.set_title("Number of sources above average", fontsize=18)
  215. ax = plt.subplot(3, 1, 3)
  216. ax.plot(xs, improves)
  217. ax.set_xlabel("Sorted Shop ID by Avg.Err", fontsize=14)
  218. ax.set_ylabel("Ratio", fontsize=14)
  219. ax.set_title("Best Improve Ratio: (Avg - Min) / Avg", fontsize=18)
  220. fig.tight_layout()
  221. fig.savefig(os.path.join(pfs_res_dir, "{}-var.jpg".format(algo)))
  222. plt.show()
  223. def plot_performance(errs, weights, algo):
  224. avg_err = []
  225. min_err = []
  226. med_err = []
  227. max_err = []
  228. std_err = []
  229. cnts = []
  230. improves = []
  231. for i in range(errs.shape[0]):
  232. inds = [j for j in range(errs.shape[1]) if j != i]
  233. arr = errs[i][inds]
  234. avg_err.append(np.mean(arr))
  235. min_err.append(np.min(arr))
  236. med_err.append(np.median(arr))
  237. max_err.append(np.max(arr))
  238. std_err.append(np.std(arr))
  239. cnts.append(np.sum(arr >= np.mean(arr)))
  240. improves.append((np.mean(arr) - np.min(arr)) / np.mean(arr))
  241. avg_err = np.array(avg_err)
  242. min_err = np.array(min_err)
  243. med_err = np.array(med_err)
  244. max_err = np.array(max_err)
  245. std_err = np.array(std_err)
  246. cnts = np.array(cnts)
  247. improves = np.array(improves)
  248. inds = np.argsort(avg_err)
  249. avg_err = avg_err[inds]
  250. min_err = min_err[inds]
  251. med_err = med_err[inds]
  252. max_err = max_err[inds]
  253. std_err = std_err[inds]
  254. cnts = cnts[inds]
  255. improves = improves[inds]
  256. xs = list(range(len(inds)))
  257. fig = plt.figure(figsize=(12, 9))
  258. ax = plt.subplot(2, 2, 1)
  259. ax.plot(xs, avg_err, color="red", linestyle="solid", linewidth=2.5)
  260. ax.plot(xs, min_err, color="blue", linestyle="dotted", linewidth=1.5)
  261. ax.plot(xs, med_err, color="purple", linestyle="solid", linewidth=1.0)
  262. ax.plot(xs, max_err, color="green", linestyle="dashed", linewidth=1.5)
  263. ax.legend(["Avg", "Min", "Median", "Max"], fontsize=14)
  264. ax.fill_between(xs, avg_err - std_err, avg_err + std_err, alpha=0.2)
  265. gap = np.mean(avg_err - min_err)
  266. ax.set_ylabel("RMSE", fontsize=14)
  267. ax.set_title("RMSE of Source Models ({}) [Avg-Min:{:.3f}]".format(algo, gap), fontsize=18)
  268. ax = plt.subplot(2, 2, 2)
  269. ax.bar(xs, cnts)
  270. ax.set_ylabel("Number", fontsize=14)
  271. ax.set_title("Number of sources above average", fontsize=18)
  272. ax = plt.subplot(2, 2, 3)
  273. ax.plot(xs, improves)
  274. ax.set_xlabel("Sorted Shop ID by Avg.Err", fontsize=14)
  275. ax.set_ylabel("Ratio", fontsize=14)
  276. ax.set_title("Best Improve Ratio: (Avg - Min) / Avg", fontsize=18)
  277. ax = plt.subplot(2, 2, 4)
  278. weights = np.mean(weights, axis=0) / weights.sum()
  279. weights = np.sort(weights)
  280. xs = list(range(len(weights)))
  281. ax.plot(xs, weights)
  282. # ax.set_xlabel("Sorted Feature ID by Avg.Feature_Importance", fontsize=14)
  283. ax.set_ylabel("Proportion", fontsize=14)
  284. ax.set_title("Avg.Feature_Importances", fontsize=18)
  285. fig.tight_layout()
  286. fig.savefig(os.path.join(pfs_res_dir, "PFS_{}_performance.png".format(algo)), dpi=700)
  287. # fig.savefig(f"{algo}_performance.png", dpi=700)
  288. plt.show()
  289. if __name__ == "__main__":
  290. # for algo in ["ridge", "lgb", "xgboost_125"]:
  291. for algo in ["ridge"]:
  292. fpath = os.path.join(pfs_res_dir, "{}_errs.pkl".format(algo))
  293. if os.path.exists(fpath):
  294. with open(fpath, "rb") as fr:
  295. errs = pickle.load(fr)
  296. else:
  297. errs = get_errors(algo=algo)
  298. with open(fpath, "wb") as fw:
  299. pickle.dump(errs, fw)
  300. index = ["Source{}".format(k) for k in range(len(errs))]
  301. columns = ["Target{}".format(k) for k in range(len(errs[0]))]
  302. df = pd.DataFrame(errs, index=index, columns=columns)
  303. fpath = os.path.join(pfs_res_dir, "PFS_{}_errs.txt".format(algo))
  304. # df.to_csv(fpath, index=True)
  305. np.savetxt(fpath, errs.T)
  306. # plot_var(errs, algo)
  307. plot_heatmap(errs.T, algo)
  308. weights = np.loadtxt(os.path.join(pfs_res_dir, "PFS_{}_weights.txt".format(algo)))
  309. plot_performance(errs.T, weights, algo)