Browse Source

fix bug

tags/v1.21.12.1
liuzx 4 years ago
parent
commit
63ffb73eca
3 changed files with 128 additions and 7 deletions
  1. +45
    -2
      modules/storage/obs.go
  2. +63
    -4
      routers/api/v1/repo/modelarts.go
  3. +20
    -1
      routers/repo/modelarts.go

+ 45
- 2
modules/storage/obs.go View File

@@ -176,10 +176,10 @@ func ObsModelDownload(JobName string, fileName string) (io.ReadCloser, error) {
}
}

func GetObsListObject(jobName, parentDir string) ([]FileInfo, error) {
func GetObsListObject(jobName, parentDir, versionName string) ([]FileInfo, error) {
input := &obs.ListObjectsInput{}
input.Bucket = setting.Bucket
input.Prefix = strings.TrimPrefix(path.Join(setting.TrainJobModelPath, jobName, setting.OutPutPath, parentDir), "/")
input.Prefix = strings.TrimPrefix(path.Join(setting.TrainJobModelPath, jobName, setting.OutPutPath, versionName, parentDir), "/")
strPrefix := strings.Split(input.Prefix, "/")
output, err := ObsCli.ListObjects(input)
fileInfos := make([]FileInfo, 0)
@@ -275,8 +275,34 @@ func GetObsCreateSignedUrl(jobName, parentDir, fileName string) (string, error)
log.Error("CreateSignedUrl failed:", err.Error())
return "", err
}
log.Info("SignedUrl:%s", output.SignedUrl)
return output.SignedUrl, nil
}

func GetObsCreateSignedUrlByBucketAndKey(bucket, key string) (string, error) {
input := &obs.CreateSignedUrlInput{}
input.Bucket = bucket
input.Key = key

input.Expires = 60 * 60
input.Method = obs.HttpMethodGet
comma := strings.LastIndex(key, "/")
filename := key
if comma != -1 {
filename = key[comma+1:]
}
reqParams := make(map[string]string)
filename = url.QueryEscape(filename)
reqParams["response-content-disposition"] = "attachment; filename=\"" + filename + "\""
input.QueryParams = reqParams
output, err := ObsCli.CreateSignedUrl(input)
if err != nil {
log.Error("CreateSignedUrl failed:", err.Error())
return "", err
}

return output.SignedUrl, nil

}

