Browse Source

修改

tags/v1.21.12.1
liuzx 4 years ago
parent
commit
2dff30a6c9
4 changed files with 29 additions and 24 deletions
  1. +1
    -0
      models/cloudbrain.go
  2. +2
    -1
      modules/auth/modelarts.go
  3. +13
    -0
      modules/modelarts/modelarts.go
  4. +13
    -23
      routers/repo/modelarts.go

+ 1
- 0
models/cloudbrain.go View File

@@ -91,6 +91,7 @@ type Cloudbrain struct {
Description string Description string
WorkServerNumber int WorkServerNumber int
FlavorName string FlavorName string
EngineName string


User *User `xorm:"-"` User *User `xorm:"-"`
Repo *Repository `xorm:"-"` Repo *Repository `xorm:"-"`


+ 2
- 1
modules/auth/modelarts.go View File

@@ -19,7 +19,7 @@ type CreateModelArtsNotebookForm struct {
JobName string `form:"job_name" binding:"Required"` JobName string `form:"job_name" binding:"Required"`
Attachment string `form:"attachment"` Attachment string `form:"attachment"`
Description string `form:"description"` Description string `form:"description"`
Flavor string `form:"flavor"`
Flavor string `form:"flavor"`
} }


func (f *CreateModelArtsNotebookForm) Validate(ctx *macaron.Context, errs binding.Errors) binding.Errors { func (f *CreateModelArtsNotebookForm) Validate(ctx *macaron.Context, errs binding.Errors) binding.Errors {
@@ -42,6 +42,7 @@ type CreateModelArtsTrainJobForm struct {
BranchName string `form:"branch_name" binding:"Required"` BranchName string `form:"branch_name" binding:"Required"`
VersionName string `form:"version_name" binding:"Required"` VersionName string `form:"version_name" binding:"Required"`
FlavorName string `form:"flavor_name" binding:"Required"` FlavorName string `form:"flavor_name" binding:"Required"`
EngineName string `form:"engine_name" binding:"Required"`
} }


func (f *CreateModelArtsTrainJobForm) Validate(ctx *macaron.Context, errs binding.Errors) binding.Errors { func (f *CreateModelArtsTrainJobForm) Validate(ctx *macaron.Context, errs binding.Errors) binding.Errors {


+ 13
- 0
modules/modelarts/modelarts.go View File

@@ -82,6 +82,7 @@ type GenerateTrainJobReq struct {
FatherVersionName string FatherVersionName string
FlavorName string FlavorName string
VersionCount int VersionCount int
EngineName string
} }


type GenerateTrainJobVersionReq struct { type GenerateTrainJobVersionReq struct {
@@ -104,6 +105,7 @@ type GenerateTrainJobVersionReq struct {
CommitID string CommitID string
BranchName string BranchName string
FlavorName string FlavorName string
EngineName string
} }


type VersionInfo struct { type VersionInfo struct {
@@ -134,6 +136,15 @@ type ResourcePool struct {
} `json:"resource_pool"` } `json:"resource_pool"`
} }


type Parameter struct {
Label string `json:"label"`
Value string `json:"value"`
}

type Parameters struct {
Parameter []Parameter `json:"parameter"`
}

func GenerateTask(ctx *context.Context, jobName, uuid, description, flavor string) error { func GenerateTask(ctx *context.Context, jobName, uuid, description, flavor string) error {
var dataActualPath string var dataActualPath string
if uuid != "" { if uuid != "" {
@@ -263,6 +274,7 @@ func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (err error
Description: req.Description, Description: req.Description,
WorkServerNumber: req.WorkServerNumber, WorkServerNumber: req.WorkServerNumber,
FlavorName: req.FlavorName, FlavorName: req.FlavorName,
EngineName: req.EngineName,
VersionCount: req.VersionCount, VersionCount: req.VersionCount,
}) })


@@ -331,6 +343,7 @@ func GenerateTrainJobVersion(ctx *context.Context, req *GenerateTrainJobVersionR
Description: req.Description, Description: req.Description,
WorkServerNumber: req.WorkServerNumber, WorkServerNumber: req.WorkServerNumber,
FlavorName: req.FlavorName, FlavorName: req.FlavorName,
EngineName: req.EngineName,
}) })
if err != nil { if err != nil {
log.Error("CreateCloudbrain(%s) failed:%v", req.JobName, err.Error()) log.Error("CreateCloudbrain(%s) failed:%v", req.JobName, err.Error())


+ 13
- 23
routers/repo/modelarts.go View File

@@ -441,6 +441,13 @@ func trainJobNewVersionDataPrepare(ctx *context.Context) error {
} }
ctx.Data["flavor_infos"] = flavorInfos.Info ctx.Data["flavor_infos"] = flavorInfos.Info


var Parameters modelarts.Parameters
if err = json.Unmarshal([]byte(task.Parameters), &Parameters); err != nil {
ctx.ServerError("json.Unmarshal failed:", err)
return err
}
ctx.Data["params"] = Parameters.Parameter

outputObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.OutputPath outputObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.OutputPath
ctx.Data["train_url"] = outputObsPath ctx.Data["train_url"] = outputObsPath


@@ -454,9 +461,9 @@ func trainJobNewVersionDataPrepare(ctx *context.Context) error {
ctx.Data["description"] = task.Description ctx.Data["description"] = task.Description
ctx.Data["boot_file"] = task.BootFile ctx.Data["boot_file"] = task.BootFile
ctx.Data["dataset_name"] = task.DatasetName ctx.Data["dataset_name"] = task.DatasetName
ctx.Data["params"] = task.Parameters
ctx.Data["work_server_number"] = task.WorkServerNumber ctx.Data["work_server_number"] = task.WorkServerNumber
ctx.Data["flavor_name"] = task.FlavorName ctx.Data["flavor_name"] = task.FlavorName
ctx.Data["engine_name"] = task.FlavorName
ctx.Data["uuid"] = task.Uuid ctx.Data["uuid"] = task.Uuid
ctx.Data["flavor_code"] = task.FlavorCode ctx.Data["flavor_code"] = task.FlavorCode
ctx.Data["engine_id"] = task.EngineID ctx.Data["engine_id"] = task.EngineID
@@ -493,6 +500,7 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm)
isLatestVersion := modelarts.IsLatestVersion isLatestVersion := modelarts.IsLatestVersion
FlavorName := form.FlavorName FlavorName := form.FlavorName
VersionCount := modelarts.VersionCount VersionCount := modelarts.VersionCount
EngineName := form.EngineName


if err := paramCheckCreateTrainJob(form); err != nil { if err := paramCheckCreateTrainJob(form); err != nil {
log.Error("paramCheckCreateTrainJob failed:(%v)", err) log.Error("paramCheckCreateTrainJob failed:(%v)", err)
@@ -619,13 +627,6 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm)
return return
} }
} }
//将引擎id转化为引擎名称
FlavorName, err = getFlavorNameByEngineID(engineID)
if err != nil {
log.Error("getFlavorNameByEngineID(%s) failed:%v", engineID, err.Error())
ctx.RenderWithErr(err.Error(), tplModelArtsTrainJobNew, &form)
return
}


