Browse Source

下载模型分离接口

tags/v1.21.12.1
liuzx 4 years ago
parent
commit
12c5a3cf1b
3 changed files with 121 additions and 18 deletions
  1. +59
    -0
      modules/storage/obs.go
  2. +2
    -0
      routers/api/v1/api.go
  3. +60
    -18
      routers/api/v1/repo/modelarts.go

+ 59
- 0
modules/storage/obs.go View File

@@ -235,6 +235,65 @@ func GetObsListObject(jobName, parentDir string) ([]FileInfo, error) {
}
}

func GetObsListObjectVersion(jobName, parentDir string, VersionOutputPath string) ([]FileInfo, error) {
input := &obs.ListObjectsInput{}
input.Bucket = setting.Bucket
input.Prefix = strings.TrimPrefix(path.Join(setting.TrainJobModelPath, jobName, setting.OutPutPath, VersionOutputPath, parentDir), "/")
strPrefix := strings.Split(input.Prefix, "/")
output, err := ObsCli.ListObjects(input)
fileInfos := make([]FileInfo, 0)
if err == nil {
for _, val := range output.Contents {
str1 := strings.Split(val.Key, "/")
var isDir bool
var fileName, nextParentDir string
if strings.HasSuffix(val.Key, "/") {
//dirs in next level dir
if len(str1)-len(strPrefix) > 2 {
continue
}
fileName = str1[len(str1)-2]
isDir = true
if parentDir == "" {
nextParentDir = fileName
} else {
nextParentDir = parentDir + "/" + fileName
}

if fileName == strPrefix[len(strPrefix)-1] || (fileName+"/") == setting.OutPutPath {
continue
}
} else {
//files in next level dir
if len(str1)-len(strPrefix) > 1 {
continue
}
fileName = str1[len(str1)-1]
isDir = false
nextParentDir = parentDir
}

fileInfo := FileInfo{
ModTime: val.LastModified.Local().Format("2006-01-02 15:04:05"),
FileName: fileName,
Size: val.Size,
IsDir: isDir,
ParenDir: nextParentDir,
}
fileInfos = append(fileInfos, fileInfo)
}
sort.Slice(fileInfos, func(i, j int) bool {
return fileInfos[i].ModTime > fileInfos[j].ModTime
})
return fileInfos, err
} else {
if obsError, ok := err.(obs.ObsError); ok {
log.Error("Code:%s, Message:%s", obsError.Code, obsError.Message)
}
return nil, err
}
}

func GetVersionObsListObject(jobName, parentDir string) ([]FileInfo, error) {
input := &obs.ListObjectsInput{}
input.Bucket = setting.Bucket


+ 2
- 0
routers/api/v1/api.go View File

@@ -880,6 +880,8 @@ func RegisterRoutes(m *macaron.Macaron) {
m.Get("/log", repo.TrainJobGetLog)
m.Post("/del_version", repo.DelTrainJobVersion)
m.Post("/stop_version", repo.StopTrainJobVersion)
m.Get("/model_list", repo.ModelList)
m.Get("/model_download", repo.ModelDownload)
// m.Group("/:version-name", func() {
// m.Get("", repo.GetModelArtsTrainJobVersion)
// })


+ 60
- 18
routers/api/v1/repo/modelarts.go View File

@@ -8,12 +8,14 @@ package repo
import (
"net/http"
"strconv"
"strings"

"code.gitea.io/gitea/models"
"code.gitea.io/gitea/modules/context"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/modelarts"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/storage"
)

func GetModelArtsNotebook(ctx *context.APIContext) {
@@ -157,15 +159,6 @@ func TrainJobGetLog(ctx *context.APIContext) {
return
}

// task, err := models.GetCloudbrainByJobIDAndVersionName(jobID, versionName)
// if err != nil {
// log.Error("GetCloudbrainByJobIDAndVersionName(%s) failed:%v", jobID, err.Error())
// ctx.JSON(http.StatusInternalServerError, map[string]interface{}{
// "err_msg": "GetCloudbrainByJobIDAndVersionName failed",
// })
// return
// }

resultLogFile, result, err := trainJobGetLogContent(jobID, versionName, baseLine, order, lines_int)
if err != nil {
log.Error("trainJobGetLog(%s) failed:%v", jobID, err.Error())
@@ -175,15 +168,6 @@ func TrainJobGetLog(ctx *context.APIContext) {

ctx.Data["log_file_name"] = resultLogFile.LogFileList[0]

// result, err := modelarts.GetTrainJobLog(jobID, strconv.FormatInt(task.VersionID, 10), baseLine, logFileName, order, modelarts.Lines)
// if err != nil {
// log.Error("GetTrainJobLog(%s) failed:%v", jobID, err.Error())
// ctx.JSON(http.StatusInternalServerError, map[string]interface{}{
// "err_msg": "GetTrainJobLog failed",
// })
// return
// }

ctx.JSON(http.StatusOK, map[string]interface{}{
"JobID": jobID,
"LogFileName": resultLogFile.LogFileList[0],
@@ -316,3 +300,61 @@ func StopTrainJobVersion(ctx *context.APIContext) {
"StatusOK": 0,
})
}

func ModelList(ctx *context.APIContext) {
var (
err error
)

var jobID = ctx.Params(":jobid")
var versionName = ctx.Query("version_name")
parentDir := ctx.Query("parentDir")
dirArray := strings.Split(parentDir, "/")
task, err := models.GetCloudbrainByJobIDAndVersionName(jobID, versionName)
if err != nil {
log.Error("GetCloudbrainByJobID(%s) failed:%v", task.JobName, err.Error())
return
}
VersionOutputPath := "V" + strconv.Itoa(task.TotalVersionCount)
models, err := storage.GetObsListObjectVersion(task.JobName, parentDir, VersionOutputPath)
if err != nil {
log.Info("get TrainJobListModel failed:", err)
ctx.ServerError("GetObsListObject:", err)
return
}

ctx.JSON(http.StatusOK, map[string]interface{}{
"JobID": jobID,
"VersionName": versionName,
"StatusOK": 0,
"Path": dirArray,
"Dirs": models,
"task": task,
"PageIsCloudBrain": true,
})
}

func ModelDownload(ctx *context.APIContext) {
var (
err error
)

// var jobID = ctx.Params(":jobid")
// var versionName = ctx.Query("version_name")
// parentDir := ctx.Query("parentDir")
// task, err := models.GetCloudbrainByJobIDAndVersionName(jobID, versionName)
// if err != nil {
// log.Error("GetCloudbrainByJobID(%s) failed:%v", task.JobName, err.Error())
// return
// }
parentDir := ctx.Query("parentDir")
fileName := ctx.Query("fileName")
jobName := ctx.Query("jobName")
url, err := storage.GetObsCreateSignedUrl(jobName, parentDir, fileName)
if err != nil {
log.Error("GetObsCreateSignedUrl failed: %v", err.Error(), ctx.Data["msgID"])
ctx.ServerError("GetObsCreateSignedUrl", err)
return
}
http.Redirect(ctx.Resp, ctx.Req.Request, url, http.StatusMovedPermanently)
}

Loading…
Cancel
Save