You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

workflow.py 1.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import fire
  2. from learnware.logger import get_module_logger
  3. from homo import HomogeneousDatasetWorkflow
  4. from hetero import HeterogeneousDatasetWorkflow
  5. from config import homo_table_benchmark_config, hetero_cross_feat_eng_benchmark_config, hetero_cross_task_benchmark_config
  6. from utils import set_seed
  7. logger = get_module_logger("base_table", level="INFO")
  8. class TableDatasetWorkflow:
  9. def unlabeled_homo_table_example(self):
  10. workflow = HomogeneousDatasetWorkflow(
  11. benchmark_config=homo_table_benchmark_config,
  12. name="easy",
  13. rebuild=True
  14. )
  15. workflow.unlabeled_homo_table_example()
  16. def labeled_homo_table_example(self):
  17. workflow = HomogeneousDatasetWorkflow(
  18. benchmark_config=homo_table_benchmark_config,
  19. name="easy",
  20. rebuild=False
  21. )
  22. workflow.labeled_homo_table_example()
  23. def cross_feat_eng_hetero_table_example(self):
  24. set_seed(0)
  25. workflow = HeterogeneousDatasetWorkflow(
  26. benchmark_config=hetero_cross_feat_eng_benchmark_config,
  27. name="hetero",
  28. rebuild=False,
  29. retrain=False
  30. )
  31. workflow.unlabeled_hetero_table_example()
  32. def cross_task_hetero_table_example(self):
  33. set_seed(0)
  34. workflow = HeterogeneousDatasetWorkflow(
  35. benchmark_config=hetero_cross_task_benchmark_config,
  36. name="hetero",
  37. rebuild=False,
  38. retrain=False
  39. )
  40. workflow.labeled_hetero_table_example()
  41. if __name__ == "__main__":
  42. fire.Fire(TableDatasetWorkflow)