req := &modelarts.GenerateTrainJobReq{ req := &modelarts.GenerateTrainJobReq{
JobName: jobName, JobName: jobName,
@@ -648,6 +649,7 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm)
Params: form.Params, Params: form.Params,
FatherVersionName: modelarts.InitFatherVersionName, FatherVersionName: modelarts.InitFatherVersionName,
FlavorName: FlavorName, FlavorName: FlavorName,
EngineName: EngineName,
VersionCount: VersionCount, VersionCount: VersionCount,
} }


@@ -719,6 +721,7 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ
branch_name := form.BranchName branch_name := form.BranchName
fatherVersionName := versionName fatherVersionName := versionName
FlavorName := form.FlavorName FlavorName := form.FlavorName
EngineName := form.EngineName


if err := paramCheckCreateTrainJob(form); err != nil { if err := paramCheckCreateTrainJob(form); err != nil {
log.Error("paramCheckCreateTrainJob failed:(%v)", err) log.Error("paramCheckCreateTrainJob failed:(%v)", err)
@@ -843,8 +846,7 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ
return return
} }
} }
//将引擎id转化为引擎名称
FlavorName, err = getFlavorNameByEngineID(engineID)

if err != nil { if err != nil {
log.Error("getFlavorNameByEngineID(%s) failed:%v", engineID, err.Error()) log.Error("getFlavorNameByEngineID(%s) failed:%v", engineID, err.Error())
ctx.RenderWithErr(err.Error(), tplModelArtsTrainJobVersionNew, &form) ctx.RenderWithErr(err.Error(), tplModelArtsTrainJobVersionNew, &form)
@@ -876,6 +878,7 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ
CommitID: commitID, CommitID: commitID,
BranchName: branch_name, BranchName: branch_name,
FlavorName: FlavorName, FlavorName: FlavorName,
EngineName: EngineName,
} }
err = modelarts.GenerateTrainJobVersion(ctx, req, jobID, fatherVersionName) err = modelarts.GenerateTrainJobVersion(ctx, req, jobID, fatherVersionName)
if err != nil { if err != nil {
@@ -920,19 +923,6 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ
ctx.HTML(http.StatusOK, tplModelArtsTrainJobShow) ctx.HTML(http.StatusOK, tplModelArtsTrainJobShow)
} }


func getFlavorNameByEngineID(engineID int) (FlavorName string, err error) {
if engineID == 121 {
FlavorName = "TensorFlow-1.15"
return FlavorName, nil
} else if engineID == 122 {
FlavorName = "Mindspore-1.3"
return FlavorName, nil
} else {
log.Error("getFlavorName failed:%v", errors.New("getFlavorName failed"))
return "getFlavorName failed:", errors.New("getFlavorName failed")
}
}

// readDir reads the directory named by dirname and returns // readDir reads the directory named by dirname and returns
// a list of directory entries sorted by filename. // a list of directory entries sorted by filename.
func readDir(dirname string) ([]os.FileInfo, error) { func readDir(dirname string) ([]os.FileInfo, error) {


Loading…
Cancel
Save