@@ -893,6 +893,28 @@ type NotebookDelResult struct { | |||
InstanceID string `json:"instance_id"` | |||
} | |||
type CreateUserImageTrainJobParams struct { | |||
JobName string `json:"job_name"` | |||
Description string `json:"job_desc"` | |||
Config UserImageConfig `json:"config"` | |||
WorkspaceID string `json:"workspace_id"` | |||
} | |||
type UserImageConfig struct { | |||
WorkServerNum int `json:"worker_server_num"` | |||
AppUrl string `json:"app_url"` //训练作业的代码目录 | |||
BootFileUrl string `json:"boot_file_url"` //训练作业的代码启动文件,需要在代码目录下 | |||
Parameter []Parameter `json:"parameter"` | |||
DataUrl string `json:"data_url"` //训练作业需要的数据集OBS路径URL | |||
TrainUrl string `json:"train_url"` //训练作业的输出文件OBS路径URL | |||
LogUrl string `json:"log_url"` | |||
UserImageUrl string `json:"user_image_url"` | |||
UserCommand string `json:"user_command"` | |||
CreateVersion bool `json:"create_version"` | |||
Flavor Flavor `json:"flavor"` | |||
PoolID string `json:"pool_id"` | |||
} | |||
type CreateTrainJobParams struct { | |||
JobName string `json:"job_name"` | |||
Description string `json:"job_desc"` | |||
@@ -952,6 +974,8 @@ type TrainJobVersionConfig struct { | |||
Flavor Flavor `json:"flavor"` | |||
PoolID string `json:"pool_id"` | |||
PreVersionId int64 `json:"pre_version_id"` | |||
UserImageUrl string `json:"user_image_url"` | |||
UserCommand string `json:"user_command"` | |||
} | |||
type CreateConfigParams struct { | |||
@@ -97,6 +97,8 @@ type GenerateTrainJobReq struct { | |||
VersionCount int | |||
EngineName string | |||
TotalVersionCount int | |||
UserImageUrl string | |||
UserCommand string | |||
} | |||
type GenerateInferenceJobReq struct { | |||
@@ -386,15 +388,14 @@ func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (err error | |||
func GenerateModelConvertTrainJob(req *GenerateTrainJobReq) (*models.CreateTrainJobResult, error) { | |||
return createTrainJob(models.CreateTrainJobParams{ | |||
return createTrainJobUserImage(models.CreateUserImageTrainJobParams{ | |||
JobName: req.JobName, | |||
Description: req.Description, | |||
Config: models.Config{ | |||
Config: models.UserImageConfig{ | |||
WorkServerNum: req.WorkServerNumber, | |||
AppUrl: req.CodeObsPath, | |||
BootFileUrl: req.BootFileUrl, | |||
DataUrl: req.DataUrl, | |||
EngineID: req.EngineID, | |||
TrainUrl: req.TrainUrl, | |||
LogUrl: req.LogUrl, | |||
PoolID: req.PoolID, | |||
@@ -402,7 +403,9 @@ func GenerateModelConvertTrainJob(req *GenerateTrainJobReq) (*models.CreateTrain | |||
Flavor: models.Flavor{ | |||
Code: req.FlavorCode, | |||
}, | |||
Parameter: req.Parameters, | |||
Parameter: req.Parameters, | |||
UserImageUrl: req.UserImageUrl, | |||
UserCommand: req.UserCommand, | |||
}, | |||
}) | |||
} | |||
@@ -472,6 +472,62 @@ sendjob: | |||
return &result, nil | |||
} | |||
func createTrainJobUserImage(createJobParams models.CreateUserImageTrainJobParams) (*models.CreateTrainJobResult, error) { | |||
checkSetting() | |||
client := getRestyClient() | |||
var result models.CreateTrainJobResult | |||
retry := 0 | |||
sendjob: | |||
res, err := client.R(). | |||
SetHeader("Content-Type", "application/json"). | |||
SetAuthToken(TOKEN). | |||
SetBody(createJobParams). | |||
SetResult(&result). | |||
Post(HOST + "/v1/" + setting.ProjectID + urlTrainJob) | |||
if err != nil { | |||
return nil, fmt.Errorf("resty create train-job: %s", err) | |||
} | |||
req, _ := json.Marshal(createJobParams) | |||
log.Info("%s", req) | |||
if res.StatusCode() == http.StatusUnauthorized && retry < 1 { | |||
retry++ | |||
_ = getToken() | |||
goto sendjob | |||
} | |||
if res.StatusCode() != http.StatusOK { | |||
var temp models.ErrorResult | |||
if err = json.Unmarshal([]byte(res.String()), &temp); err != nil { | |||
log.Error("json.Unmarshal failed(%s): %v", res.String(), err.Error()) | |||
return &result, fmt.Errorf("json.Unmarshal failed(%s): %v", res.String(), err.Error()) | |||
} | |||
log.Error("createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) | |||
BootFileErrorMsg := "Invalid OBS path '" + createJobParams.Config.BootFileUrl + "'." | |||
DataSetErrorMsg := "Invalid OBS path '" + createJobParams.Config.DataUrl + "'." | |||
if temp.ErrorMsg == BootFileErrorMsg { | |||
log.Error("启动文件错误!createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) | |||
return &result, fmt.Errorf("启动文件错误!") | |||
} | |||
if temp.ErrorMsg == DataSetErrorMsg { | |||
log.Error("数据集错误!createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) | |||
return &result, fmt.Errorf("数据集错误!") | |||
} | |||
return &result, fmt.Errorf("createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) | |||
} | |||
if !result.IsSuccess { | |||
log.Error("createTrainJob failed(%s): %s", result.ErrorCode, result.ErrorMsg) | |||
return &result, fmt.Errorf("createTrainJob failed(%s): %s", result.ErrorCode, result.ErrorMsg) | |||
} | |||
return &result, nil | |||
} | |||
func createTrainJob(createJobParams models.CreateTrainJobParams) (*models.CreateTrainJobResult, error) { | |||
checkSetting() | |||
client := getRestyClient() | |||
@@ -203,6 +203,7 @@ func createNpuTrainJob(modelConvert *models.AiModelConvert, ctx *context.Context | |||
engineId = int64(NPU_TENSORFLOW_IMAGE_ID) | |||
bootfile = TensorFlowNpuBootFile | |||
} | |||
userCommand := "/bin/bash /home/work/run_train.sh " + codeObsPath + " /code/" + bootfile | |||
req := &modelarts.GenerateTrainJobReq{ | |||
JobName: modelConvert.ID, | |||
DisplayJobName: modelConvert.Name, | |||
@@ -220,6 +221,8 @@ func createNpuTrainJob(modelConvert *models.AiModelConvert, ctx *context.Context | |||
PoolID: NPU_PoolID, | |||
Parameters: param, | |||
BranchName: DefaultBranchName, | |||
UserImageUrl: "swr.cn-south-222.ai.pcl.cn/openi/mindspore1.6.1_train_v1_openi:v3_ascend", | |||
UserCommand: userCommand, | |||
} | |||
result, err := modelarts.GenerateModelConvertTrainJob(req) | |||
if err == nil { | |||