diff --git a/main.py b/main.py index e6c8ce8..b9e1e89 100644 --- a/main.py +++ b/main.py @@ -14,14 +14,15 @@ def main(stage, precision, seed, dataset_path, + k_fold, + kth_fold_start, gpus=None, tpu_cores=None, version_nth=None, path_final_save=None, every_n_epochs=1, save_top_k=1, - k_fold=5, - kth_fold_start=0): + ): """ 框架的入口函数. 包含设置超参数, 划分数据集, 选择训练或测试等流程 该函数的参数为训练过程中需要经常改动的参数 @@ -99,8 +100,8 @@ def main(stage, if __name__ == "__main__": - main('fit', max_epochs=2, batch_size=32, precision=16, seed=1234, dataset_path='./dataset', k_fold=5, + main('fit', max_epochs=2, batch_size=32, precision=16, seed=1234, dataset_path='./dataset', k_fold=10, + kth_fold_start=9, # gpus=1, # version_nth=8, - kth_fold_start=4, )