Browse Source

update

tags/v1.22.1.2
liuzx 3 years ago
parent
commit
2ba72dcb25
5 changed files with 88 additions and 0 deletions
  1. +1
    -0
      models/cloudbrain.go
  2. +1
    -0
      modules/auth/modelarts.go
  3. +2
    -0
      modules/modelarts/modelarts.go
  4. +83
    -0
      routers/repo/modelarts.go
  5. +1
    -0
      routers/routes/routes.go

+ 1
- 0
models/cloudbrain.go View File

@@ -126,6 +126,7 @@ type Cloudbrain struct {
EngineName string //引擎名称
TotalVersionCount int //任务的所有版本数量,包括删除的

LabelName string //标签名称
ModelName string //模型名称
ModelVersion string //模型版本
CkptName string //权重文件名称


+ 1
- 0
modules/auth/modelarts.go View File

@@ -62,6 +62,7 @@ type CreateModelArtsInferenceJobForm struct {
VersionName string `form:"version_name" binding:"Required"`
FlavorName string `form:"flaver_names" binding:"Required"`
EngineName string `form:"engine_names" binding:"Required"`
LabelName string `form:"label_names" binding:"Required"`
TrainUrl string `form:"train_url" binding:"Required"`
ModelName string `form:"model_name" binding:"Required"`
ModelVersion string `form:"model_version" binding:"Required"`


+ 2
- 0
modules/modelarts/modelarts.go View File

@@ -136,6 +136,7 @@ type GenerateInferenceJobReq struct {
BranchName string
FlavorName string
EngineName string
LabelName string
IsLatestVersion string
VersionCount int
TotalVersionCount int
@@ -535,6 +536,7 @@ func GenerateInferenceJob(ctx *context.Context, req *GenerateInferenceJobReq) (e
WorkServerNumber: req.WorkServerNumber,
FlavorName: req.FlavorName,
EngineName: req.EngineName,
LabelName: req.LabelName,
IsLatestVersion: req.IsLatestVersion,
VersionCount: req.VersionCount,
TotalVersionCount: req.TotalVersionCount,


+ 83
- 0
routers/repo/modelarts.go View File

@@ -1,6 +1,7 @@
package repo

import (
"archive/zip"
"encoding/json"
"errors"
"io"
@@ -1284,6 +1285,19 @@ func paramCheckCreateInferenceJob(form auth.CreateModelArtsInferenceJobForm) err
return errors.New("计算节点数必须在1-25之间")
}

if form.ModelName == "" {
log.Error("the ModelName(%d) must not be nil", form.ModelName)
return errors.New("模型名称不能为空")
}
if form.ModelVersion == "" {
log.Error("the ModelVersion(%d) must not be nil", form.ModelVersion)
return errors.New("模型版本不能为空")
}
if form.CkptName == "" {
log.Error("the CkptName(%d) must not be nil", form.CkptName)
return errors.New("权重文件不能为空")
}

return nil
}

@@ -1564,6 +1578,7 @@ func InferenceJobCreate(ctx *context.Context, form auth.CreateModelArtsInference
branch_name := form.BranchName
FlavorName := form.FlavorName
EngineName := form.EngineName
LabelName := form.LabelName
isLatestVersion := modelarts.IsLatestVersion
VersionCount := modelarts.VersionCount
trainUrl := form.TrainUrl
@@ -1694,6 +1709,7 @@ func InferenceJobCreate(ctx *context.Context, form auth.CreateModelArtsInference
Params: form.Params,
FlavorName: FlavorName,
EngineName: EngineName,
LabelName: LabelName,
IsLatestVersion: isLatestVersion,
VersionCount: VersionCount,
TotalVersionCount: modelarts.TotalVersionCount,
@@ -2018,3 +2034,70 @@ func DeleteJobStorage(jobName string) error {

return nil
}

func DownloadMultiResultFile(ctx *context.Context) {
log.Info("DownloadMultiModelFile start.")
id := ctx.Query("ID")
log.Info("id=" + id)
task, err := models.QueryModelById(id)
if err != nil {
log.Error("no such model!", err.Error())
ctx.ServerError("no such model:", err)
return
}
if !isCanDeleteOrDownload(ctx, task) {
ctx.ServerError("no right.", errors.New(ctx.Tr("repo.model_noright")))
return
}

path := Model_prefix + models.AttachmentRelativePath(id) + "/"

allFile, err := storage.GetAllObjectByBucketAndPrefix(setting.Bucket, path)
if err == nil {
//count++
models.ModifyModelDownloadCount(id)

returnFileName := task.Name + "_" + task.Version + ".zip"
ctx.Resp.Header().Set("Content-Disposition", "attachment; filename="+returnFileName)
ctx.Resp.Header().Set("Content-Type", "application/octet-stream")
w := zip.NewWriter(ctx.Resp)
defer w.Close()
for _, oneFile := range allFile {
if oneFile.IsDir {
log.Info("zip dir name:" + oneFile.FileName)
} else {
log.Info("zip file name:" + oneFile.FileName)
fDest, err := w.Create(oneFile.FileName)
if err != nil {
log.Info("create zip entry error, download file failed: %s\n", err.Error())
ctx.ServerError("download file failed:", err)
return
}
body, err := storage.ObsDownloadAFile(setting.Bucket, path+oneFile.FileName)
if err != nil {
log.Info("download file failed: %s\n", err.Error())
ctx.ServerError("download file failed:", err)
return
} else {
defer body.Close()
p := make([]byte, 1024)
var readErr error
var readCount int
// 读取对象内容
for {
readCount, readErr = body.Read(p)
if readCount > 0 {
fDest.Write(p[:readCount])
}
if readErr != nil {
break
}
}
}
}
}
} else {
log.Info("error,msg=" + err.Error())
ctx.ServerError("no file to download.", err)
}
}

+ 1
- 0
routers/routes/routes.go View File

@@ -1038,6 +1038,7 @@ func RegisterRoutes(m *macaron.Macaron) {
m.Get("", reqRepoCloudBrainReader, repo.InferenceJobShow)
m.Post("/stop", cloudbrain.AdminOrOwnerOrJobCreaterRight, repo.InferenceJobStop)
m.Post("/del", cloudbrain.AdminOrOwnerOrJobCreaterRight, repo.InferenceJobDel)
m.Get("/downloadall", repo.DownloadMultiResultFile)
})
m.Get("/create", reqRepoCloudBrainWriter, repo.InferenceJobNew)
m.Post("/create", reqRepoCloudBrainWriter, bindIgnErr(auth.CreateModelArtsInferenceJobForm{}), repo.InferenceJobCreate)


Loading…
Cancel
Save