Browse Source

update

tags/v1.22.1.2
liuzx 3 years ago
parent
commit
2ba72dcb25
5 changed files with 88 additions and 0 deletions
  1. +1
    -0
      models/cloudbrain.go
  2. +1
    -0
      modules/auth/modelarts.go
  3. +2
    -0
      modules/modelarts/modelarts.go
  4. +83
    -0
      routers/repo/modelarts.go
  5. +1
    -0
      routers/routes/routes.go

+ 1
- 0
models/cloudbrain.go View File

@@ -126,6 +126,7 @@ type Cloudbrain struct {
EngineName string //引擎名称 EngineName string //引擎名称
TotalVersionCount int //任务的所有版本数量,包括删除的 TotalVersionCount int //任务的所有版本数量,包括删除的


LabelName string //标签名称
ModelName string //模型名称 ModelName string //模型名称
ModelVersion string //模型版本 ModelVersion string //模型版本
CkptName string //权重文件名称 CkptName string //权重文件名称


+ 1
- 0
modules/auth/modelarts.go View File

@@ -62,6 +62,7 @@ type CreateModelArtsInferenceJobForm struct {
VersionName string `form:"version_name" binding:"Required"` VersionName string `form:"version_name" binding:"Required"`
FlavorName string `form:"flaver_names" binding:"Required"` FlavorName string `form:"flaver_names" binding:"Required"`
EngineName string `form:"engine_names" binding:"Required"` EngineName string `form:"engine_names" binding:"Required"`
LabelName string `form:"label_names" binding:"Required"`
TrainUrl string `form:"train_url" binding:"Required"` TrainUrl string `form:"train_url" binding:"Required"`
ModelName string `form:"model_name" binding:"Required"` ModelName string `form:"model_name" binding:"Required"`
ModelVersion string `form:"model_version" binding:"Required"` ModelVersion string `form:"model_version" binding:"Required"`


+ 2
- 0
modules/modelarts/modelarts.go View File

@@ -136,6 +136,7 @@ type GenerateInferenceJobReq struct {
BranchName string BranchName string
FlavorName string FlavorName string
EngineName string EngineName string
LabelName string
IsLatestVersion string IsLatestVersion string
VersionCount int VersionCount int
TotalVersionCount int TotalVersionCount int
@@ -535,6 +536,7 @@ func GenerateInferenceJob(ctx *context.Context, req *GenerateInferenceJobReq) (e
WorkServerNumber: req.WorkServerNumber, WorkServerNumber: req.WorkServerNumber,
FlavorName: req.FlavorName, FlavorName: req.FlavorName,
EngineName: req.EngineName, EngineName: req.EngineName,
LabelName: req.LabelName,
IsLatestVersion: req.IsLatestVersion, IsLatestVersion: req.IsLatestVersion,
VersionCount: req.VersionCount, VersionCount: req.VersionCount,
TotalVersionCount: req.TotalVersionCount, TotalVersionCount: req.TotalVersionCount,


+ 83
- 0
routers/repo/modelarts.go View File

@@ -1,6 +1,7 @@
package repo package repo


import ( import (
"archive/zip"
"encoding/json" "encoding/json"
"errors" "errors"
"io" "io"
@@ -1284,6 +1285,19 @@ func paramCheckCreateInferenceJob(form auth.CreateModelArtsInferenceJobForm) err
return errors.New("计算节点数必须在1-25之间") return errors.New("计算节点数必须在1-25之间")
} }


if form.ModelName == "" {
log.Error("the ModelName(%d) must not be nil", form.ModelName)
return errors.New("模型名称不能为空")
}
if form.ModelVersion == "" {
log.Error("the ModelVersion(%d) must not be nil", form.ModelVersion)
return errors.New("模型版本不能为空")
}
if form.CkptName == "" {
log.Error("the CkptName(%d) must not be nil", form.CkptName)
return errors.New("权重文件不能为空")
}

return nil return nil
} }


@@ -1564,6 +1578,7 @@ func InferenceJobCreate(ctx *context.Context, form auth.CreateModelArtsInference
branch_name := form.BranchName branch_name := form.BranchName
FlavorName := form.FlavorName FlavorName := form.FlavorName
EngineName := form.EngineName EngineName := form.EngineName
LabelName := form.LabelName
isLatestVersion := modelarts.IsLatestVersion isLatestVersion := modelarts.IsLatestVersion
VersionCount := modelarts.VersionCount VersionCount := modelarts.VersionCount
trainUrl := form.TrainUrl trainUrl := form.TrainUrl
@@ -1694,6 +1709,7 @@ func InferenceJobCreate(ctx *context.Context, form auth.CreateModelArtsInference
Params: form.Params, Params: form.Params,
FlavorName: FlavorName, FlavorName: FlavorName,
EngineName: EngineName, EngineName: EngineName,
LabelName: LabelName,
IsLatestVersion: isLatestVersion, IsLatestVersion: isLatestVersion,
VersionCount: VersionCount, VersionCount: VersionCount,
TotalVersionCount: modelarts.TotalVersionCount, TotalVersionCount: modelarts.TotalVersionCount,
@@ -2018,3 +2034,70 @@ func DeleteJobStorage(jobName string) error {


return nil return nil
} }

func DownloadMultiResultFile(ctx *context.Context) {
log.Info("DownloadMultiModelFile start.")
id := ctx.Query("ID")
log.Info("id=" + id)
task, err := models.QueryModelById(id)
if err != nil {
log.Error("no such model!", err.Error())
ctx.ServerError("no such model:", err)
return
}
if !isCanDeleteOrDownload(ctx, task) {
ctx.ServerError("no right.", errors.New(ctx.Tr("repo.model_noright")))
return
}

path := Model_prefix + models.AttachmentRelativePath(id) + "/"

allFile, err := storage.GetAllObjectByBucketAndPrefix(setting.Bucket, path)
if err == nil {
//count++
models.ModifyModelDownloadCount(id)

returnFileName := task.Name + "_" + task.Version + ".zip"
ctx.Resp.Header().Set("Content-Disposition", "attachment; filename="+returnFileName)
ctx.Resp.Header().Set("Content-Type", "application/octet-stream")
w := zip.NewWriter(ctx.Resp)
defer w.Close()
for _, oneFile := range allFile {
if oneFile.IsDir {
log.Info("zip dir name:" + oneFile.FileName)
} else {
log.Info("zip file name:" + oneFile.FileName)
fDest, err := w.Create(oneFile.FileName)
if err != nil {
log.Info("create zip entry error, download file failed: %s\n", err.Error())
ctx.ServerError("download file failed:", err)
return
}
body, err := storage.ObsDownloadAFile(setting.Bucket, path+oneFile.FileName)
if err != nil {
log.Info("download file failed: %s\n", err.Error())
ctx.ServerError("download file failed:", err)
return
} else {
defer body.Close()
p := make([]byte, 1024)
var readErr error
var readCount int
// 读取对象内容
for {
readCount, readErr = body.Read(p)
if readCount > 0 {
fDest.Write(p[:readCount])
}
if readErr != nil {
break
}
}
}
}
}
} else {
log.Info("error,msg=" + err.Error())
ctx.ServerError("no file to download.", err)
}
}

+ 1
- 0
routers/routes/routes.go View File

@@ -1038,6 +1038,7 @@ func RegisterRoutes(m *macaron.Macaron) {
m.Get("", reqRepoCloudBrainReader, repo.InferenceJobShow) m.Get("", reqRepoCloudBrainReader, repo.InferenceJobShow)
m.Post("/stop", cloudbrain.AdminOrOwnerOrJobCreaterRight, repo.InferenceJobStop) m.Post("/stop", cloudbrain.AdminOrOwnerOrJobCreaterRight, repo.InferenceJobStop)
m.Post("/del", cloudbrain.AdminOrOwnerOrJobCreaterRight, repo.InferenceJobDel) m.Post("/del", cloudbrain.AdminOrOwnerOrJobCreaterRight, repo.InferenceJobDel)
m.Get("/downloadall", repo.DownloadMultiResultFile)
}) })
m.Get("/create", reqRepoCloudBrainWriter, repo.InferenceJobNew) m.Get("/create", reqRepoCloudBrainWriter, repo.InferenceJobNew)
m.Post("/create", reqRepoCloudBrainWriter, bindIgnErr(auth.CreateModelArtsInferenceJobForm{}), repo.InferenceJobCreate) m.Post("/create", reqRepoCloudBrainWriter, bindIgnErr(auth.CreateModelArtsInferenceJobForm{}), repo.InferenceJobCreate)


Loading…
Cancel
Save