Browse Source

add modelarts interface

fix-2419
lewis 3 years ago
parent
commit
ef8e67aea5
3 changed files with 139 additions and 13 deletions
  1. +44
    -4
      models/cloudbrain.go
  2. +1
    -1
      modules/modelarts/modelarts.go
  3. +94
    -8
      modules/modelarts/resty.go

+ 44
- 4
models/cloudbrain.go View File

@@ -1193,12 +1193,52 @@ type LogFile struct {
Name string Name string
} }


type JobList struct {
JobName string `json:"job_name"`
JobID int64 `json:"job_id"`
VersionID int64 `json:"version_id"`
VersionCount int64 `json:"version_count"`
Description string `json:"job_desc"`
IntStatus int `json:"status"`
Status string
}

type GetTrainJobListResult struct { type GetTrainJobListResult struct {
ErrorResult ErrorResult
JobTotalCount int `json:"job_total_count"` //查询到的用户创建作业总数
JobCountLimit int `json:"job_count_limit"` //用户还可以创建训练作业的数量
Quotas int `json:"quotas"` //训练作业的运行数量上限
ParaConfigs []ParaConfig `json:"jobs"`
JobTotalCount int `json:"job_total_count"` //查询到的用户创建作业总数
JobCountLimit int `json:"job_count_limit"` //用户还可以创建训练作业的数量
Quotas int `json:"quotas"` //训练作业的运行数量上限
JobList []JobList `json:"jobs"`
}

type JobVersionList struct {
VersionName string `json:"version_name"`
VersionID int64 `json:"version_id"`
IntStatus int `json:"status"`
Status string
}

type GetTrainJobVersionListResult struct {
ErrorResult
JobID string `json:"job_id"`
JobName string `json:"job_name"`
JobDesc string `json:"job_desc"`
VersionCount int64 `json:"version_count"`
JobVersionList []JobVersionList `json:"versions"`
}

type NotebookList struct {
JobName string `json:"name"`
JobID string `json:"id"`
Status string `json:"status"`
}

type GetNotebookListResult struct {
TotalCount int64 `json:"total"` //总的记录数量
CurrentPage int `json:"current"` //当前页数
TotalPages int `json:"pages"` //总的页数
Size int `json:"size"` //每一页的数量
NotebookList []NotebookList `json:"data"`
} }


//Grampus //Grampus


+ 1
- 1
modules/modelarts/modelarts.go View File

@@ -506,7 +506,7 @@ func GenerateTrainJobVersion(ctx *context.Context, req *GenerateTrainJobReq, job
} }


//将训练任务的上一版本的isLatestVersion设置为"0" //将训练任务的上一版本的isLatestVersion设置为"0"
err = models.SetVersionCountAndLatestVersion(req.JobName, VersionTaskList[0].VersionName, VersionCount, NotLatestVersion, TotalVersionCount)
err = models.SetVersionCountAndLatestVersion(req.JobName, VersionTaskList[0].VersionName, VersionListCount, NotLatestVersion, VersionTaskList[0].TotalVersionCount)
if err != nil { if err != nil {
ctx.ServerError("Update IsLatestVersion failed", err) ctx.ServerError("Update IsLatestVersion failed", err)
return err return err


+ 94
- 8
modules/modelarts/resty.go View File

@@ -1175,10 +1175,10 @@ sendjob:
return &result, nil return &result, nil
} }


func GetTrainJobList(perPage, page int, sortBy, order, searchContent, status string) (*models.GetConfigListResult, error) {
func GetTrainJobList(perPage, page int, sortBy, order, searchContent, status string) (*models.GetTrainJobListResult, error) {
checkSetting() checkSetting()
client := getRestyClient() client := getRestyClient()
var result models.GetConfigListResult
var result models.GetTrainJobListResult


retry := 0 retry := 0


@@ -1190,14 +1190,13 @@ sendjob:
"sortBy": sortBy, "sortBy": sortBy,
"order": order, "order": order,
"search_content": searchContent, "search_content": searchContent,
"status": status,
}). }).
SetAuthToken(TOKEN). SetAuthToken(TOKEN).
SetResult(&result). SetResult(&result).
Get(HOST + "/v1/" + setting.ProjectID + urlTrainJob) Get(HOST + "/v1/" + setting.ProjectID + urlTrainJob)


