diff --git a/modules/structs/pipeline.go b/modules/structs/pipeline.go new file mode 100644 index 000000000..fd26d1b51 --- /dev/null +++ b/modules/structs/pipeline.go @@ -0,0 +1,23 @@ +package structs + +type Pipeline struct { + ID int64 `json:"id"` + Name string `json:"name"` + Status string `json:"status"` +} +type NodeInfo struct { + Name string `json:"name"` + Status string `json:"status"` + Code string `json:"code"` + Message string `json:"message"` +} + +type PipelineNotification struct { + Type int `json:"type"` + Username string `json:"username"` + Reponame string `json:"reponame"` + Pipeline Pipeline `json:"pipeline"` + PipelineRunId string `json:"pipeline_run_id"` + Node NodeInfo `json:"node"` + OccurTime int64 `json:"occur_time"` +} diff --git a/routers/api/v1/api.go b/routers/api/v1/api.go index 309a484ce..66a73862b 100755 --- a/routers/api/v1/api.go +++ b/routers/api/v1/api.go @@ -543,6 +543,10 @@ func RegisterRoutes(m *macaron.Macaron) { m.Get("/get_multipart_url", repo.GetMultipartUploadUrl) m.Post("/complete_multipart", repo.CompleteMultipart) + }, reqToken()) + m.Group("/pipeline", func() { + m.Post("/notification", bind(api.PipelineNotification{}), notify.PipelineNotify) + }, reqToken()) // Notifications @@ -995,6 +999,7 @@ func RegisterRoutes(m *macaron.Macaron) { m.Get("/detail", reqToken(), reqRepoReader(models.UnitTypeCloudBrain), repo.CloudBrainShow) m.Get("/model_list", repo.CloudBrainModelList) m.Post("/stop_version", cloudbrain.AdminOrOwnerOrJobCreaterRightForTrain, repo_ext.CloudBrainStop) + m.Put("/stop", cloudbrain.AdminOrOwnerOrJobCreaterRightForTrain, repo.GeneralCloudBrainJobStop) }) }) m.Group("/inference-job", func() { @@ -1019,6 +1024,7 @@ func RegisterRoutes(m *macaron.Macaron) { m.Get("/query_modelfile_for_predict", repo.QueryModelFileForPredict) m.Get("/query_train_model", repo.QueryTrainModelList) m.Post("/create_model_convert", repo.CreateModelConvert) + m.Post("/convert_stop", repo.StopModelConvert) m.Get("/show_model_convert_page", repo.ShowModelConvertPage) m.Get("/query_model_convert_byId", repo.QueryModelConvertById) diff --git a/routers/api/v1/notify/pipeline.go b/routers/api/v1/notify/pipeline.go new file mode 100644 index 000000000..021af20dc --- /dev/null +++ b/routers/api/v1/notify/pipeline.go @@ -0,0 +1,15 @@ +package notify + +import ( + "net/http" + + "code.gitea.io/gitea/models" + "code.gitea.io/gitea/modules/context" + api "code.gitea.io/gitea/modules/structs" +) + +func PipelineNotify(ctx *context.APIContext, form api.PipelineNotification) { + + ctx.JSON(http.StatusOK, models.BaseOKMessageApi) + +} diff --git a/routers/api/v1/repo/cloudbrain.go b/routers/api/v1/repo/cloudbrain.go index cd8340c41..1c5a58b47 100755 --- a/routers/api/v1/repo/cloudbrain.go +++ b/routers/api/v1/repo/cloudbrain.go @@ -17,6 +17,8 @@ import ( "strings" "time" + "code.gitea.io/gitea/modules/grampus" + cloudbrainService "code.gitea.io/gitea/services/cloudbrain" "code.gitea.io/gitea/modules/convert" @@ -80,6 +82,30 @@ func CloudBrainShow(ctx *context.APIContext) { ctx.JSON(http.StatusOK, models.BaseMessageWithDataApi{Code: 0, Message: "", Data: convert.ToCloudBrain(task)}) } +func GeneralCloudBrainJobStop(ctx *context.APIContext) { + task := ctx.Cloudbrain + if task.IsTerminal() { + ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("cloudbrain.Already_stopped")) + return + } + var err error + + if ctx.Cloudbrain.Type == models.TypeCloudBrainOne { + err = cloudbrain.StopJob(task.JobID) + } else if ctx.Cloudbrain.Type == models.TypeCloudBrainTwo { + _, err = modelarts.StopTrainJob(task.JobID, strconv.FormatInt(task.VersionID, 10)) + } else { + _, err = grampus.StopJob(task.JobID) + } + + if err != nil { + log.Warn("cloud brain stopped failed.", err) + ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("cloudbrain.Stopped_failed")) + return + } + + ctx.JSON(http.StatusOK, models.BaseOKMessageApi) +} func CreateFileNoteBook(ctx *context.APIContext, option api.CreateFileNotebookJobOption) { cloudbrainTask.FileNotebookCreate(ctx.Context, option) } diff --git a/routers/api/v1/repo/modelmanage.go b/routers/api/v1/repo/modelmanage.go index 3989ec56c..a8166732d 100644 --- a/routers/api/v1/repo/modelmanage.go +++ b/routers/api/v1/repo/modelmanage.go @@ -89,6 +89,11 @@ func CreateModelConvert(ctx *context.APIContext) { routerRepo.SaveModelConvert(ctx.Context) } +func StopModelConvert(ctx *context.APIContext) { + log.Info("StopModelConvert by api.") + routerRepo.StopModelConvertApi(ctx.Context) +} + func ShowModelConvertPage(ctx *context.APIContext) { log.Info("ShowModelConvertPage by api.") modelResult, count, err := routerRepo.GetModelConvertPageData(ctx.Context) diff --git a/routers/repo/ai_model_convert.go b/routers/repo/ai_model_convert.go index dda410def..4b22998f5 100644 --- a/routers/repo/ai_model_convert.go +++ b/routers/repo/ai_model_convert.go @@ -573,13 +573,10 @@ func deleteCloudBrainTask(task *models.AiModelConvert) { } } -func StopModelConvert(ctx *context.Context) { - id := ctx.Params(":id") - log.Info("stop model convert start.id=" + id) +func stopModelConvert(id string) error { job, err := models.QueryModelConvertById(id) if err != nil { - ctx.ServerError("Not found task.", err) - return + return err } if job.IsGpuTrainTask() { err = cloudbrain.StopJob(job.CloudBrainTaskId) @@ -600,6 +597,35 @@ func StopModelConvert(ctx *context.Context) { err = models.UpdateModelConvert(job) if err != nil { log.Error("UpdateModelConvert failed:", err) + return err + } + return nil +} + +func StopModelConvertApi(ctx *context.Context) { + id := ctx.Params(":id") + log.Info("stop model convert start.id=" + id) + err := stopModelConvert(id) + if err == nil { + ctx.JSON(200, map[string]string{ + "code": "0", + "msg": "succeed", + }) + } else { + ctx.JSON(200, map[string]string{ + "code": "1", + "msg": err.Error(), + }) + } +} + +func StopModelConvert(ctx *context.Context) { + id := ctx.Params(":id") + log.Info("stop model convert start.id=" + id) + err := stopModelConvert(id) + if err != nil { + ctx.ServerError("Not found task.", err) + return } ctx.Redirect(setting.AppSubURL + ctx.Repo.RepoLink + "/modelmanage/convert_model") }