|
|
|
@@ -14,14 +14,15 @@ def main(stage, |
|
|
|
precision, |
|
|
|
seed, |
|
|
|
dataset_path, |
|
|
|
k_fold, |
|
|
|
kth_fold_start, |
|
|
|
gpus=None, |
|
|
|
tpu_cores=None, |
|
|
|
version_nth=None, |
|
|
|
path_final_save=None, |
|
|
|
every_n_epochs=1, |
|
|
|
save_top_k=1, |
|
|
|
k_fold=5, |
|
|
|
kth_fold_start=0): |
|
|
|
): |
|
|
|
""" |
|
|
|
框架的入口函数. 包含设置超参数, 划分数据集, 选择训练或测试等流程 |
|
|
|
该函数的参数为训练过程中需要经常改动的参数 |
|
|
|
@@ -99,8 +100,8 @@ def main(stage, |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
main('fit', max_epochs=2, batch_size=32, precision=16, seed=1234, dataset_path='./dataset', k_fold=5, |
|
|
|
main('fit', max_epochs=2, batch_size=32, precision=16, seed=1234, dataset_path='./dataset', k_fold=10, |
|
|
|
kth_fold_start=9, |
|
|
|
# gpus=1, |
|
|
|
# version_nth=8, |
|
|
|
kth_fold_start=4, |
|
|
|
) |