Browse Source

增加训练版本的模型下载接口

tags/v1.21.12.1
liuzx 4 years ago
parent
commit
800bd819bb
3 changed files with 99 additions and 10 deletions
  1. +64
    -8
      modules/storage/obs.go
  2. +34
    -2
      routers/repo/modelarts.go
  3. +1
    -0
      routers/routes/routes.go

+ 64
- 8
modules/storage/obs.go View File

@@ -185,10 +185,10 @@ func GetObsListObject(jobName, parentDir string) ([]FileInfo, error) {
for _, val := range output.Contents {
str1 := strings.Split(val.Key, "/")
var isDir bool
var fileName,nextParentDir string
var fileName, nextParentDir string
if strings.HasSuffix(val.Key, "/") {
//dirs in next level dir
if len(str1) - len(strPrefix) > 2 {
if len(str1)-len(strPrefix) > 2 {
continue
}
fileName = str1[len(str1)-2]
@@ -199,12 +199,12 @@ func GetObsListObject(jobName, parentDir string) ([]FileInfo, error) {
nextParentDir = parentDir + "/" + fileName
}

if fileName == strPrefix[len(strPrefix)-1] || (fileName + "/") == setting.OutPutPath {
if fileName == strPrefix[len(strPrefix)-1] || (fileName+"/") == setting.OutPutPath {
continue
}
} else {
//files in next level dir
if len(str1) - len(strPrefix) > 1 {
if len(str1)-len(strPrefix) > 1 {
continue
}
fileName = str1[len(str1)-1]
@@ -213,10 +213,66 @@ func GetObsListObject(jobName, parentDir string) ([]FileInfo, error) {
}

fileInfo := FileInfo{
ModTime: val.LastModified.Format("2006-01-02 15:04:05"),
ModTime: val.LastModified.Format("2006-01-02 15:04:05"),
FileName: fileName,
Size: val.Size,
IsDir:isDir,
Size: val.Size,
IsDir: isDir,
ParenDir: nextParentDir,
}
fileInfos = append(fileInfos, fileInfo)
}
return fileInfos, err
} else {
if obsError, ok := err.(obs.ObsError); ok {
log.Error("Code:%s, Message:%s", obsError.Code, obsError.Message)
}
return nil, err
}
}

func GetVersionObsListObject(jobName, parentDir string) ([]FileInfo, error) {
input := &obs.ListObjectsInput{}
input.Bucket = setting.Bucket
input.Prefix = strings.TrimPrefix(path.Join(setting.TrainJobModelPath, jobName, setting.OutPutPath, parentDir), "/")
strPrefix := strings.Split(input.Prefix, "/")
output, err := ObsCli.ListObjects(input)
fileInfos := make([]FileInfo, 0)
if err == nil {
for _, val := range output.Contents {
str1 := strings.Split(val.Key, "/")
var isDir bool
var fileName, nextParentDir string
if strings.HasSuffix(val.Key, "/") {
//dirs in next level dir
if len(str1)-len(strPrefix) > 2 {
continue
}
fileName = str1[len(str1)-2]
isDir = true
if parentDir == "" {
nextParentDir = fileName
} else {
nextParentDir = parentDir + "/" + fileName
}

if fileName == strPrefix[len(strPrefix)-1] || (fileName+"/") == setting.OutPutPath {
continue
}
} else {
//files in next level dir
if len(str1)-len(strPrefix) > 1 {
continue
}
fileName = str1[len(str1)-1]
isDir = false
nextParentDir = parentDir
}

fileInfo := FileInfo{
ModTime: val.LastModified.Format("2006-01-02 15:04:05"),
FileName: fileName,
Size: val.Size,
IsDir: isDir,
ParenDir: nextParentDir,
}
fileInfos = append(fileInfos, fileInfo)
@@ -257,7 +313,7 @@ func GetObsCreateSignedUrl(jobName, parentDir, fileName string) (string, error)
input := &obs.CreateSignedUrlInput{}
input.Bucket = setting.Bucket
input.Key = strings.TrimPrefix(path.Join(setting.TrainJobModelPath, jobName, setting.OutPutPath, parentDir, fileName), "/")
input.Expires = 60 * 60
input.Method = obs.HttpMethodGet



+ 34
- 2
routers/repo/modelarts.go View File

@@ -1219,6 +1219,10 @@ func TrainJobShow(ctx *context.Context) {

var jobID = ctx.Params(":jobid")
task, err := models.GetCloudbrainByJobID(jobID)
if err != nil {
ctx.ServerError("GetCloudbrainByJobID faild", err)
return
}

repo := ctx.Repo.Repository
page := ctx.QueryInt("page")
@@ -1290,8 +1294,8 @@ func TrainJobShow(ctx *context.Context) {
ctx.Data["task"] = task
ctx.Data["jobID"] = jobID
ctx.Data["result"] = result
ctx.Data["VersionListTasks"] = VersionListTasks
ctx.Data["VersionLisCount"] = VersionListCount
ctx.Data["version_list_task"] = VersionListTasks
ctx.Data["version_list_count"] = VersionListCount
ctx.HTML(http.StatusOK, tplModelArtsTrainJobShow)
}

@@ -1541,6 +1545,34 @@ func TrainJobShowModels(ctx *context.Context) {
ctx.HTML(200, tplModelArtsTrainJobShowModels)
}

func TrainJobVersionShowModels(ctx *context.Context) {
ctx.Data["PageIsCloudBrain"] = true

jobID := ctx.Params(":jobid")
parentDir := ctx.Query("parentDir")
versionName := ctx.Query("version_name")
dirArray := strings.Split(parentDir, "/")
task, err := models.GetCloudbrainByJobIDAndVersionName(jobID, versionName)
if err != nil {
log.Error("no such job!", ctx.Data["msgID"])
ctx.ServerError("no such job:", err)
return
}
parentDir = versionName
models, err := storage.GetVersionObsListObject(task.JobName, parentDir)
if err != nil {
log.Info("get TrainJobListModel failed:", err)
ctx.ServerError("GetVersionObsListObject:", err)
return
}

ctx.Data["Path"] = dirArray
ctx.Data["Dirs"] = models
ctx.Data["task"] = task
ctx.Data["JobID"] = jobID
ctx.HTML(200, tplModelArtsTrainJobShowModels)
}

func TrainJobDownloadModel(ctx *context.Context) {
parentDir := ctx.Query("parentDir")
fileName := ctx.Query("fileName")


+ 1
- 0
routers/routes/routes.go View File

@@ -988,6 +988,7 @@ func RegisterRoutes(m *macaron.Macaron) {
m.Get("/log", reqRepoCloudBrainReader, repo.TrainJobGetLog)
m.Get("/models", reqRepoCloudBrainReader, repo.TrainJobShowModels)
m.Get("/download_model", reqRepoCloudBrainReader, repo.TrainJobDownloadModel)
m.Get("/version_models", reqRepoCloudBrainReader, repo.TrainJobVersionShowModels)
m.Get("/create_version", reqRepoCloudBrainReader, repo.TrainJobNewVersion)
m.Post("/create_version", reqRepoCloudBrainWriter, bindIgnErr(auth.CreateModelArtsTrainJobForm{}), repo.TrainJobCreateVersion)
m.Post("/stop_version", reqRepoCloudBrainWriter, repo.TrainJobVersionStop)


Loading…
Cancel
Save