| @@ -564,6 +564,17 @@ type FlavorInfo struct { | |||
| Desc string `json:"desc"` | |||
| } | |||
| type SpecialPools struct { | |||
| Pools []*SpecialPool `json:"pools"` | |||
| } | |||
| type SpecialPool struct { | |||
| Org string `json:"org"` | |||
| Type string `json:"type"` | |||
| IsExclusive bool `json:"isExclusive"` | |||
| Pool []*GpuInfo `json:"pool"` | |||
| JobType []string `json:"jobType"` | |||
| } | |||
| type ImageInfosModelArts struct { | |||
| ImageInfo []*ImageInfoModelArts `json:"image_info"` | |||
| } | |||
| @@ -1,12 +1,16 @@ | |||
| package grampus | |||
| import ( | |||
| "encoding/json" | |||
| "strings" | |||
| "code.gitea.io/gitea/modules/setting" | |||
| "code.gitea.io/gitea/models" | |||
| "code.gitea.io/gitea/modules/context" | |||
| "code.gitea.io/gitea/modules/log" | |||
| "code.gitea.io/gitea/modules/notification" | |||
| "code.gitea.io/gitea/modules/timeutil" | |||
| "strings" | |||
| ) | |||
| const ( | |||
| @@ -28,6 +32,8 @@ var ( | |||
| poolInfos *models.PoolInfos | |||
| FlavorInfos *models.FlavorInfos | |||
| ImageInfos *models.ImageInfosModelArts | |||
| SpecialPools *models.SpecialPools | |||
| ) | |||
| type GenerateTrainJobReq struct { | |||
| @@ -63,6 +69,27 @@ type GenerateTrainJobReq struct { | |||
| func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (err error) { | |||
| createTime := timeutil.TimeStampNow() | |||
| var CenterID []string | |||
| var CenterName []string | |||
| if SpecialPools != nil { | |||
| for _, pool := range SpecialPools.Pools { | |||
| if !pool.IsExclusive && strings.Contains(req.ComputeResource, pool.Type) { | |||
| org, _ := models.GetOrgByName(pool.Org) | |||
| if org != nil { | |||
| isOrgMember, _ := models.IsOrganizationMember(org.ID, ctx.User.ID) | |||
| if isOrgMember { | |||
| for _, info := range pool.Pool { | |||
| CenterID = append(CenterID, info.Queue) | |||
| CenterName = append(CenterName, info.Value) | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| jobResult, err := createJob(models.CreateGrampusJobRequest{ | |||
| Name: req.JobName, | |||
| Tasks: []models.GrampusTasks{ | |||
| @@ -72,6 +99,8 @@ func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (err error | |||
| ResourceSpecId: req.ResourceSpecId, | |||
| ImageId: req.ImageId, | |||
| ImageUrl: req.ImageUrl, | |||
| CenterID: CenterID, | |||
| CenterName: CenterName, | |||
| ReplicaNum: 0, | |||
| }, | |||
| }, | |||
| @@ -136,3 +165,8 @@ func TransTrainJobStatus(status string) string { | |||
| return strings.ToUpper(status) | |||
| } | |||
| func InitSpecialPool() { | |||
| if SpecialPools == nil && setting.Grampus.SpecialPools != "" { | |||
| json.Unmarshal([]byte(setting.Grampus.SpecialPools), &SpecialPools) | |||
| } | |||
| } | |||
| @@ -1178,6 +1178,7 @@ model.manage.model_accuracy = Model Accuracy | |||
| grampus.train_job.ai_center = AI Center | |||
| grampus.dataset_path_rule = The code is storaged in /tmp/code;the dataset is storaged in /tmp/dataset;and please put your model into /tmp/output, then you can download it online。 | |||
| grampus.no_operate_right = You have no right to do this operation. | |||
| template.items = Template Items | |||
| template.git_content = Git Content (Default Branch) | |||
| @@ -1193,6 +1193,8 @@ model.manage.model_accuracy = 模型精度 | |||
| grampus.train_job.ai_center=智算中心 | |||
| grampus.dataset_path_rule = 训练脚本存储在/tmp/code中,数据集存储在/tmp/dataset中,训练输出请存储在/tmp/output中以供后续下载。 | |||
| grampus.no_operate_right = 您没有权限创建这类任务。 | |||
| template.items=模板选项 | |||
| template.git_content=Git数据(默认分支) | |||
| template.git_hooks=Git 钩子 | |||
| @@ -71,6 +71,25 @@ func grampusTrainJobNewDataPrepare(ctx *context.Context, processType string) err | |||
| ctx.Data["images"] = images.Infos | |||
| } | |||
| grampus.InitSpecialPool() | |||
| ctx.Data["GPUEnabled"] = true | |||
| ctx.Data["NPUEnabled"] = true | |||
| if grampus.SpecialPools != nil { | |||
| for _, pool := range grampus.SpecialPools.Pools { | |||
| if pool.IsExclusive { | |||
| org, _ := models.GetOrgByName(pool.Org) | |||
| if org != nil { | |||
| isOrgMember, _ := models.IsOrganizationMember(org.ID, ctx.User.ID) | |||
| if !isOrgMember { | |||
| ctx.Data[pool.Type+"Enabled"] = false | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| //get valid resource specs | |||
| specs, err := grampus.GetResourceSpecs(processType) | |||
| if err != nil { | |||
| @@ -122,10 +141,17 @@ func GrampusTrainJobGpuCreate(ctx *context.Context, form auth.CreateGrampusTrain | |||
| image := strings.TrimSpace(form.Image) | |||
| if !jobNamePattern.MatchString(displayJobName) { | |||
| grampusTrainJobNewDataPrepare(ctx, grampus.ProcessorTypeGPU) | |||
| ctx.RenderWithErr(ctx.Tr("repo.cloudbrain_jobname_err"), tplGrampusTrainJobGPUNew, &form) | |||
| return | |||
| } | |||
| errStr := checkSpecialPool(ctx, "GPU") | |||
| if errStr != "" { | |||
| grampusTrainJobNewDataPrepare(ctx, grampus.ProcessorTypeGPU) | |||
| ctx.RenderWithErr(errStr, tplGrampusTrainJobGPUNew, &form) | |||
| } | |||
| //check count limit | |||
| count, err := models.GetGrampusCountByUserID(ctx.User.ID, string(models.JobTypeTrain), models.GPUResource) | |||
| if err != nil { | |||
| @@ -257,6 +283,28 @@ func GrampusTrainJobGpuCreate(ctx *context.Context, form auth.CreateGrampusTrain | |||
| ctx.Redirect(setting.AppSubURL + ctx.Repo.RepoLink + "/modelarts/train-job") | |||
| } | |||
| func checkSpecialPool(ctx *context.Context, resourceType string) string { | |||
| grampus.InitSpecialPool() | |||
| if grampus.SpecialPools != nil { | |||
| for _, pool := range grampus.SpecialPools.Pools { | |||
| if pool.IsExclusive && pool.Type == resourceType { | |||
| org, _ := models.GetOrgByName(pool.Org) | |||
| if org != nil { | |||
| isOrgMember, _ := models.IsOrganizationMember(org.ID, ctx.User.ID) | |||
| if !isOrgMember { | |||
| return ctx.Tr("repo.grampus.no_operate_right") | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return "" | |||
| } | |||
| func GrampusTrainJobNpuCreate(ctx *context.Context, form auth.CreateGrampusTrainJobForm) { | |||
| displayJobName := form.DisplayJobName | |||
| jobName := util.ConvertDisplayJobNameToJobName(displayJobName) | |||
| @@ -275,10 +323,17 @@ func GrampusTrainJobNpuCreate(ctx *context.Context, form auth.CreateGrampusTrain | |||
| engineName := form.EngineName | |||
| if !jobNamePattern.MatchString(displayJobName) { | |||
| grampusTrainJobNewDataPrepare(ctx, grampus.ProcessorTypeNPU) | |||
| ctx.RenderWithErr(ctx.Tr("repo.cloudbrain_jobname_err"), tplGrampusTrainJobNPUNew, &form) | |||
| return | |||
| } | |||
| errStr := checkSpecialPool(ctx, "NPU") | |||
| if errStr != "" { | |||
| grampusTrainJobNewDataPrepare(ctx, grampus.ProcessorTypeNPU) | |||
| ctx.RenderWithErr(errStr, tplGrampusTrainJobGPUNew, &form) | |||
| } | |||
| //check count limit | |||
| count, err := models.GetGrampusCountByUserID(ctx.User.ID, string(models.JobTypeTrain), models.NPUResource) | |||
| if err != nil { | |||