| @@ -216,7 +216,7 @@ type CloudbrainsOptions struct { | |||||
| JobType string | JobType string | ||||
| VersionName string | VersionName string | ||||
| IsLatestVersion string | IsLatestVersion string | ||||
| JobTypeNot bool | |||||
| JobTypeNot bool | |||||
| } | } | ||||
| type TaskPod struct { | type TaskPod struct { | ||||
| @@ -650,6 +650,28 @@ type Config struct { | |||||
| Flavor Flavor `json:"flavor"` | Flavor Flavor `json:"flavor"` | ||||
| PoolID string `json:"pool_id"` | PoolID string `json:"pool_id"` | ||||
| } | } | ||||
| type CreateInferenceJobParams struct { | |||||
| JobName string `json:"job_name"` | |||||
| Description string `json:"job_desc"` | |||||
| InfConfig InfConfig `json:"config"` | |||||
| WorkspaceID string `json:"workspace_id"` | |||||
| } | |||||
| type InfConfig struct { | |||||
| WorkServerNum int `json:"worker_server_num"` | |||||
| AppUrl string `json:"app_url"` //训练作业的代码目录 | |||||
| BootFileUrl string `json:"boot_file_url"` //训练作业的代码启动文件,需要在代码目录下 | |||||
| Parameter []Parameter `json:"parameter"` | |||||
| DataUrl string `json:"data_url"` //训练作业需要的数据集OBS路径URL | |||||
| EngineID int64 `json:"engine_id"` | |||||
| // TrainUrl string `json:"train_url"` //训练作业的输出文件OBS路径URL | |||||
| LogUrl string `json:"log_url"` | |||||
| //UserImageUrl string `json:"user_image_url"` | |||||
| //UserCommand string `json:"user_command"` | |||||
| CreateVersion bool `json:"create_version"` | |||||
| Flavor Flavor `json:"flavor"` | |||||
| PoolID string `json:"pool_id"` | |||||
| } | |||||
| type CreateTrainJobVersionParams struct { | type CreateTrainJobVersionParams struct { | ||||
| Description string `json:"job_desc"` | Description string `json:"job_desc"` | ||||
| @@ -47,7 +47,7 @@ const ( | |||||
| TrainUrl = "train_url" | TrainUrl = "train_url" | ||||
| DataUrl = "data_url" | DataUrl = "data_url" | ||||
| ResultUrl = "result_url" | ResultUrl = "result_url" | ||||
| CkptName = "ckpt_name" | |||||
| CkptUrl = "ckpt_url" | |||||
| PerPage = 10 | PerPage = 10 | ||||
| IsLatestVersion = "1" | IsLatestVersion = "1" | ||||
| NotLatestVersion = "0" | NotLatestVersion = "0" | ||||
| @@ -477,16 +477,16 @@ func GetVersionOutputPathByTotalVersionCount(TotalVersionCount int) (VersionOutp | |||||
| } | } | ||||
| func GenerateInferenceJob(ctx *context.Context, req *GenerateInferenceJobReq) (err error) { | func GenerateInferenceJob(ctx *context.Context, req *GenerateInferenceJobReq) (err error) { | ||||
| jobResult, err := createTrainJob(models.CreateTrainJobParams{ | |||||
| jobResult, err := createInferenceJob(models.CreateInferenceJobParams{ | |||||
| JobName: req.JobName, | JobName: req.JobName, | ||||
| Description: req.Description, | Description: req.Description, | ||||
| Config: models.Config{ | |||||
| InfConfig: models.InfConfig{ | |||||
| WorkServerNum: req.WorkServerNumber, | WorkServerNum: req.WorkServerNumber, | ||||
| AppUrl: req.CodeObsPath, | AppUrl: req.CodeObsPath, | ||||
| BootFileUrl: req.BootFileUrl, | BootFileUrl: req.BootFileUrl, | ||||
| DataUrl: req.DataUrl, | DataUrl: req.DataUrl, | ||||
| EngineID: req.EngineID, | EngineID: req.EngineID, | ||||
| TrainUrl: req.TrainUrl, | |||||
| // TrainUrl: req.TrainUrl, | |||||
| LogUrl: req.LogUrl, | LogUrl: req.LogUrl, | ||||
| PoolID: req.PoolID, | PoolID: req.PoolID, | ||||
| CreateVersion: true, | CreateVersion: true, | ||||
| @@ -874,3 +874,59 @@ sendjob: | |||||
| return &result, nil | return &result, nil | ||||
| } | } | ||||
| func createInferenceJob(createJobParams models.CreateInferenceJobParams) (*models.CreateTrainJobResult, error) { | |||||
| checkSetting() | |||||
| client := getRestyClient() | |||||
| var result models.CreateTrainJobResult | |||||
| retry := 0 | |||||
| sendjob: | |||||
| res, err := client.R(). | |||||
| SetHeader("Content-Type", "application/json"). | |||||
| SetAuthToken(TOKEN). | |||||
| SetBody(createJobParams). | |||||
| SetResult(&result). | |||||
| Post(HOST + "/v1/" + setting.ProjectID + urlTrainJob) | |||||
| if err != nil { | |||||
| return nil, fmt.Errorf("resty create train-job: %s", err) | |||||
| } | |||||
| req, _ := json.Marshal(createJobParams) | |||||
| log.Info("%s", req) | |||||
| if res.StatusCode() == http.StatusUnauthorized && retry < 1 { | |||||
| retry++ | |||||
| _ = getToken() | |||||
| goto sendjob | |||||
| } | |||||
| if res.StatusCode() != http.StatusOK { | |||||
| var temp models.ErrorResult | |||||
| if err = json.Unmarshal([]byte(res.String()), &temp); err != nil { | |||||
| log.Error("json.Unmarshal failed(%s): %v", res.String(), err.Error()) | |||||
| return &result, fmt.Errorf("json.Unmarshal failed(%s): %v", res.String(), err.Error()) | |||||
| } | |||||
| log.Error("createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) | |||||
| BootFileErrorMsg := "Invalid OBS path '" + createJobParams.InfConfig.BootFileUrl + "'." | |||||
| DataSetErrorMsg := "Invalid OBS path '" + createJobParams.InfConfig.DataUrl + "'." | |||||
| if temp.ErrorMsg == BootFileErrorMsg { | |||||
| log.Error("启动文件错误!createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) | |||||
| return &result, fmt.Errorf("启动文件错误!") | |||||
| } | |||||
| if temp.ErrorMsg == DataSetErrorMsg { | |||||
| log.Error("数据集错误!createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) | |||||
| return &result, fmt.Errorf("数据集错误!") | |||||
| } | |||||
| return &result, fmt.Errorf("createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) | |||||
| } | |||||
| if !result.IsSuccess { | |||||
| log.Error("createTrainJob failed(%s): %s", result.ErrorCode, result.ErrorMsg) | |||||
| return &result, fmt.Errorf("createTrainJob failed(%s): %s", result.ErrorCode, result.ErrorMsg) | |||||
| } | |||||
| return &result, nil | |||||
| } | |||||
| @@ -1596,6 +1596,8 @@ func InferenceJobCreate(ctx *context.Context, form auth.CreateModelArtsInference | |||||
| modelVersion := form.ModelVersion | modelVersion := form.ModelVersion | ||||
| ckptName := form.CkptName | ckptName := form.CkptName | ||||
| ckptUrl := form.TrainUrl + form.CkptName | |||||
| count, err := models.GetCloudbrainTrainJobCountByUserID(ctx.User.ID) | count, err := models.GetCloudbrainTrainJobCountByUserID(ctx.User.ID) | ||||
| if err != nil { | if err != nil { | ||||
| log.Error("GetCloudbrainTrainJobCountByUserID failed:%v", err, ctx.Data["MsgID"]) | log.Error("GetCloudbrainTrainJobCountByUserID failed:%v", err, ctx.Data["MsgID"]) | ||||
| @@ -1675,8 +1677,8 @@ func InferenceJobCreate(ctx *context.Context, form auth.CreateModelArtsInference | |||||
| Label: modelarts.ResultUrl, | Label: modelarts.ResultUrl, | ||||
| Value: "s3:/" + resultObsPath, | Value: "s3:/" + resultObsPath, | ||||
| }, models.Parameter{ | }, models.Parameter{ | ||||
| Label: modelarts.CkptName, | |||||
| Value: ckptName, | |||||
| Label: modelarts.CkptUrl, | |||||
| Value: "s3:/" + ckptUrl, | |||||
| }) | }) | ||||
| if len(params) != 0 { | if len(params) != 0 { | ||||
| err := json.Unmarshal([]byte(params), ¶meters) | err := json.Unmarshal([]byte(params), ¶meters) | ||||