You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

workflow.py 1.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import fire
  2. from config import (
  3. hetero_cross_feat_eng_benchmark_config,
  4. hetero_cross_task_benchmark_config,
  5. homo_table_benchmark_config,
  6. )
  7. from hetero import HeterogeneousDatasetWorkflow
  8. from homo import HomogeneousDatasetWorkflow
  9. from learnware.logger import get_module_logger
  10. logger = get_module_logger("base_table", level="INFO")
  11. class TableDatasetWorkflow:
  12. def unlabeled_homo_table_example(self, rebuild=True):
  13. workflow = HomogeneousDatasetWorkflow(
  14. benchmark_config=homo_table_benchmark_config, name="easy", rebuild=rebuild
  15. )
  16. workflow.unlabeled_homo_table_example()
  17. def labeled_homo_table_example(self, skip_test=False, rebuild=True):
  18. workflow = HomogeneousDatasetWorkflow(
  19. benchmark_config=homo_table_benchmark_config, name="easy", rebuild=rebuild
  20. )
  21. workflow.labeled_homo_table_example(skip_test=skip_test)
  22. def cross_feat_eng_hetero_table_example(self, rebuild=True, retrain=True):
  23. workflow = HeterogeneousDatasetWorkflow(
  24. benchmark_config=hetero_cross_feat_eng_benchmark_config, name="hetero", rebuild=rebuild, retrain=retrain
  25. )
  26. workflow.unlabeled_hetero_table_example()
  27. def cross_task_hetero_table_example(self, skip_test=False, rebuild=True, retrain=True):
  28. workflow = HeterogeneousDatasetWorkflow(
  29. benchmark_config=hetero_cross_task_benchmark_config, name="hetero", rebuild=rebuild, retrain=retrain
  30. )
  31. workflow.labeled_hetero_table_example(skip_test=skip_test)
  32. if __name__ == "__main__":
  33. fire.Fire(TableDatasetWorkflow)