diff --git a/models/cloudbrain.go b/models/cloudbrain.go index 00d4b00e6..7d9d951a9 100755 --- a/models/cloudbrain.go +++ b/models/cloudbrain.go @@ -167,6 +167,7 @@ type TaskInfo struct { CodeName string `json:"code_name"` BenchmarkCategory []string `json:"selected_category"` CodeLink string `json:"code_link"` + GpuType string `json:"gpu_type"` } func ConvertToTaskPod(input map[string]interface{}) (TaskPod, error) { diff --git a/routers/repo/cloudbrain.go b/routers/repo/cloudbrain.go index 0c1b34836..f3d28eb5f 100755 --- a/routers/repo/cloudbrain.go +++ b/routers/repo/cloudbrain.go @@ -170,12 +170,12 @@ func CloudBrainCreate(ctx *context.Context, form auth.CreateCloudBrainForm) { benchmarkPath := setting.JobPath + jobName + cloudbrain.BenchMarkMountPath if setting.IsBenchmarkEnabled && jobType == string(models.JobTypeBenchmark) { gpuType = form.GpuType - downloadRateCode(repo, jobName, setting.BenchmarkCode, benchmarkPath, form.BenchmarkCategory) + downloadRateCode(repo, jobName, setting.BenchmarkCode, benchmarkPath, form.BenchmarkCategory, gpuType) } snn4imagenetPath := setting.JobPath + jobName + cloudbrain.Snn4imagenetMountPath if setting.IsSnn4imagenetEnabled && jobType == string(models.JobTypeSnn4imagenet) { - downloadRateCode(repo, jobName, setting.Snn4imagenetCode, snn4imagenetPath, "") + downloadRateCode(repo, jobName, setting.Snn4imagenetCode, snn4imagenetPath, "", "") } err = cloudbrain.GenerateTask(ctx, jobName, image, command, uuid, codePath, modelPath, benchmarkPath, snn4imagenetPath, jobType, gpuType) @@ -344,7 +344,7 @@ func downloadCode(repo *models.Repository, codePath string) error { return nil } -func downloadRateCode(repo *models.Repository, taskName, gitPath, codePath, benchmarkCategory string) error { +func downloadRateCode(repo *models.Repository, taskName, gitPath, codePath, benchmarkCategory, gpuType string) error { err := os.MkdirAll(codePath, os.ModePerm) if err != nil { log.Error("mkdir codePath failed", err.Error()) @@ -375,6 +375,7 @@ func downloadRateCode(repo *models.Repository, taskName, gitPath, codePath, benc CodeName: repo.Name, BenchmarkCategory: strings.Split(benchmarkCategory, ","), CodeLink: strings.TrimSuffix(repo.CloneLink().HTTPS, ".git"), + GpuType: gpuType, }) if err != nil { log.Error("json.Marshal failed", err.Error())