Browse Source

提交代码。

Signed-off-by: zouap <zouap@pcl.ac.cn>
tags/v1.22.7.1
zouap 3 years ago
parent
commit
c7d8d92f8b
4 changed files with 90 additions and 4 deletions
  1. +24
    -0
      models/cloudbrain.go
  2. +7
    -4
      modules/modelarts/modelarts.go
  3. +56
    -0
      modules/modelarts/resty.go
  4. +3
    -0
      routers/repo/ai_model_convert.go

+ 24
- 0
models/cloudbrain.go View File

@@ -893,6 +893,28 @@ type NotebookDelResult struct {
InstanceID string `json:"instance_id"`
}

type CreateUserImageTrainJobParams struct {
JobName string `json:"job_name"`
Description string `json:"job_desc"`
Config UserImageConfig `json:"config"`
WorkspaceID string `json:"workspace_id"`
}

type UserImageConfig struct {
WorkServerNum int `json:"worker_server_num"`
AppUrl string `json:"app_url"` //训练作业的代码目录
BootFileUrl string `json:"boot_file_url"` //训练作业的代码启动文件,需要在代码目录下
Parameter []Parameter `json:"parameter"`
DataUrl string `json:"data_url"` //训练作业需要的数据集OBS路径URL
TrainUrl string `json:"train_url"` //训练作业的输出文件OBS路径URL
LogUrl string `json:"log_url"`
UserImageUrl string `json:"user_image_url"`
UserCommand string `json:"user_command"`
CreateVersion bool `json:"create_version"`
Flavor Flavor `json:"flavor"`
PoolID string `json:"pool_id"`
}

type CreateTrainJobParams struct {
JobName string `json:"job_name"`
Description string `json:"job_desc"`
@@ -952,6 +974,8 @@ type TrainJobVersionConfig struct {
Flavor Flavor `json:"flavor"`
PoolID string `json:"pool_id"`
PreVersionId int64 `json:"pre_version_id"`
UserImageUrl string `json:"user_image_url"`
UserCommand string `json:"user_command"`
}

type CreateConfigParams struct {


+ 7
- 4
modules/modelarts/modelarts.go View File

@@ -97,6 +97,8 @@ type GenerateTrainJobReq struct {
VersionCount int
EngineName string
TotalVersionCount int
UserImageUrl string
UserCommand string
}

type GenerateInferenceJobReq struct {
@@ -386,15 +388,14 @@ func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (err error

func GenerateModelConvertTrainJob(req *GenerateTrainJobReq) (*models.CreateTrainJobResult, error) {

return createTrainJob(models.CreateTrainJobParams{
return createTrainJobUserImage(models.CreateUserImageTrainJobParams{
JobName: req.JobName,
Description: req.Description,
Config: models.Config{
Config: models.UserImageConfig{
WorkServerNum: req.WorkServerNumber,
AppUrl: req.CodeObsPath,
BootFileUrl: req.BootFileUrl,
DataUrl: req.DataUrl,
EngineID: req.EngineID,
TrainUrl: req.TrainUrl,
LogUrl: req.LogUrl,
PoolID: req.PoolID,
@@ -402,7 +403,9 @@ func GenerateModelConvertTrainJob(req *GenerateTrainJobReq) (*models.CreateTrain
Flavor: models.Flavor{
Code: req.FlavorCode,
},
Parameter: req.Parameters,
Parameter: req.Parameters,
UserImageUrl: req.UserImageUrl,
UserCommand: req.UserCommand,
},
})
}


+ 56
- 0
modules/modelarts/resty.go View File

@@ -472,6 +472,62 @@ sendjob:
return &result, nil
}

func createTrainJobUserImage(createJobParams models.CreateUserImageTrainJobParams) (*models.CreateTrainJobResult, error) {
checkSetting()
client := getRestyClient()
var result models.CreateTrainJobResult

retry := 0

sendjob:
res, err := client.R().
SetHeader("Content-Type", "application/json").
SetAuthToken(TOKEN).
SetBody(createJobParams).
SetResult(&result).
Post(HOST + "/v1/" + setting.ProjectID + urlTrainJob)

if err != nil {
return nil, fmt.Errorf("resty create train-job: %s", err)
}

req, _ := json.Marshal(createJobParams)
log.Info("%s", req)

if res.StatusCode() == http.StatusUnauthorized && retry < 1 {
retry++
_ = getToken()
goto sendjob
}

if res.StatusCode() != http.StatusOK {
var temp models.ErrorResult
if err = json.Unmarshal([]byte(res.String()), &temp); err != nil {
log.Error("json.Unmarshal failed(%s): %v", res.String(), err.Error())
return &result, fmt.Errorf("json.Unmarshal failed(%s): %v", res.String(), err.Error())
}
log.Error("createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg)
BootFileErrorMsg := "Invalid OBS path '" + createJobParams.Config.BootFileUrl + "'."
DataSetErrorMsg := "Invalid OBS path '" + createJobParams.Config.DataUrl + "'."
if temp.ErrorMsg == BootFileErrorMsg {
log.Error("启动文件错误!createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg)
return &result, fmt.Errorf("启动文件错误!")
}
if temp.ErrorMsg == DataSetErrorMsg {
log.Error("数据集错误!createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg)
return &result, fmt.Errorf("数据集错误!")
}
return &result, fmt.Errorf("createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg)
}

if !result.IsSuccess {
log.Error("createTrainJob failed(%s): %s", result.ErrorCode, result.ErrorMsg)
return &result, fmt.Errorf("createTrainJob failed(%s): %s", result.ErrorCode, result.ErrorMsg)
}

return &result, nil
}

func createTrainJob(createJobParams models.CreateTrainJobParams) (*models.CreateTrainJobResult, error) {
checkSetting()
client := getRestyClient()


+ 3
- 0
routers/repo/ai_model_convert.go View File

@@ -203,6 +203,7 @@ func createNpuTrainJob(modelConvert *models.AiModelConvert, ctx *context.Context
engineId = int64(NPU_TENSORFLOW_IMAGE_ID)
bootfile = TensorFlowNpuBootFile
}
userCommand := "/bin/bash /home/work/run_train.sh " + codeObsPath + " /code/" + bootfile
req := &modelarts.GenerateTrainJobReq{
JobName: modelConvert.ID,
DisplayJobName: modelConvert.Name,
@@ -220,6 +221,8 @@ func createNpuTrainJob(modelConvert *models.AiModelConvert, ctx *context.Context
PoolID: NPU_PoolID,
Parameters: param,
BranchName: DefaultBranchName,
UserImageUrl: "swr.cn-south-222.ai.pcl.cn/openi/mindspore1.6.1_train_v1_openi:v3_ascend",
UserCommand: userCommand,
}
result, err := modelarts.GenerateModelConvertTrainJob(req)
if err == nil {


Loading…
Cancel
Save