diff --git a/routers/repo/modelarts.go b/routers/repo/modelarts.go index b37c7b3b6..73d3172a5 100755 --- a/routers/repo/modelarts.go +++ b/routers/repo/modelarts.go @@ -859,6 +859,7 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm) flavorCode := form.Flavor params := form.Params poolID := form.PoolID + isSaveParam := form.IsSaveParam repo := ctx.Repo.Repository codeLocalPath := setting.JobPath + jobName + modelarts.CodePath codeObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.CodePath @@ -983,6 +984,40 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm) }) } + //save param config + if isSaveParam == "on" { + if form.ParameterTemplateName == "" { + log.Error("ParameterTemplateName is empty") + trainJobNewDataPrepare(ctx) + ctx.RenderWithErr("保存作业参数时,作业参数名称不能为空", tplModelArtsTrainJobNew, &form) + return + } + + _, err := modelarts.CreateTrainJobConfig(models.CreateConfigParams{ + ConfigName: form.ParameterTemplateName, + Description: form.PrameterDescription, + DataUrl: dataPath, + AppUrl: codeObsPath, + BootFileUrl: codeObsPath + bootFile, + TrainUrl: outputObsPath, + Flavor: models.Flavor{ + Code: flavorCode, + }, + WorkServerNum: workServerNumber, + EngineID: int64(engineID), + LogUrl: logObsPath, + PoolID: poolID, + Parameter: param, + }) + + if err != nil { + log.Error("Failed to CreateTrainJobConfig: %v", err) + trainJobErrorNewDataPrepare(ctx, form) + ctx.RenderWithErr("保存作业参数失败:"+err.Error(), tplModelArtsTrainJobNew, &form) + return + } + } + req := &modelarts.GenerateTrainJobReq{ JobName: jobName, DisplayJobName: displayJobName, @@ -1062,6 +1097,7 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ flavorCode := form.Flavor params := form.Params poolID := form.PoolID + isSaveParam := form.IsSaveParam repo := ctx.Repo.Repository codeLocalPath := setting.JobPath + jobName + modelarts.CodePath codeObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.CodePath + VersionOutputPath + "/" @@ -1161,6 +1197,40 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ }) } + //save param config + if isSaveParam == "on" { + if form.ParameterTemplateName == "" { + log.Error("ParameterTemplateName is empty") + versionErrorDataPrepare(ctx, form) + ctx.RenderWithErr("保存作业参数时,作业参数名称不能为空", tplModelArtsTrainJobVersionNew, &form) + return + } + + _, err := modelarts.CreateTrainJobConfig(models.CreateConfigParams{ + ConfigName: form.ParameterTemplateName, + Description: form.PrameterDescription, + DataUrl: dataPath, + AppUrl: codeObsPath, + BootFileUrl: codeObsPath + bootFile, + TrainUrl: outputObsPath, + Flavor: models.Flavor{ + Code: flavorCode, + }, + WorkServerNum: workServerNumber, + EngineID: int64(engineID), + LogUrl: logObsPath, + PoolID: poolID, + Parameter: parameters.Parameter, + }) + + if err != nil { + log.Error("Failed to CreateTrainJobConfig: %v", err) + versionErrorDataPrepare(ctx, form) + ctx.RenderWithErr("保存作业参数失败:"+err.Error(), tplModelArtsTrainJobVersionNew, &form) + return + } + } + task, err := models.GetCloudbrainByJobIDAndVersionName(jobID, PreVersionName) if err != nil { log.Error("GetCloudbrainByJobIDAndVersionName(%s) failed:%v", jobID, err.Error())