package inference import ( "context" "errors" "github.com/go-resty/resty/v2" "github.com/zeromicro/go-zero/core/logx" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/collector" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/strategy" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" "math/rand" "mime/multipart" "net/http" "sort" "strconv" "sync" "time" ) type ImageInferenceLogic struct { logx.Logger ctx context.Context svcCtx *svc.ServiceContext } func NewImageInferenceLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ImageInferenceLogic { return &ImageInferenceLogic{ Logger: logx.WithContext(ctx), ctx: ctx, svcCtx: svcCtx, } } func (l *ImageInferenceLogic) ImageInference(req *types.ImageInferenceReq) (resp *types.ImageInferenceResp, err error) { return nil, nil } func (l *ImageInferenceLogic) ImageInfer(r *http.Request, req *types.ImageInferenceReq) (resp *types.ImageInferenceResp, err error) { resp = &types.ImageInferenceResp{} opt := &option.InferOption{ TaskName: req.TaskName, TaskDesc: req.TaskDesc, AdapterId: req.AdapterId, AiClusterIds: req.AiClusterIds, ModelName: req.ModelName, ModelType: req.ModelType, Strategy: req.Strategy, StaticWeightMap: req.StaticWeightMap, } var ts []struct { imageResult *types.ImageResult file multipart.File } uploadedFiles := r.MultipartForm.File if len(uploadedFiles) == 0 { return nil, errors.New("Images does not exist") } if len(uploadedFiles["images"]) == 0 { return nil, errors.New("Images does not exist") } for _, header := range uploadedFiles["images"] { file, err := header.Open() if err != nil { return nil, err } defer file.Close() var ir types.ImageResult ir.ImageName = header.Filename t := struct { imageResult *types.ImageResult file multipart.File }{ imageResult: &ir, file: file, } ts = append(ts, t) } _, ok := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId] if !ok { return nil, errors.New("AdapterId does not exist") } var strat strategy.Strategy switch opt.Strategy { case strategy.STATIC_WEIGHT: strat = strategy.NewStaticWeightStrategy(opt.StaticWeightMap, int32(len(ts))) if err != nil { return nil, err } default: return nil, errors.New("no strategy has been chosen") } clusters, err := strat.Schedule() if err != nil { return nil, err } results, err := infer(opt, clusters, ts, l.svcCtx, l.ctx) if err != nil { return nil, err } resp.InferResults = results return resp, nil } var acs []*strategy.AssignedCluster var aiTaskList []*models.TaskAi func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []struct { imageResult *types.ImageResult file multipart.File }, svcCtx *svc.ServiceContext, ctx context.Context) ([]*types.ImageResult, error) { if clusters == nil || len(clusters) == 0 { return nil, errors.New("clusters is nil") } for i := len(clusters) - 1; i >= 0; i-- { if clusters[i].Replicas == 0 { clusters = append(clusters[:i], clusters[i+1:]...) } } var wg sync.WaitGroup var cluster_ch = make(chan struct { urls []*collector.ImageInferUrl clusterId string clusterName string imageNum int32 }, len(clusters)) var cs []struct { urls []*collector.ImageInferUrl clusterId string clusterName string imageNum int32 } collectorMap := svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId] //save task var synergystatus int64 if len(clusters) > 1 { synergystatus = 1 } strategyCode, err := svcCtx.Scheduler.AiStorages.GetStrategyCode(opt.Strategy) if err != nil { return nil, err } adapterName, err := svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId) if err != nil { return nil, err } id, err := svcCtx.Scheduler.AiStorages.SaveTask(opt.TaskName, strategyCode, synergystatus, "11") if err != nil { return nil, err } svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "create", "任务创建中") //save taskai for _, c := range clusters { clusterName, _ := svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId) opt.Replica = c.Replicas err := svcCtx.Scheduler.AiStorages.SaveAiTask(id, opt, adapterName, c.ClusterId, clusterName, "", constants.Saved, "") if err != nil { return nil, err } } for _, cluster := range clusters { wg.Add(1) c := cluster go func() { imageUrls, err := collectorMap[c.ClusterId].GetImageInferUrl(ctx, opt) if err != nil { wg.Done() return } clusterName, _ := svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId) s := struct { urls []*collector.ImageInferUrl clusterId string clusterName string imageNum int32 }{ urls: imageUrls, clusterId: c.ClusterId, clusterName: clusterName, imageNum: c.Replicas, } cluster_ch <- s wg.Done() return }() } wg.Wait() close(cluster_ch) for s := range cluster_ch { cs = append(cs, s) } tx := svcCtx.DbEngin.Raw("select * from task_ai where `task_id` = ? ", id).Scan(&aiTaskList) if tx.Error != nil { return nil, tx.Error } //change cluster status if len(clusters) != len(cs) { for _, cluster := range clusters { if contains(cs, cluster.ClusterId) { continue } else { var ac *strategy.AssignedCluster ac = cluster acs = append(acs, ac) } } // update failed cluster status for _, ac := range acs { for _, t := range aiTaskList { if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { t.Status = constants.Failed err := svcCtx.Scheduler.AiStorages.UpdateAiTask(t) if err != nil { logx.Errorf(tx.Error.Error()) } } } } } var result_ch = make(chan *types.ImageResult, len(ts)) var results []*types.ImageResult var imageNumIdx int32 = 0 var imageNumIdxEnd int32 = 0 for _, c := range cs { new_images := make([]struct { imageResult *types.ImageResult file multipart.File }, len(ts)) copy(new_images, ts) imageNumIdxEnd = imageNumIdxEnd + c.imageNum new_images = new_images[imageNumIdx:imageNumIdxEnd] imageNumIdx = imageNumIdx + c.imageNum wg.Add(len(new_images)) go sendInferReq(new_images, c, &wg, *svcCtx, result_ch) } wg.Wait() close(result_ch) for s := range result_ch { results = append(results, s) } sort.Slice(results, func(p, q int) bool { return results[p].ClusterName < results[q].ClusterName }) // update succeeded cluster status for _, c := range cs { for _, t := range aiTaskList { if c.clusterId == strconv.Itoa(int(t.ClusterId)) { t.Status = constants.Completed err := svcCtx.Scheduler.AiStorages.UpdateAiTask(t) if err != nil { logx.Errorf(tx.Error.Error()) } } } } return results, nil } func sendInferReq(images []struct { imageResult *types.ImageResult file multipart.File }, cluster struct { urls []*collector.ImageInferUrl clusterId string clusterName string imageNum int32 }, wg *sync.WaitGroup, svcCtx svc.ServiceContext, ch chan<- *types.ImageResult) { for _, image := range images { go func(t struct { imageResult *types.ImageResult file multipart.File }, c struct { urls []*collector.ImageInferUrl clusterId string clusterName string imageNum int32 }) { if len(c.urls) == 1 { r, err := getInferResult(c.urls[0].Url, t.file, t.imageResult.ImageName) if err != nil { t.imageResult.ImageResult = err.Error() t.imageResult.ClusterName = c.clusterName t.imageResult.Card = c.urls[0].Card ch <- t.imageResult wg.Done() return } t.imageResult.ImageResult = r t.imageResult.ClusterName = c.clusterName t.imageResult.Card = c.urls[0].Card ch <- t.imageResult wg.Done() return } else { idx := rand.Intn(len(c.urls)) r, err := getInferResult(c.urls[idx].Url, t.file, t.imageResult.ImageName) if err != nil { t.imageResult.ImageResult = err.Error() t.imageResult.ClusterName = c.clusterName t.imageResult.Card = c.urls[idx].Card ch <- t.imageResult wg.Done() return } t.imageResult.ImageResult = r t.imageResult.ClusterName = c.clusterName t.imageResult.Card = c.urls[idx].Card for _, ac := range acs { for _, task := range aiTaskList { if ac.ClusterId == strconv.Itoa(int(task.ClusterId)) && ac.ClusterId == t.imageResult.ClusterId { taskAiSub := &models.TaskAiSub{ Id: task.Id, ImageName: t.imageResult.ImageName, Result: t.imageResult.ImageResult, Card: t.imageResult.Card, ClusterId: task.ClusterId, ClusterName: t.imageResult.ClusterName, } tx := svcCtx.DbEngin.Save(&taskAiSub) if tx.Error != nil { logx.Errorf(err.Error()) } } continue } continue } ch <- t.imageResult wg.Done() return } }(image, cluster) } } func getInferResult(url string, file multipart.File, fileName string) (string, error) { var res Res req := GetRestyRequest(10) _, err := req. SetFileReader("file", fileName, file). SetResult(&res). Post(url) if err != nil { return "", err } return res.Result, nil } func GetRestyRequest(timeoutSeconds int64) *resty.Request { client := resty.New().SetTimeout(time.Duration(timeoutSeconds) * time.Second) request := client.R() return request } type Res struct { Result string `json:"result"` } func contains(cs []struct { urls []*collector.ImageInferUrl clusterId string clusterName string imageNum int32 }, e string) bool { for _, c := range cs { if c.clusterId == e { return true } } return false }