diff --git a/models/ai_model_manage.go b/models/ai_model_manage.go index b6169b806..7ed6b1051 100644 --- a/models/ai_model_manage.go +++ b/models/ai_model_manage.go @@ -4,7 +4,9 @@ import ( "fmt" "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/timeutil" + "xorm.io/builder" ) type AiModelManage struct { @@ -32,6 +34,16 @@ type AiModelManage struct { UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"` } +type AiModelQueryOptions struct { + ListOptions + RepoID int64 // include all repos if empty + UserID int64 + ModelID string + SortType string + // JobStatus CloudbrainStatus + Type int +} + func SaveModelToDb(model *AiModelManage) error { sess := x.NewSession() defer sess.Close() @@ -68,3 +80,59 @@ func QueryModelByName(name string, uid int64) []*AiModelManage { sess.Find(&aiModelManageList) return aiModelManageList } + +func QueryModel(opts *AiModelQueryOptions) ([]*AiModelManage, int64, error) { + sess := x.NewSession() + defer sess.Close() + + var cond = builder.NewCond() + if opts.RepoID > 0 { + cond = cond.And( + builder.Eq{"ai_model_manage.repo_id": opts.RepoID}, + ) + } + + if opts.UserID > 0 { + cond = cond.And( + builder.Eq{"ai_model_manage.user_id": opts.UserID}, + ) + } + + if len(opts.ModelID) > 0 { + cond = cond.And( + builder.Eq{"ai_model_manage.id": opts.ModelID}, + ) + } + + if (opts.Type) >= 0 { + cond = cond.And( + builder.Eq{"ai_model_manage.type": opts.Type}, + ) + } + + count, err := sess.Where(cond).Count(new(AiModelManage)) + if err != nil { + return nil, 0, fmt.Errorf("Count: %v", err) + } + + if opts.Page >= 0 && opts.PageSize > 0 { + var start int + if opts.Page == 0 { + start = 0 + } else { + start = (opts.Page - 1) * opts.PageSize + } + sess.Limit(opts.PageSize, start) + } + + sess.OrderBy("ai_model_manage.created_unix DESC") + aiModelManages := make([]*AiModelManage, 0, setting.UI.IssuePagingNum) + if err := sess.Table(&AiModelManage{}).Where(cond). + Join("left", "`user`", "ai_model_manage.user_id = `user`.id"). + Find(&aiModelManages); err != nil { + return nil, 0, fmt.Errorf("Find: %v", err) + } + sess.Close() + + return aiModelManages, count, nil +} diff --git a/routers/private/internal.go b/routers/private/internal.go index 5b1d63bff..29debdfdc 100755 --- a/routers/private/internal.go +++ b/routers/private/internal.go @@ -45,5 +45,7 @@ func RegisterRoutes(m *macaron.Macaron) { m.Post("/tool/update_all_repo_commit_cnt", UpdateAllRepoCommitCnt) m.Post("/tool/repo_stat", RepoStatisticManually) m.Post("/tool/create_model", CreateModel) + m.Post("/tool/delete_model", DeleteModel) + m.Post("/tool/show_model", ShowModel) }, CheckInternalToken) } diff --git a/routers/private/tool.go b/routers/private/tool.go index 3b8322d94..f3487836b 100755 --- a/routers/private/tool.go +++ b/routers/private/tool.go @@ -56,3 +56,18 @@ func CreateModel(ctx *macaron.Context) { repo.SaveModelByParameters(trainTaskId, name, version, label, description, userId) } + +func DeleteModel(ctx *macaron.Context) { + id := ctx.Query("id") + repo.DeleteModelByID(id) +} + +func ShowModel(ctx *macaron.Context) { + repoId := ctx.QueryInt64("repoId") + modelResult, _, err := repo.QueryModelByParameters(repoId, 5) + if err == nil { + ctx.JSON(200, modelResult) + } else { + ctx.JSON(500, "query error.") + } +} diff --git a/routers/repo/ai_model_manage.go b/routers/repo/ai_model_manage.go index c647483a5..6ad37a814 100644 --- a/routers/repo/ai_model_manage.go +++ b/routers/repo/ai_model_manage.go @@ -43,7 +43,7 @@ func SaveModelByParameters(trainTaskId string, name string, version string, labe } cloudType = aiTask.Type //download model zip //train type - if cloudType == models.TypeCloudBrainTrainJob { + if cloudType == models.TypeCloudBrainTwo { modelPath, modelSize, err = downloadModelFromCloudBrainTwo(id, aiTask.JobName, "") if err == nil { @@ -174,7 +174,7 @@ func downloadModelFromCloudBrainTwo(modelUUID string, jobName string, parentDir func DeleteModel(ctx *context.Context) { log.Info("delete model start.") id := ctx.Query("ID") - err := models.DeleteModelById(id) + err := DeleteModelByID(id) if err != nil { ctx.JSON(500, err.Error()) } else { @@ -184,14 +184,54 @@ func DeleteModel(ctx *context.Context) { } } +func DeleteModelByID(id string) error { + log.Info("delete model start. id=" + id) + return models.DeleteModelById(id) +} + func DownloadModel(ctx *context.Context) { log.Info("download model start.") } +func QueryModelByParameters(repoId int64, page int) ([]*models.AiModelManage, int64, error) { + + return models.QueryModel(&models.AiModelQueryOptions{ + ListOptions: models.ListOptions{ + Page: page, + PageSize: setting.UI.IssuePagingNum, + }, + RepoID: repoId, + }) +} + func ShowModelInfo(ctx *context.Context) { - log.Info("ShowModelInfo.") + log.Info("ShowModelInfo start.") + + page := ctx.QueryInt("page") + if page <= 0 { + page = 1 + } + repoId := ctx.QueryInt64("repoId") + + modelResult, count, err := models.QueryModel(&models.AiModelQueryOptions{ + ListOptions: models.ListOptions{ + Page: page, + PageSize: setting.UI.IssuePagingNum, + }, + RepoID: repoId, + }) + if err != nil { + ctx.ServerError("Cloudbrain", err) + return + } + pager := context.NewPagination(int(count), setting.UI.IssuePagingNum, page, 5) + pager.SetDefaultParams(ctx) + ctx.Data["Page"] = pager + ctx.Data["PageIsCloudBrain"] = true + ctx.Data["Tasks"] = modelResult + ctx.HTML(200, "") } func downloadModelFromCloudBrainOne(modelUUID string, jobName string, parentDir string) (string, int64, error) {