From 24fdd1244e23d39b06d747321a215c23012d91c1 Mon Sep 17 00:00:00 2001 From: shenyan <23357320@qq.com> Date: Fri, 15 Oct 2021 20:29:55 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9main=E5=87=BD=E6=95=B0?= =?UTF-8?q?=E7=9A=84=E6=A0=BC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index e6c8ce8..b9e1e89 100644 --- a/main.py +++ b/main.py @@ -14,14 +14,15 @@ def main(stage, precision, seed, dataset_path, + k_fold, + kth_fold_start, gpus=None, tpu_cores=None, version_nth=None, path_final_save=None, every_n_epochs=1, save_top_k=1, - k_fold=5, - kth_fold_start=0): + ): """ 框架的入口函数. 包含设置超参数, 划分数据集, 选择训练或测试等流程 该函数的参数为训练过程中需要经常改动的参数 @@ -99,8 +100,8 @@ def main(stage, if __name__ == "__main__": - main('fit', max_epochs=2, batch_size=32, precision=16, seed=1234, dataset_path='./dataset', k_fold=5, + main('fit', max_epochs=2, batch_size=32, precision=16, seed=1234, dataset_path='./dataset', k_fold=10, + kth_fold_start=9, # gpus=1, # version_nth=8, - kth_fold_start=4, )