You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 1.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import numpy as np
  2. import lightgbm as lgb
  3. from lightgbm import early_stopping
  4. from sklearn.metrics import mean_squared_error
  5. from learnware.logger import get_module_logger
  6. from config import user_model_params
  7. logger = get_module_logger("train_table", level="INFO")
  8. def train_lgb(X_train, y_train, X_val, y_val, dataset):
  9. model_param = user_model_params[dataset]["lgb"]
  10. params = model_param["params"]
  11. MAX_ROUNDS = model_param["MAX_ROUNDS"]
  12. val_pred = []
  13. cate_vars = []
  14. dtrain = lgb.Dataset(X_train, label=y_train, categorical_feature=cate_vars)
  15. dval = lgb.Dataset(X_val, label=y_val, reference=dtrain, categorical_feature=cate_vars)
  16. bst = lgb.train(
  17. params,
  18. dtrain,
  19. num_boost_round=MAX_ROUNDS,
  20. valid_sets=[dtrain, dval] if dataset == "Corporacion" else [dval],
  21. callbacks=[early_stopping(model_param["early_stopping_rounds"], verbose=False)]
  22. )
  23. val_pred.append(bst.predict(X_val, num_iteration=bst.best_iteration or MAX_ROUNDS))
  24. return bst
  25. def train_ridge(X_train, y_train, X_val, y_val, dataset):
  26. pass
  27. def train_model(X_train, y_train, X_val, y_val, test_info):
  28. dataset = test_info["dataset"]
  29. model_type = test_info["model_type"]
  30. assert model_type in ["lgb", "ridge"]
  31. if model_type == "lgb":
  32. return train_lgb(X_train, y_train, X_val, y_val, dataset)