diff --git a/desc/schedule/pcm-schedule.api b/desc/schedule/pcm-schedule.api index b7e59dda..bbf05fc0 100644 --- a/desc/schedule/pcm-schedule.api +++ b/desc/schedule/pcm-schedule.api @@ -154,6 +154,7 @@ type ( // 调度资源信息:/queryResources QueryResourcesReq{ + Type string `json:"type"` ClusterIDs []string `json:"clusterIDs,optional"` } diff --git a/internal/cron/cron.go b/internal/cron/cron.go index 6602f508..c3b2ef1d 100644 --- a/internal/cron/cron.go +++ b/internal/cron/cron.go @@ -42,11 +42,16 @@ func AddCronGroup(svc *svc.ServiceContext) { svc.Cron.AddFunc("0 5/5 * * * *", func() { queryResource := schedule.NewQueryResourcesLogic(svc.HttpClient.R().Context(), svc) - rus, err := queryResource.QueryResourcesByClusterId(nil) + trainResrc, err := queryResource.QueryResourcesByClusterId(nil, "Train") if err != nil { logx.Error(err) } - svc.Scheduler.AiService.LocalCache[schedule.QUERY_RESOURCES] = rus + svc.Scheduler.AiService.LocalCache[schedule.QUERY_TRAIN_RESOURCES] = trainResrc + inferResrc, err := queryResource.QueryResourcesByClusterId(nil, "Inference") + if err != nil { + logx.Error(err) + } + svc.Scheduler.AiService.LocalCache[schedule.QUERY_INFERENCE_RESOURCES] = inferResrc }) //更新hpc任务状态 diff --git a/internal/logic/schedule/queryresourceslogic.go b/internal/logic/schedule/queryresourceslogic.go index e5797ef3..82fac13f 100644 --- a/internal/logic/schedule/queryresourceslogic.go +++ b/internal/logic/schedule/queryresourceslogic.go @@ -15,8 +15,9 @@ import ( ) const ( - ADAPTERID = "1777144940459986944" // 异构适配器id - QUERY_RESOURCES = "query_resources" + ADAPTERID = "1777144940459986944" // 异构适配器id + QUERY_TRAIN_RESOURCES = "train_resources" + QUERY_INFERENCE_RESOURCES = "inference_resources" ) type QueryResourcesLogic struct { @@ -41,25 +42,29 @@ func (l *QueryResourcesLogic) QueryResources(req *types.QueryResourcesReq) (resp if err != nil { return nil, err } - resources, ok := l.svcCtx.Scheduler.AiService.LocalCache[QUERY_RESOURCES] + + var resources interface{} + switch req.Type { + case "Train": + resources, _ = l.svcCtx.Scheduler.AiService.LocalCache[QUERY_TRAIN_RESOURCES] + case "Inference": + resources, _ = l.svcCtx.Scheduler.AiService.LocalCache[QUERY_INFERENCE_RESOURCES] + default: + resources, _ = l.svcCtx.Scheduler.AiService.LocalCache[QUERY_TRAIN_RESOURCES] + } + + specs, ok := resources.([]*collector.ResourceSpec) if ok { - specs, ok := resources.([]*collector.ResourceSpec) - if ok { - results := handleEmptyResourceUsage(cs.List, specs) - resp.Data = results - return resp, nil - } + results := handleEmptyResourceUsage(cs.List, specs) + resp.Data = results + return resp, nil } - rus, err := l.QueryResourcesByClusterId(cs.List) + rus, err := l.QueryResourcesByClusterId(cs.List, req.Type) if err != nil { return nil, err } - if checkCachingCondition(cs.List, rus) { - l.svcCtx.Scheduler.AiService.LocalCache[QUERY_RESOURCES] = rus - } - results := handleEmptyResourceUsage(cs.List, rus) resp.Data = results @@ -77,7 +82,7 @@ func (l *QueryResourcesLogic) QueryResources(req *types.QueryResourcesReq) (resp return nil, errors.New("no clusters found ") } - rus, err := l.QueryResourcesByClusterId(clusters) + rus, err := l.QueryResourcesByClusterId(clusters, req.Type) if err != nil { return nil, err } @@ -89,7 +94,7 @@ func (l *QueryResourcesLogic) QueryResources(req *types.QueryResourcesReq) (resp return resp, nil } -func (l *QueryResourcesLogic) QueryResourcesByClusterId(clusterinfos []types.ClusterInfo) ([]*collector.ResourceSpec, error) { +func (l *QueryResourcesLogic) QueryResourcesByClusterId(clusterinfos []types.ClusterInfo, resrcType string) ([]*collector.ResourceSpec, error) { var clusters []types.ClusterInfo if len(clusterinfos) == 0 { cs, err := l.svcCtx.Scheduler.AiStorages.GetClustersByAdapterId(ADAPTERID) @@ -121,7 +126,7 @@ func (l *QueryResourcesLogic) QueryResourcesByClusterId(clusterinfos []types.Clu return } - u, err = col.GetResourceSpecs(l.ctx) + u, err = col.GetResourceSpecs(l.ctx, resrcType) if err != nil { done <- true return diff --git a/internal/logic/schedule/schedulecreatetasklogic.go b/internal/logic/schedule/schedulecreatetasklogic.go index 19634d68..474e54b2 100644 --- a/internal/logic/schedule/schedulecreatetasklogic.go +++ b/internal/logic/schedule/schedulecreatetasklogic.go @@ -218,7 +218,7 @@ func (l *ScheduleCreateTaskLogic) getAssignedClustersByStrategy(resources *types var resCount int for i := 0; i < QUERY_RESOURCE_RETRY; i++ { defer time.Sleep(time.Second) - qResources, err := l.queryResource.QueryResourcesByClusterId(nil) + qResources, err := l.queryResource.QueryResourcesByClusterId(nil, "Train") if err != nil { continue } diff --git a/internal/scheduler/service/collector/collector.go b/internal/scheduler/service/collector/collector.go index d6172788..2e539f36 100644 --- a/internal/scheduler/service/collector/collector.go +++ b/internal/scheduler/service/collector/collector.go @@ -14,7 +14,7 @@ type AiCollector interface { UploadAlgorithmCode(ctx context.Context, resourceType string, card string, taskType string, dataset string, algorithm string, code string) error GetComputeCards(ctx context.Context) ([]string, error) GetUserBalance(ctx context.Context) (float64, error) - GetResourceSpecs(ctx context.Context) (*ResourceSpec, error) + GetResourceSpecs(ctx context.Context, resrcType string) (*ResourceSpec, error) } type ResourceSpec struct { diff --git a/internal/storeLink/modelarts.go b/internal/storeLink/modelarts.go index 1a599855..29aaa6a0 100644 --- a/internal/storeLink/modelarts.go +++ b/internal/storeLink/modelarts.go @@ -996,7 +996,7 @@ func (m *ModelArtsLink) CheckImageExist(ctx context.Context, option *option.Infe return errors.New("failed to find Image ") } -func (m *ModelArtsLink) GetResourceSpecs(ctx context.Context) (*collector.ResourceSpec, error) { +func (m *ModelArtsLink) GetResourceSpecs(ctx context.Context, resrcType string) (*collector.ResourceSpec, error) { var wg sync.WaitGroup //查询modelarts资源规格 req := &modelarts.GetResourceFlavorsReq{} diff --git a/internal/storeLink/octopus.go b/internal/storeLink/octopus.go index 30a108b7..e33e438e 100644 --- a/internal/storeLink/octopus.go +++ b/internal/storeLink/octopus.go @@ -1279,7 +1279,7 @@ func (o *OctopusLink) CheckModelExistence(ctx context.Context, name string, mtyp return true } -func (o *OctopusLink) GetResourceSpecs(ctx context.Context) (*collector.ResourceSpec, error) { +func (o *OctopusLink) GetResourceSpecs(ctx context.Context, resrcType string) (*collector.ResourceSpec, error) { res := &collector.ResourceSpec{ ClusterId: strconv.FormatInt(o.participantId, 10), Resources: make([]interface{}, 0), diff --git a/internal/storeLink/openi.go b/internal/storeLink/openi.go index 2d10b408..1c241d3f 100644 --- a/internal/storeLink/openi.go +++ b/internal/storeLink/openi.go @@ -770,7 +770,14 @@ func (o *OpenI) GetUserBalance(ctx context.Context) (float64, error) { return 0, errors.New("failed to implement") } -func (o *OpenI) GetResourceSpecs(ctx context.Context) (*collector.ResourceSpec, error) { +func (o *OpenI) GetResourceSpecs(ctx context.Context, resrcType string) (*collector.ResourceSpec, error) { + var jobType string + if resrcType == "Inference" { + jobType = ONLINEINFERENCE + } else if resrcType == "Train" { + jobType = TRAIN + } + var resources []interface{} res := &collector.ResourceSpec{ ClusterId: strconv.FormatInt(o.participantId, 10), @@ -795,7 +802,7 @@ func (o *OpenI) GetResourceSpecs(ctx context.Context) (*collector.ResourceSpec, param := model.TaskCreationRequiredParam{ UserName: o.userName, RepoName: TESTREPO, - JobType: TRAIN, + JobType: jobType, ComputeSource: ComputeSource[i], ClusterType: C2NET, } diff --git a/internal/storeLink/shuguangai.go b/internal/storeLink/shuguangai.go index 2fbce5f5..92ee1ae0 100644 --- a/internal/storeLink/shuguangai.go +++ b/internal/storeLink/shuguangai.go @@ -1103,7 +1103,7 @@ func (s *ShuguangAi) CheckModelExistence(ctx context.Context, name string, mtype return resp.Data.Exist } -func (s *ShuguangAi) GetResourceSpecs(ctx context.Context) (*collector.ResourceSpec, error) { +func (s *ShuguangAi) GetResourceSpecs(ctx context.Context, resrcType string) (*collector.ResourceSpec, error) { return nil, nil //var timeout = 5 //var wg sync.WaitGroup diff --git a/internal/storeLink/template.go b/internal/storeLink/template.go index f5147016..76fceb2d 100644 --- a/internal/storeLink/template.go +++ b/internal/storeLink/template.go @@ -118,6 +118,6 @@ func (o Template) GetUserBalance(ctx context.Context) (float64, error) { } // GetResourceSpecs 查询资源规格 -func (o Template) GetResourceSpecs(ctx context.Context) (*collector.ResourceSpec, error) { +func (o Template) GetResourceSpecs(ctx context.Context, resrcType string) (*collector.ResourceSpec, error) { return nil, errors.New("failed to implement") } diff --git a/internal/types/types.go b/internal/types/types.go index c9052a2d..29fde3e4 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -6013,6 +6013,7 @@ type GetClusterBalanceByIdResp struct { } type QueryResourcesReq struct { + Type string `json:"type"` ClusterIDs []string `json:"clusterIDs,optional"` }