diff --git a/models/cloudbrain.go b/models/cloudbrain.go index 30dd08aa4..c59554208 100755 --- a/models/cloudbrain.go +++ b/models/cloudbrain.go @@ -564,6 +564,17 @@ type FlavorInfo struct { Desc string `json:"desc"` } +type SpecialPools struct { + Pools []*SpecialPool `json:"pools"` +} +type SpecialPool struct { + Org string `json:"org"` + Type string `json:"type"` + IsExclusive bool `json:"isExclusive"` + Pool []*GpuInfo `json:"pool"` + JobType []string `json:"jobType"` +} + type ImageInfosModelArts struct { ImageInfo []*ImageInfoModelArts `json:"image_info"` } diff --git a/modules/grampus/grampus.go b/modules/grampus/grampus.go index 11749a741..47734c1aa 100755 --- a/modules/grampus/grampus.go +++ b/modules/grampus/grampus.go @@ -1,12 +1,16 @@ package grampus import ( + "encoding/json" + "strings" + + "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/models" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/notification" "code.gitea.io/gitea/modules/timeutil" - "strings" ) const ( @@ -28,6 +32,8 @@ var ( poolInfos *models.PoolInfos FlavorInfos *models.FlavorInfos ImageInfos *models.ImageInfosModelArts + + SpecialPools *models.SpecialPools ) type GenerateTrainJobReq struct { @@ -63,6 +69,27 @@ type GenerateTrainJobReq struct { func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (err error) { createTime := timeutil.TimeStampNow() + + var CenterID []string + var CenterName []string + + if SpecialPools != nil { + for _, pool := range SpecialPools.Pools { + if !pool.IsExclusive && strings.Contains(req.ComputeResource, pool.Type) { + org, _ := models.GetOrgByName(pool.Org) + if org != nil { + isOrgMember, _ := models.IsOrganizationMember(org.ID, ctx.User.ID) + if isOrgMember { + for _, info := range pool.Pool { + CenterID = append(CenterID, info.Queue) + CenterName = append(CenterName, info.Value) + } + } + } + } + } + } + jobResult, err := createJob(models.CreateGrampusJobRequest{ Name: req.JobName, Tasks: []models.GrampusTasks{ @@ -72,6 +99,8 @@ func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (err error ResourceSpecId: req.ResourceSpecId, ImageId: req.ImageId, ImageUrl: req.ImageUrl, + CenterID: CenterID, + CenterName: CenterName, ReplicaNum: 0, }, }, @@ -136,3 +165,8 @@ func TransTrainJobStatus(status string) string { return strings.ToUpper(status) } +func InitSpecialPool() { + if SpecialPools == nil && setting.Grampus.SpecialPools != "" { + json.Unmarshal([]byte(setting.Grampus.SpecialPools), &SpecialPools) + } +} diff --git a/options/locale/locale_en-US.ini b/options/locale/locale_en-US.ini index bcab8d829..13f1e3a36 100755 --- a/options/locale/locale_en-US.ini +++ b/options/locale/locale_en-US.ini @@ -1178,6 +1178,7 @@ model.manage.model_accuracy = Model Accuracy grampus.train_job.ai_center = AI Center grampus.dataset_path_rule = The code is storaged in /tmp/code;the dataset is storaged in /tmp/dataset;and please put your model into /tmp/output, then you can download it online。 +grampus.no_operate_right = You have no right to do this operation. template.items = Template Items template.git_content = Git Content (Default Branch) diff --git a/options/locale/locale_zh-CN.ini b/options/locale/locale_zh-CN.ini index 46d47238d..86ed4a937 100755 --- a/options/locale/locale_zh-CN.ini +++ b/options/locale/locale_zh-CN.ini @@ -1193,6 +1193,8 @@ model.manage.model_accuracy = 模型精度 grampus.train_job.ai_center=智算中心 grampus.dataset_path_rule = 训练脚本存储在/tmp/code中,数据集存储在/tmp/dataset中,训练输出请存储在/tmp/output中以供后续下载。 +grampus.no_operate_right = 您没有权限创建这类任务。 + template.items=模板选项 template.git_content=Git数据(默认分支) template.git_hooks=Git 钩子 diff --git a/routers/repo/grampus.go b/routers/repo/grampus.go index 35e2c5feb..b92644acc 100755 --- a/routers/repo/grampus.go +++ b/routers/repo/grampus.go @@ -71,6 +71,25 @@ func grampusTrainJobNewDataPrepare(ctx *context.Context, processType string) err ctx.Data["images"] = images.Infos } + grampus.InitSpecialPool() + + ctx.Data["GPUEnabled"] = true + ctx.Data["NPUEnabled"] = true + + if grampus.SpecialPools != nil { + for _, pool := range grampus.SpecialPools.Pools { + if pool.IsExclusive { + org, _ := models.GetOrgByName(pool.Org) + if org != nil { + isOrgMember, _ := models.IsOrganizationMember(org.ID, ctx.User.ID) + if !isOrgMember { + ctx.Data[pool.Type+"Enabled"] = false + } + } + } + } + } + //get valid resource specs specs, err := grampus.GetResourceSpecs(processType) if err != nil { @@ -122,10 +141,17 @@ func GrampusTrainJobGpuCreate(ctx *context.Context, form auth.CreateGrampusTrain image := strings.TrimSpace(form.Image) if !jobNamePattern.MatchString(displayJobName) { + grampusTrainJobNewDataPrepare(ctx, grampus.ProcessorTypeGPU) ctx.RenderWithErr(ctx.Tr("repo.cloudbrain_jobname_err"), tplGrampusTrainJobGPUNew, &form) return } + errStr := checkSpecialPool(ctx, "GPU") + if errStr != "" { + grampusTrainJobNewDataPrepare(ctx, grampus.ProcessorTypeGPU) + ctx.RenderWithErr(errStr, tplGrampusTrainJobGPUNew, &form) + } + //check count limit count, err := models.GetGrampusCountByUserID(ctx.User.ID, string(models.JobTypeTrain), models.GPUResource) if err != nil { @@ -257,6 +283,28 @@ func GrampusTrainJobGpuCreate(ctx *context.Context, form auth.CreateGrampusTrain ctx.Redirect(setting.AppSubURL + ctx.Repo.RepoLink + "/modelarts/train-job") } +func checkSpecialPool(ctx *context.Context, resourceType string) string { + grampus.InitSpecialPool() + if grampus.SpecialPools != nil { + for _, pool := range grampus.SpecialPools.Pools { + + if pool.IsExclusive && pool.Type == resourceType { + + org, _ := models.GetOrgByName(pool.Org) + if org != nil { + isOrgMember, _ := models.IsOrganizationMember(org.ID, ctx.User.ID) + if !isOrgMember { + return ctx.Tr("repo.grampus.no_operate_right") + } + } + } + + } + + } + return "" +} + func GrampusTrainJobNpuCreate(ctx *context.Context, form auth.CreateGrampusTrainJobForm) { displayJobName := form.DisplayJobName jobName := util.ConvertDisplayJobNameToJobName(displayJobName) @@ -275,10 +323,17 @@ func GrampusTrainJobNpuCreate(ctx *context.Context, form auth.CreateGrampusTrain engineName := form.EngineName if !jobNamePattern.MatchString(displayJobName) { + grampusTrainJobNewDataPrepare(ctx, grampus.ProcessorTypeNPU) ctx.RenderWithErr(ctx.Tr("repo.cloudbrain_jobname_err"), tplGrampusTrainJobNPUNew, &form) return } + errStr := checkSpecialPool(ctx, "NPU") + if errStr != "" { + grampusTrainJobNewDataPrepare(ctx, grampus.ProcessorTypeNPU) + ctx.RenderWithErr(errStr, tplGrampusTrainJobGPUNew, &form) + } + //check count limit count, err := models.GetGrampusCountByUserID(ctx.User.ID, string(models.JobTypeTrain), models.NPUResource) if err != nil {