import fire from config import ( hetero_cross_feat_eng_benchmark_config, hetero_cross_task_benchmark_config, homo_table_benchmark_config, ) from hetero import HeterogeneousDatasetWorkflow from homo import HomogeneousDatasetWorkflow from learnware.logger import get_module_logger logger = get_module_logger("base_table", level="INFO") class TableDatasetWorkflow: def unlabeled_homo_table_example(self, rebuild=True): workflow = HomogeneousDatasetWorkflow( benchmark_config=homo_table_benchmark_config, name="easy", rebuild=rebuild ) workflow.unlabeled_homo_table_example() def labeled_homo_table_example(self, skip_test=False, rebuild=True): workflow = HomogeneousDatasetWorkflow( benchmark_config=homo_table_benchmark_config, name="easy", rebuild=rebuild ) workflow.labeled_homo_table_example(skip_test=skip_test) def cross_feat_eng_hetero_table_example(self, rebuild=True, retrain=True): workflow = HeterogeneousDatasetWorkflow( benchmark_config=hetero_cross_feat_eng_benchmark_config, name="hetero", rebuild=rebuild, retrain=retrain ) workflow.unlabeled_hetero_table_example() def cross_task_hetero_table_example(self, skip_test=False, rebuild=True, retrain=True): workflow = HeterogeneousDatasetWorkflow( benchmark_config=hetero_cross_task_benchmark_config, name="hetero", rebuild=rebuild, retrain=retrain ) workflow.labeled_hetero_table_example(skip_test=skip_test) if __name__ == "__main__": fire.Fire(TableDatasetWorkflow)