You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 3.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import os
  2. import pickle
  3. import numpy as np
  4. import pandas as pd
  5. from lightgbm import LGBMClassifier
  6. from sklearn.feature_extraction.text import TfidfVectorizer
  7. class TextDataLoader:
  8. def __init__(self, data_root, train: bool = True):
  9. self.data_root = data_root
  10. self.train = train
  11. def get_idx_data(self, idx=0):
  12. if self.train:
  13. X_path = os.path.join(self.data_root, "uploader", "uploader_%d_X.pkl" % (idx))
  14. y_path = os.path.join(self.data_root, "uploader", "uploader_%d_y.pkl" % (idx))
  15. if not (os.path.exists(X_path) and os.path.exists(y_path)):
  16. raise Exception("Index Error")
  17. with open(X_path, "rb") as f:
  18. X = pickle.load(f)
  19. with open(y_path, "rb") as f:
  20. y = pickle.load(f)
  21. else:
  22. X_path = os.path.join(self.data_root, "user", "user_%d_X.pkl" % (idx))
  23. y_path = os.path.join(self.data_root, "user", "user_%d_y.pkl" % (idx))
  24. if not (os.path.exists(X_path) and os.path.exists(y_path)):
  25. raise Exception("Index Error")
  26. with open(X_path, "rb") as f:
  27. X = pickle.load(f)
  28. with open(y_path, "rb") as f:
  29. y = pickle.load(f)
  30. return X, y
  31. def generate_uploader(data_x: pd.Series, data_y: pd.Series, n_uploaders=50, data_save_root=None):
  32. if data_save_root is None:
  33. return
  34. os.makedirs(data_save_root, exist_ok=True)
  35. types = data_x["discourse_type"].unique()
  36. for i in range(n_uploaders):
  37. indices = data_x["discourse_type"] == types[i]
  38. selected_X = data_x[indices]["discourse_text"].to_list()
  39. selected_y = data_y[indices].to_list()
  40. X_save_dir = os.path.join(data_save_root, "uploader_%d_X.pkl" % (i))
  41. y_save_dir = os.path.join(data_save_root, "uploader_%d_y.pkl" % (i))
  42. with open(X_save_dir, "wb") as f:
  43. pickle.dump(selected_X, f)
  44. with open(y_save_dir, "wb") as f:
  45. pickle.dump(selected_y, f)
  46. print("Saving to %s" % (X_save_dir))
  47. def generate_user(data_x, data_y, n_users=50, data_save_root=None):
  48. if data_save_root is None:
  49. return
  50. os.makedirs(data_save_root, exist_ok=True)
  51. types = data_x["discourse_type"].unique()
  52. for i in range(n_users):
  53. indices = data_x["discourse_type"] == types[i]
  54. selected_X = data_x[indices]["discourse_text"].to_list()
  55. selected_y = data_y[indices].to_list()
  56. X_save_dir = os.path.join(data_save_root, "user_%d_X.pkl" % (i))
  57. y_save_dir = os.path.join(data_save_root, "user_%d_y.pkl" % (i))
  58. with open(X_save_dir, "wb") as f:
  59. pickle.dump(selected_X, f)
  60. with open(y_save_dir, "wb") as f:
  61. pickle.dump(selected_y, f)
  62. print("Saving to %s" % (X_save_dir))
  63. # Train Uploaders' models
  64. def train(X, y, out_classes):
  65. vectorizer = TfidfVectorizer(stop_words="english")
  66. X_tfidf = vectorizer.fit_transform(X)
  67. lgbm = LGBMClassifier(boosting_type="dart", n_estimators=500, num_leaves=21)
  68. lgbm.fit(X_tfidf, y)
  69. return vectorizer, lgbm
  70. def eval_prediction(pred_y, target_y):
  71. if not isinstance(pred_y, np.ndarray):
  72. pred_y = pred_y.detach().cpu().numpy()
  73. if len(pred_y.shape) == 1:
  74. predicted = np.array(pred_y)
  75. else:
  76. predicted = np.argmax(pred_y, 1)
  77. annos = np.array(target_y)
  78. total = predicted.shape[0]
  79. correct = (predicted == annos).sum().item()
  80. return correct / total