|
|
|
@@ -57,7 +57,7 @@ def main(stage, |
|
|
|
gpus = 1 if torch.cuda.is_available() and gpus is None and tpu_cores is None else None |
|
|
|
# 定义不常改动的通用参数 |
|
|
|
# TODO 获得最优的batch size |
|
|
|
num_workers = cpu_count() |
|
|
|
num_workers = min([cpu_count(), 8]) |
|
|
|
# 获得非通用参数 |
|
|
|
config = {'dim_in': 32, } |
|
|
|
for kth_fold in range(kth_fold_start, k_fold): |
|
|
|
|