You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 1.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. import lightgbm as lgb
  2. from config import user_model_params
  3. from lightgbm import early_stopping
  4. from learnware.logger import get_module_logger
  5. logger = get_module_logger("train_table", level="INFO")
  6. def train_lgb(X_train, y_train, X_val, y_val, dataset):
  7. model_param = user_model_params[dataset]["lgb"]
  8. params = model_param["params"]
  9. MAX_ROUNDS = model_param["MAX_ROUNDS"]
  10. val_pred = []
  11. cate_vars = []
  12. dtrain = lgb.Dataset(X_train, label=y_train, categorical_feature=cate_vars)
  13. dval = lgb.Dataset(X_val, label=y_val, reference=dtrain, categorical_feature=cate_vars)
  14. bst = lgb.train(
  15. params,
  16. dtrain,
  17. num_boost_round=MAX_ROUNDS,
  18. valid_sets=[dtrain, dval] if dataset == "Corporacion" else [dval],
  19. callbacks=[early_stopping(model_param["early_stopping_rounds"], verbose=False)],
  20. )
  21. val_pred.append(bst.predict(X_val, num_iteration=bst.best_iteration or MAX_ROUNDS))
  22. return bst
  23. def train_ridge(X_train, y_train, X_val, y_val, dataset):
  24. pass
  25. def train_model(X_train, y_train, X_val, y_val, test_info):
  26. dataset = test_info["dataset"]
  27. model_type = test_info["model_type"]
  28. assert model_type in ["lgb", "ridge"]
  29. if model_type == "lgb":
  30. return train_lgb(X_train, y_train, X_val, y_val, dataset)