diff --git a/modules/modelarts/modelarts.go b/modules/modelarts/modelarts.go index b740b1167..e30d0100c 100755 --- a/modules/modelarts/modelarts.go +++ b/modules/modelarts/modelarts.go @@ -51,6 +51,8 @@ const ( DataUrl = "data_url" ResultUrl = "result_url" CkptUrl = "ckpt_url" + DeviceTarget = "device_target" + Ascend = "Ascend" PerPage = 10 IsLatestVersion = "1" NotLatestVersion = "0" diff --git a/routers/repo/modelarts.go b/routers/repo/modelarts.go index 9c670e203..b37c7b3b6 100755 --- a/routers/repo/modelarts.go +++ b/routers/repo/modelarts.go @@ -859,7 +859,6 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm) flavorCode := form.Flavor params := form.Params poolID := form.PoolID - isSaveParam := form.IsSaveParam repo := ctx.Repo.Repository codeLocalPath := setting.JobPath + jobName + modelarts.CodePath codeObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.CodePath @@ -953,17 +952,9 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm) return } - //todo: del local code? - var parameters models.Parameters param := make([]models.Parameter, 0) - param = append(param, models.Parameter{ - Label: modelarts.TrainUrl, - Value: outputObsPath, - }, models.Parameter{ - Label: modelarts.DataUrl, - Value: dataPath, - }) + existDeviceTarget := false if len(params) != 0 { err := json.Unmarshal([]byte(params), ¶meters) if err != nil { @@ -974,6 +965,9 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm) } for _, parameter := range parameters.Parameter { + if parameter.Label == modelarts.DeviceTarget { + existDeviceTarget = true + } if parameter.Label != modelarts.TrainUrl && parameter.Label != modelarts.DataUrl { param = append(param, models.Parameter{ Label: parameter.Label, @@ -982,39 +976,11 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm) } } } - - //save param config - if isSaveParam == "on" { - if form.ParameterTemplateName == "" { - log.Error("ParameterTemplateName is empty") - trainJobNewDataPrepare(ctx) - ctx.RenderWithErr("保存作业参数时,作业参数名称不能为空", tplModelArtsTrainJobNew, &form) - return - } - - _, err := modelarts.CreateTrainJobConfig(models.CreateConfigParams{ - ConfigName: form.ParameterTemplateName, - Description: form.PrameterDescription, - DataUrl: dataPath, - AppUrl: codeObsPath, - BootFileUrl: codeObsPath + bootFile, - TrainUrl: outputObsPath, - Flavor: models.Flavor{ - Code: flavorCode, - }, - WorkServerNum: workServerNumber, - EngineID: int64(engineID), - LogUrl: logObsPath, - PoolID: poolID, - Parameter: param, + if !existDeviceTarget { + param = append(param, models.Parameter{ + Label: modelarts.DeviceTarget, + Value: modelarts.Ascend, }) - - if err != nil { - log.Error("Failed to CreateTrainJobConfig: %v", err) - trainJobErrorNewDataPrepare(ctx, form) - ctx.RenderWithErr("保存作业参数失败:"+err.Error(), tplModelArtsTrainJobNew, &form) - return - } } req := &modelarts.GenerateTrainJobReq{ @@ -1032,7 +998,7 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm) LogUrl: logObsPath, PoolID: poolID, Uuid: uuid, - Parameters: parameters.Parameter, + Parameters: param, CommitID: commitID, IsLatestVersion: isLatestVersion, BranchName: branch_name, @@ -1096,7 +1062,6 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ flavorCode := form.Flavor params := form.Params poolID := form.PoolID - isSaveParam := form.IsSaveParam repo := ctx.Repo.Repository codeLocalPath := setting.JobPath + jobName + modelarts.CodePath codeObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.CodePath + VersionOutputPath + "/" @@ -1168,13 +1133,7 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ var parameters models.Parameters param := make([]models.Parameter, 0) - param = append(param, models.Parameter{ - Label: modelarts.TrainUrl, - Value: outputObsPath, - }, models.Parameter{ - Label: modelarts.DataUrl, - Value: dataPath, - }) + existDeviceTarget := true if len(params) != 0 { err := json.Unmarshal([]byte(params), ¶meters) if err != nil { @@ -1183,8 +1142,10 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ ctx.RenderWithErr("运行参数错误", tplModelArtsTrainJobVersionNew, &form) return } - for _, parameter := range parameters.Parameter { + if parameter.Label == modelarts.DeviceTarget { + existDeviceTarget = true + } if parameter.Label != modelarts.TrainUrl && parameter.Label != modelarts.DataUrl { param = append(param, models.Parameter{ Label: parameter.Label, @@ -1193,45 +1154,11 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ } } } - - //save param config - if isSaveParam == "on" { - if form.ParameterTemplateName == "" { - log.Error("ParameterTemplateName is empty") - versionErrorDataPrepare(ctx, form) - ctx.RenderWithErr("保存作业参数时,作业参数名称不能为空", tplModelArtsTrainJobVersionNew, &form) - return - } - - _, err := modelarts.CreateTrainJobConfig(models.CreateConfigParams{ - ConfigName: form.ParameterTemplateName, - Description: form.PrameterDescription, - DataUrl: dataPath, - AppUrl: codeObsPath, - BootFileUrl: codeObsPath + bootFile, - TrainUrl: outputObsPath, - Flavor: models.Flavor{ - Code: flavorCode, - }, - WorkServerNum: workServerNumber, - EngineID: int64(engineID), - LogUrl: logObsPath, - PoolID: poolID, - Parameter: parameters.Parameter, + if !existDeviceTarget { + param = append(param, models.Parameter{ + Label: modelarts.DeviceTarget, + Value: modelarts.Ascend, }) - - if err != nil { - log.Error("Failed to CreateTrainJobConfig: %v", err) - versionErrorDataPrepare(ctx, form) - ctx.RenderWithErr("保存作业参数失败:"+err.Error(), tplModelArtsTrainJobVersionNew, &form) - return - } - } - - if err != nil { - log.Error("getFlavorNameByEngineID(%s) failed:%v", engineID, err.Error()) - ctx.RenderWithErr(err.Error(), tplModelArtsTrainJobVersionNew, &form) - return } task, err := models.GetCloudbrainByJobIDAndVersionName(jobID, PreVersionName) @@ -1257,7 +1184,7 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ PoolID: poolID, Uuid: uuid, Params: form.Params, - Parameters: parameters.Parameter, + Parameters: param, PreVersionId: task.VersionID, CommitID: commitID, BranchName: branch_name, @@ -1782,7 +1709,6 @@ func InferenceJobCreate(ctx *context.Context, form auth.CreateModelArtsInference return } - //todo: del local code? var parameters models.Parameters param := make([]models.Parameter, 0) param = append(param, models.Parameter{ @@ -1792,6 +1718,7 @@ func InferenceJobCreate(ctx *context.Context, form auth.CreateModelArtsInference Label: modelarts.CkptUrl, Value: "s3:/" + ckptUrl, }) + existDeviceTarget := false if len(params) != 0 { err := json.Unmarshal([]byte(params), ¶meters) if err != nil { @@ -1802,6 +1729,9 @@ func InferenceJobCreate(ctx *context.Context, form auth.CreateModelArtsInference } for _, parameter := range parameters.Parameter { + if parameter.Label == modelarts.DeviceTarget { + existDeviceTarget = true + } if parameter.Label != modelarts.TrainUrl && parameter.Label != modelarts.DataUrl { param = append(param, models.Parameter{ Label: parameter.Label, @@ -1810,6 +1740,12 @@ func InferenceJobCreate(ctx *context.Context, form auth.CreateModelArtsInference } } } + if !existDeviceTarget { + param = append(param, models.Parameter{ + Label: modelarts.DeviceTarget, + Value: modelarts.Ascend, + }) + } req := &modelarts.GenerateInferenceJobReq{ JobName: jobName,