Browse Source

update

tags/v1.22.1.2
liuzx 3 years ago
parent
commit
87103a8597
4 changed files with 87 additions and 7 deletions
  1. +23
    -1
      models/cloudbrain.go
  2. +4
    -4
      modules/modelarts/modelarts.go
  3. +56
    -0
      modules/modelarts/resty.go
  4. +4
    -2
      routers/repo/modelarts.go

+ 23
- 1
models/cloudbrain.go View File

@@ -216,7 +216,7 @@ type CloudbrainsOptions struct {
JobType string JobType string
VersionName string VersionName string
IsLatestVersion string IsLatestVersion string
JobTypeNot bool
JobTypeNot bool
} }


type TaskPod struct { type TaskPod struct {
@@ -650,6 +650,28 @@ type Config struct {
Flavor Flavor `json:"flavor"` Flavor Flavor `json:"flavor"`
PoolID string `json:"pool_id"` PoolID string `json:"pool_id"`
} }
type CreateInferenceJobParams struct {
JobName string `json:"job_name"`
Description string `json:"job_desc"`
InfConfig InfConfig `json:"config"`
WorkspaceID string `json:"workspace_id"`
}

type InfConfig struct {
WorkServerNum int `json:"worker_server_num"`
AppUrl string `json:"app_url"` //训练作业的代码目录
BootFileUrl string `json:"boot_file_url"` //训练作业的代码启动文件,需要在代码目录下
Parameter []Parameter `json:"parameter"`
DataUrl string `json:"data_url"` //训练作业需要的数据集OBS路径URL
EngineID int64 `json:"engine_id"`
// TrainUrl string `json:"train_url"` //训练作业的输出文件OBS路径URL
LogUrl string `json:"log_url"`
//UserImageUrl string `json:"user_image_url"`
//UserCommand string `json:"user_command"`
CreateVersion bool `json:"create_version"`
Flavor Flavor `json:"flavor"`
PoolID string `json:"pool_id"`
}


type CreateTrainJobVersionParams struct { type CreateTrainJobVersionParams struct {
Description string `json:"job_desc"` Description string `json:"job_desc"`


+ 4
- 4
modules/modelarts/modelarts.go View File

@@ -47,7 +47,7 @@ const (
TrainUrl = "train_url" TrainUrl = "train_url"
DataUrl = "data_url" DataUrl = "data_url"
ResultUrl = "result_url" ResultUrl = "result_url"
CkptName = "ckpt_name"
CkptUrl = "ckpt_url"
PerPage = 10 PerPage = 10
IsLatestVersion = "1" IsLatestVersion = "1"
NotLatestVersion = "0" NotLatestVersion = "0"
@@ -477,16 +477,16 @@ func GetVersionOutputPathByTotalVersionCount(TotalVersionCount int) (VersionOutp
} }


func GenerateInferenceJob(ctx *context.Context, req *GenerateInferenceJobReq) (err error) { func GenerateInferenceJob(ctx *context.Context, req *GenerateInferenceJobReq) (err error) {
jobResult, err := createTrainJob(models.CreateTrainJobParams{
jobResult, err := createInferenceJob(models.CreateInferenceJobParams{
JobName: req.JobName, JobName: req.JobName,
Description: req.Description, Description: req.Description,
Config: models.Config{
InfConfig: models.InfConfig{
WorkServerNum: req.WorkServerNumber, WorkServerNum: req.WorkServerNumber,
AppUrl: req.CodeObsPath, AppUrl: req.CodeObsPath,
BootFileUrl: req.BootFileUrl, BootFileUrl: req.BootFileUrl,
DataUrl: req.DataUrl, DataUrl: req.DataUrl,
EngineID: req.EngineID, EngineID: req.EngineID,
TrainUrl: req.TrainUrl,
// TrainUrl: req.TrainUrl,
LogUrl: req.LogUrl, LogUrl: req.LogUrl,
PoolID: req.PoolID, PoolID: req.PoolID,
CreateVersion: true, CreateVersion: true,


+ 56
- 0
modules/modelarts/resty.go View File

@@ -874,3 +874,59 @@ sendjob:


return &result, nil return &result, nil
} }

func createInferenceJob(createJobParams models.CreateInferenceJobParams) (*models.CreateTrainJobResult, error) {
checkSetting()
client := getRestyClient()
var result models.CreateTrainJobResult

retry := 0

sendjob:
res, err := client.R().
SetHeader("Content-Type", "application/json").
SetAuthToken(TOKEN).
SetBody(createJobParams).
SetResult(&result).
Post(HOST + "/v1/" + setting.ProjectID + urlTrainJob)

if err != nil {
return nil, fmt.Errorf("resty create train-job: %s", err)
}

req, _ := json.Marshal(createJobParams)
log.Info("%s", req)

if res.StatusCode() == http.StatusUnauthorized && retry < 1 {
retry++
_ = getToken()
goto sendjob
}

if res.StatusCode() != http.StatusOK {
var temp models.ErrorResult
if err = json.Unmarshal([]byte(res.String()), &temp); err != nil {
log.Error("json.Unmarshal failed(%s): %v", res.String(), err.Error())
return &result, fmt.Errorf("json.Unmarshal failed(%s): %v", res.String(), err.Error())
}
log.Error("createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg)
BootFileErrorMsg := "Invalid OBS path '" + createJobParams.InfConfig.BootFileUrl + "'."
DataSetErrorMsg := "Invalid OBS path '" + createJobParams.InfConfig.DataUrl + "'."
if temp.ErrorMsg == BootFileErrorMsg {
log.Error("启动文件错误!createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg)
return &result, fmt.Errorf("启动文件错误!")
}
if temp.ErrorMsg == DataSetErrorMsg {
log.Error("数据集错误!createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg)
return &result, fmt.Errorf("数据集错误!")
}
return &result, fmt.Errorf("createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg)
}

if !result.IsSuccess {
log.Error("createTrainJob failed(%s): %s", result.ErrorCode, result.ErrorMsg)
return &result, fmt.Errorf("createTrainJob failed(%s): %s", result.ErrorCode, result.ErrorMsg)
}

return &result, nil
}

+ 4
- 2
routers/repo/modelarts.go View File

@@ -1596,6 +1596,8 @@ func InferenceJobCreate(ctx *context.Context, form auth.CreateModelArtsInference
modelVersion := form.ModelVersion modelVersion := form.ModelVersion
ckptName := form.CkptName ckptName := form.CkptName


ckptUrl := form.TrainUrl + form.CkptName

count, err := models.GetCloudbrainTrainJobCountByUserID(ctx.User.ID) count, err := models.GetCloudbrainTrainJobCountByUserID(ctx.User.ID)
if err != nil { if err != nil {
log.Error("GetCloudbrainTrainJobCountByUserID failed:%v", err, ctx.Data["MsgID"]) log.Error("GetCloudbrainTrainJobCountByUserID failed:%v", err, ctx.Data["MsgID"])
@@ -1675,8 +1677,8 @@ func InferenceJobCreate(ctx *context.Context, form auth.CreateModelArtsInference
Label: modelarts.ResultUrl, Label: modelarts.ResultUrl,
Value: "s3:/" + resultObsPath, Value: "s3:/" + resultObsPath,
}, models.Parameter{ }, models.Parameter{
Label: modelarts.CkptName,
Value: ckptName,
Label: modelarts.CkptUrl,
Value: "s3:/" + ckptUrl,
}) })
if len(params) != 0 { if len(params) != 0 {
err := json.Unmarshal([]byte(params), &parameters) err := json.Unmarshal([]byte(params), &parameters)


Loading…
Cancel
Save