| @@ -14,14 +14,15 @@ def main(stage, | |||||
| precision, | precision, | ||||
| seed, | seed, | ||||
| dataset_path, | dataset_path, | ||||
| k_fold, | |||||
| kth_fold_start, | |||||
| gpus=None, | gpus=None, | ||||
| tpu_cores=None, | tpu_cores=None, | ||||
| version_nth=None, | version_nth=None, | ||||
| path_final_save=None, | path_final_save=None, | ||||
| every_n_epochs=1, | every_n_epochs=1, | ||||
| save_top_k=1, | save_top_k=1, | ||||
| k_fold=5, | |||||
| kth_fold_start=0): | |||||
| ): | |||||
| """ | """ | ||||
| 框架的入口函数. 包含设置超参数, 划分数据集, 选择训练或测试等流程 | 框架的入口函数. 包含设置超参数, 划分数据集, 选择训练或测试等流程 | ||||
| 该函数的参数为训练过程中需要经常改动的参数 | 该函数的参数为训练过程中需要经常改动的参数 | ||||
| @@ -99,8 +100,8 @@ def main(stage, | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| main('fit', max_epochs=2, batch_size=32, precision=16, seed=1234, dataset_path='./dataset', k_fold=5, | |||||
| main('fit', max_epochs=2, batch_size=32, precision=16, seed=1234, dataset_path='./dataset', k_fold=10, | |||||
| kth_fold_start=9, | |||||
| # gpus=1, | # gpus=1, | ||||
| # version_nth=8, | # version_nth=8, | ||||
| kth_fold_start=4, | |||||
| ) | ) | ||||