if err != nil { if err != nil {
return nil, fmt.Errorf("resty GetConfigList: %v", err)
return nil, fmt.Errorf("resty GetTrainJobList: %v", err)
} }


if res.StatusCode() == http.StatusUnauthorized && retry < 1 { if res.StatusCode() == http.StatusUnauthorized && retry < 1 {
@@ -1212,13 +1211,100 @@ sendjob:
log.Error("json.Unmarshal failed(%s): %v", res.String(), err.Error()) log.Error("json.Unmarshal failed(%s): %v", res.String(), err.Error())
return &result, fmt.Errorf("json.Unmarshal failed(%s): %v", res.String(), err.Error()) return &result, fmt.Errorf("json.Unmarshal failed(%s): %v", res.String(), err.Error())
} }
log.Error("GetConfigList failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg)
return &result, fmt.Errorf("获取参数配置列表失败(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg)
log.Error("GetTrainJobList failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg)
return &result, fmt.Errorf(temp.ErrorMsg)
} }


if !result.IsSuccess { if !result.IsSuccess {
log.Error("GetConfigList failed(%s): %s", result.ErrorCode, result.ErrorMsg)
return &result, fmt.Errorf("获取参数配置列表失败(%s): %s", result.ErrorCode, result.ErrorMsg)
log.Error("GetTrainJobList failed(%s): %s", result.ErrorCode, result.ErrorMsg)
return &result, fmt.Errorf(result.ErrorMsg)
}

return &result, nil
}

func GetTrainJobVersionList(perPage, page int, jobID string) (*models.GetTrainJobVersionListResult, error) {
checkSetting()
client := getRestyClient()
var result models.GetTrainJobVersionListResult

retry := 0

sendjob:
res, err := client.R().
SetQueryParams(map[string]string{
"per_page": strconv.Itoa(perPage),
"page": strconv.Itoa(page),
}).
SetAuthToken(TOKEN).
SetResult(&result).
Get(HOST + "/v1/" + setting.ProjectID + urlTrainJob + "/" + jobID + "/versions")

if err != nil {
return nil, fmt.Errorf("resty GetTrainJobVersionList: %v", err)
}

if res.StatusCode() == http.StatusUnauthorized && retry < 1 {
retry++
_ = getToken()
goto sendjob
}

if res.StatusCode() != http.StatusOK {
var temp models.ErrorResult
if err = json.Unmarshal([]byte(res.String()), &temp); err != nil {
log.Error("json.Unmarshal failed(%s): %v", res.String(), err.Error())
return &result, fmt.Errorf("json.Unmarshal failed(%s): %v", res.String(), err.Error())
}
log.Error("GetTrainJobVersionList failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg)
return &result, fmt.Errorf(temp.ErrorMsg)
}

if !result.IsSuccess {
log.Error("GetTrainJobVersionList failed(%s): %s", result.ErrorCode, result.ErrorMsg)
return &result, fmt.Errorf(result.ErrorMsg)
}

return &result, nil
}

func GetNotebookList(limit, page int, sortBy, order, searchContent, status string) (*models.GetNotebookListResult, error) {
checkSetting()
client := getRestyClient()
var result models.GetNotebookListResult

retry := 0

sendjob:
res, err := client.R().
SetQueryParams(map[string]string{
"limit": strconv.Itoa(limit),
"name": searchContent,
"sort_key": sortBy,
"sort_dir": order,
}).
SetAuthToken(TOKEN).
SetResult(&result).
Get(HOST + "/v1/" + setting.ProjectID + urlNotebook2)

if err != nil {
return nil, fmt.Errorf("resty GetNotebookList: %v", err)
}

if res.StatusCode() == http.StatusUnauthorized && retry < 1 {
retry++
_ = getToken()
goto sendjob
}

if res.StatusCode() != http.StatusOK {
var temp models.ErrorResult
if err = json.Unmarshal([]byte(res.String()), &temp); err != nil {
log.Error("json.Unmarshal failed(%s): %v", res.String(), err.Error())
return &result, fmt.Errorf("json.Unmarshal failed(%s): %v", res.String(), err.Error())
}
log.Error("GetNotebookList failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg)
return &result, fmt.Errorf(temp.ErrorMsg)
} }


return &result, nil return &result, nil


Loading…
Cancel
Save