You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 1.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. import numpy as np
  2. import lightgbm as lgb
  3. from lightgbm import early_stopping
  4. from sklearn.metrics import mean_squared_error
  5. from learnware.logger import get_module_logger
  6. from config import user_model_params
  7. logger = get_module_logger("train_table", level="INFO")
  8. def train_lgb(X_train, y_train, X_val, y_val, dataset):
  9. logger.info("Training and predicting models...")
  10. model_param = user_model_params[dataset]["lgb"]
  11. params = model_param["params"]
  12. MAX_ROUNDS = model_param["MAX_ROUNDS"]
  13. val_pred = []
  14. cate_vars = []
  15. logger.info(f"{np.shape(X_train)}, {np.shape(y_train)}, {np.shape(X_val)}, {np.shape(y_val)}")
  16. dtrain = lgb.Dataset(X_train, label=y_train, categorical_feature=cate_vars)
  17. dval = lgb.Dataset(X_val, label=y_val, reference=dtrain, categorical_feature=cate_vars)
  18. bst = lgb.train(
  19. params,
  20. dtrain,
  21. num_boost_round=MAX_ROUNDS,
  22. valid_sets=[dtrain, dval],
  23. callbacks=[early_stopping(model_param["early_stopping_rounds"], verbose=False)]
  24. )
  25. val_pred.append(bst.predict(X_val, num_iteration=bst.best_iteration or MAX_ROUNDS))
  26. logger.info(f"Validation mse:{mean_squared_error(y_val, np.array(val_pred).transpose())}")
  27. return bst
  28. def train_ridge(X_train, y_train, X_val, y_val, dataset):
  29. pass
  30. def train_model(X_train, y_train, X_val, y_val, test_info):
  31. dataset = test_info["dataset"]
  32. model_type = test_info["model_type"]
  33. assert model_type in ["lgb", "ridge"]
  34. if model_type == "lgb":
  35. return train_lgb(X_train, y_train, X_val, y_val, dataset)