func ObsGetPreSignedUrl(uuid, fileName string) (string, error) {
@@ -311,3 +337,20 @@ func ObsCreateObject(path string) error {

return nil
}

func ObsDownloadAFile(bucket string, key string) (io.ReadCloser, error) {
input := &obs.GetObjectInput{}
input.Bucket = bucket
input.Key = key
output, err := ObsCli.GetObject(input)
if err == nil {
log.Info("StorageClass:%s, ETag:%s, ContentType:%s, ContentLength:%d, LastModified:%s\n",
output.StorageClass, output.ETag, output.ContentType, output.ContentLength, output.LastModified)
return output.Body, nil
} else if obsError, ok := err.(obs.ObsError); ok {
log.Error("Code:%s, Message:%s", obsError.Code, obsError.Message)
return nil, obsError
} else {
return nil, err
}
}

+ 63
- 4
routers/api/v1/repo/modelarts.go View File

@@ -7,6 +7,7 @@ package repo

import (
"net/http"
"path"
"strconv"
"strings"

@@ -14,6 +15,7 @@ import (
"code.gitea.io/gitea/modules/context"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/modelarts"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/storage"
)

@@ -302,9 +304,7 @@ func ModelList(ctx *context.APIContext) {
log.Error("GetCloudbrainByJobID(%s) failed:%v", task.JobName, err.Error())
return
}
VersionOutputPath := modelarts.GetVersionOutputPathByTotalVersionCount(task.TotalVersionCount)
parentDir = VersionOutputPath + "/" + parentDir
models, err := storage.GetObsListObject(task.JobName, parentDir)
models, err := storage.GetObsListObject(task.JobName, parentDir, versionName)
if err != nil {
log.Info("get TrainJobListModel failed:", err)
ctx.ServerError("GetObsListObject:", err)
@@ -322,7 +322,7 @@ func ModelList(ctx *context.APIContext) {
})
}

func ModelDownload(ctx *context.APIContext) {
func ModelDownload1(ctx *context.APIContext) {
var (
err error
)
@@ -346,3 +346,62 @@ func ModelDownload(ctx *context.APIContext) {
}
http.Redirect(ctx.Resp, ctx.Req.Request, url, http.StatusMovedPermanently)
}

func ModelDownload(ctx *context.Context) {
var (
err error
)

var jobID = ctx.Params(":jobid")
versionName := ctx.Query("version_name")
// versionName := "V0001"
parentDir := ctx.Query("parent_dir")
fileName := ctx.Query("file_name")
log.Info("DownloadSingleModelFile start.")
// id := ctx.Params(":ID")
// path := Model_prefix + models.AttachmentRelativePath(id) + "/" + parentDir + fileName
task, err := models.GetCloudbrainByJobIDAndVersionName(jobID, versionName)
if err != nil {
log.Error("GetCloudbrainByJobID(%s) failed:%v", task.JobName, err.Error())
return
}

path := strings.TrimPrefix(path.Join(setting.TrainJobModelPath, task.JobName, setting.OutPutPath, versionName, parentDir, fileName), "/")
log.Info("Download path is:%s", path)
if setting.PROXYURL != "" {
body, err := storage.ObsDownloadAFile(setting.Bucket, path)
if err != nil {
log.Info("download error.")
} else {
//count++
// models.ModifyModelDownloadCount(id)
defer body.Close()
ctx.Resp.Header().Set("Content-Disposition", "attachment; filename="+fileName)
ctx.Resp.Header().Set("Content-Type", "application/octet-stream")
p := make([]byte, 1024)
var readErr error
var readCount int
// 读取对象内容
for {
readCount, readErr = body.Read(p)
if readCount > 0 {
ctx.Resp.Write(p[:readCount])
//fmt.Printf("%s", p[:readCount])
}
if readErr != nil {
break
}
}
}
} else {
url, err := storage.GetObsCreateSignedUrlByBucketAndKey(setting.Bucket, path)
if err != nil {
log.Error("GetObsCreateSignedUrl failed: %v", err.Error(), ctx.Data["msgID"])
ctx.ServerError("GetObsCreateSignedUrl", err)
return
}
//count++
// models.ModifyModelDownloadCount(id)
http.Redirect(ctx.Resp, ctx.Req.Request, url, http.StatusMovedPermanently)
}
}

+ 20
- 1
routers/repo/modelarts.go View File

@@ -480,6 +480,12 @@ func trainJobNewVersionDataPrepare(ctx *context.Context) error {
var jobID = ctx.Params(":jobid")
// var versionName = ctx.Params(":version-name")
var versionName = ctx.Query("version_name")
// canNewJob, err := canUserCreateTrainJobVersion(ctx, jobID)
// if err != nil {
// ctx.ServerError("get can info failed", err)
// return err
// }
// ctx.Data["canNewJob"] = canNewJob

task, err := models.GetCloudbrainByJobIDAndVersionName(jobID, versionName)
if err != nil {
@@ -541,7 +547,8 @@ func trainJobNewVersionDataPrepare(ctx *context.Context) error {
ctx.ServerError("GetBranches error:", err)
return err
}
ctx.Data["branches"] = Branches

ctx.Data["branch"] = Branches
ctx.Data["branch_name"] = task.BranchName
ctx.Data["description"] = task.Description
ctx.Data["boot_file"] = task.BootFile
@@ -1305,6 +1312,18 @@ func canUserCreateTrainJob(uid int64) (bool, error) {

return org.IsOrgMember(uid)
}
func canUserCreateTrainJobVersion(ctx *context.Context, jobID string) (bool, error) {

var versionName = "V0001"
task, err := models.GetCloudbrainByJobIDAndVersionName(jobID, versionName)
if err != nil {
return false, err
}
if ctx.User.ID == task.User.ID {
return true, nil
}
return false, err
}

func TrainJobGetConfigList(ctx *context.Context) {
ctx.Data["PageIsTrainJob"] = true


Loading…
Cancel
Save