diff --git a/models/cloudbrain.go b/models/cloudbrain.go index d282c130d..20e643884 100755 --- a/models/cloudbrain.go +++ b/models/cloudbrain.go @@ -91,6 +91,7 @@ type Cloudbrain struct { Description string WorkServerNumber int FlavorName string + EngineName string User *User `xorm:"-"` Repo *Repository `xorm:"-"` diff --git a/modules/auth/modelarts.go b/modules/auth/modelarts.go index 2cbca1c08..97ca65c2d 100755 --- a/modules/auth/modelarts.go +++ b/modules/auth/modelarts.go @@ -19,7 +19,7 @@ type CreateModelArtsNotebookForm struct { JobName string `form:"job_name" binding:"Required"` Attachment string `form:"attachment"` Description string `form:"description"` - Flavor string `form:"flavor"` + Flavor string `form:"flavor"` } func (f *CreateModelArtsNotebookForm) Validate(ctx *macaron.Context, errs binding.Errors) binding.Errors { @@ -42,6 +42,7 @@ type CreateModelArtsTrainJobForm struct { BranchName string `form:"branch_name" binding:"Required"` VersionName string `form:"version_name" binding:"Required"` FlavorName string `form:"flavor_name" binding:"Required"` + EngineName string `form:"engine_name" binding:"Required"` } func (f *CreateModelArtsTrainJobForm) Validate(ctx *macaron.Context, errs binding.Errors) binding.Errors { diff --git a/modules/modelarts/modelarts.go b/modules/modelarts/modelarts.go index cd07212f9..88319290a 100755 --- a/modules/modelarts/modelarts.go +++ b/modules/modelarts/modelarts.go @@ -82,6 +82,7 @@ type GenerateTrainJobReq struct { FatherVersionName string FlavorName string VersionCount int + EngineName string } type GenerateTrainJobVersionReq struct { @@ -104,6 +105,7 @@ type GenerateTrainJobVersionReq struct { CommitID string BranchName string FlavorName string + EngineName string } type VersionInfo struct { @@ -134,6 +136,15 @@ type ResourcePool struct { } `json:"resource_pool"` } +type Parameter struct { + Label string `json:"label"` + Value string `json:"value"` +} + +type Parameters struct { + Parameter []Parameter `json:"parameter"` +} + func GenerateTask(ctx *context.Context, jobName, uuid, description, flavor string) error { var dataActualPath string if uuid != "" { @@ -263,6 +274,7 @@ func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (err error Description: req.Description, WorkServerNumber: req.WorkServerNumber, FlavorName: req.FlavorName, + EngineName: req.EngineName, VersionCount: req.VersionCount, }) @@ -331,6 +343,7 @@ func GenerateTrainJobVersion(ctx *context.Context, req *GenerateTrainJobVersionR Description: req.Description, WorkServerNumber: req.WorkServerNumber, FlavorName: req.FlavorName, + EngineName: req.EngineName, }) if err != nil { log.Error("CreateCloudbrain(%s) failed:%v", req.JobName, err.Error()) diff --git a/routers/repo/modelarts.go b/routers/repo/modelarts.go index a91dbdb3a..28ad313d5 100755 --- a/routers/repo/modelarts.go +++ b/routers/repo/modelarts.go @@ -441,6 +441,13 @@ func trainJobNewVersionDataPrepare(ctx *context.Context) error { } ctx.Data["flavor_infos"] = flavorInfos.Info + var Parameters modelarts.Parameters + if err = json.Unmarshal([]byte(task.Parameters), &Parameters); err != nil { + ctx.ServerError("json.Unmarshal failed:", err) + return err + } + ctx.Data["params"] = Parameters.Parameter + outputObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.OutputPath ctx.Data["train_url"] = outputObsPath @@ -454,9 +461,9 @@ func trainJobNewVersionDataPrepare(ctx *context.Context) error { ctx.Data["description"] = task.Description ctx.Data["boot_file"] = task.BootFile ctx.Data["dataset_name"] = task.DatasetName - ctx.Data["params"] = task.Parameters ctx.Data["work_server_number"] = task.WorkServerNumber ctx.Data["flavor_name"] = task.FlavorName + ctx.Data["engine_name"] = task.FlavorName ctx.Data["uuid"] = task.Uuid ctx.Data["flavor_code"] = task.FlavorCode ctx.Data["engine_id"] = task.EngineID @@ -493,6 +500,7 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm) isLatestVersion := modelarts.IsLatestVersion FlavorName := form.FlavorName VersionCount := modelarts.VersionCount + EngineName := form.EngineName if err := paramCheckCreateTrainJob(form); err != nil { log.Error("paramCheckCreateTrainJob failed:(%v)", err) @@ -619,13 +627,6 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm) return } } - //将引擎id转化为引擎名称 - FlavorName, err = getFlavorNameByEngineID(engineID) - if err != nil { - log.Error("getFlavorNameByEngineID(%s) failed:%v", engineID, err.Error()) - ctx.RenderWithErr(err.Error(), tplModelArtsTrainJobNew, &form) - return - } req := &modelarts.GenerateTrainJobReq{ JobName: jobName, @@ -648,6 +649,7 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm) Params: form.Params, FatherVersionName: modelarts.InitFatherVersionName, FlavorName: FlavorName, + EngineName: EngineName, VersionCount: VersionCount, } @@ -719,6 +721,7 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ branch_name := form.BranchName fatherVersionName := versionName FlavorName := form.FlavorName + EngineName := form.EngineName if err := paramCheckCreateTrainJob(form); err != nil { log.Error("paramCheckCreateTrainJob failed:(%v)", err) @@ -843,8 +846,7 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ return } } - //将引擎id转化为引擎名称 - FlavorName, err = getFlavorNameByEngineID(engineID) + if err != nil { log.Error("getFlavorNameByEngineID(%s) failed:%v", engineID, err.Error()) ctx.RenderWithErr(err.Error(), tplModelArtsTrainJobVersionNew, &form) @@ -876,6 +878,7 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ CommitID: commitID, BranchName: branch_name, FlavorName: FlavorName, + EngineName: EngineName, } err = modelarts.GenerateTrainJobVersion(ctx, req, jobID, fatherVersionName) if err != nil { @@ -920,19 +923,6 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ ctx.HTML(http.StatusOK, tplModelArtsTrainJobShow) } -func getFlavorNameByEngineID(engineID int) (FlavorName string, err error) { - if engineID == 121 { - FlavorName = "TensorFlow-1.15" - return FlavorName, nil - } else if engineID == 122 { - FlavorName = "Mindspore-1.3" - return FlavorName, nil - } else { - log.Error("getFlavorName failed:%v", errors.New("getFlavorName failed")) - return "getFlavorName failed:", errors.New("getFlavorName failed") - } -} - // readDir reads the directory named by dirname and returns // a list of directory entries sorted by filename. func readDir(dirname string) ([]os.FileInfo, error) {