package inference import ( "context" "errors" "github.com/go-resty/resty/v2" "github.com/zeromicro/go-zero/core/logx" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/collector" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/strategy" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" "math/rand" "mime/multipart" "net/http" "sync" ) type ImageInferenceLogic struct { logx.Logger ctx context.Context svcCtx *svc.ServiceContext } func NewImageInferenceLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ImageInferenceLogic { return &ImageInferenceLogic{ Logger: logx.WithContext(ctx), ctx: ctx, svcCtx: svcCtx, } } func (l *ImageInferenceLogic) ImageInference(req *types.ImageInferenceReq) (resp *types.ImageInferenceResp, err error) { return nil, nil } func (l *ImageInferenceLogic) ImageInfer(r *http.Request, req *types.ImageInferenceReq) (resp *types.ImageInferenceResp, err error) { resp = &types.ImageInferenceResp{} opt := &option.InferOption{ TaskName: req.TaskName, TaskDesc: req.TaskDesc, AdapterId: req.AdapterId, AiClusterIds: req.AiClusterIds, ModelName: req.ModelName, ModelType: req.ModelType, Strategy: req.Strategy, StaticWeightMap: req.StaticWeightMap, } var ts []struct { imageResult *types.ImageResult file multipart.File } uploadedFiles := r.MultipartForm.File if len(uploadedFiles) == 0 { return nil, errors.New("Images does not exist") } if len(uploadedFiles["images"]) == 0 { return nil, errors.New("Images does not exist") } for _, header := range uploadedFiles["images"] { file, err := header.Open() if err != nil { return nil, err } defer file.Close() var ir types.ImageResult ir.ImageName = header.Filename t := struct { imageResult *types.ImageResult file multipart.File }{ imageResult: &ir, file: file, } ts = append(ts, t) } _, ok := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId] if !ok { return nil, errors.New("AdapterId does not exist") } var strat strategy.Strategy switch opt.Strategy { case strategy.STATIC_WEIGHT: //todo resources should match cluster StaticWeightMap strat = strategy.NewStaticWeightStrategy(opt.StaticWeightMap, int32(len(ts))) if err != nil { return nil, err } default: return nil, errors.New("no strategy has been chosen") } clusters, err := strat.Schedule() if err != nil { return nil, err } results, err := infer(opt, clusters, ts, l.svcCtx, l.ctx) if err != nil { return nil, err } resp.InferResults = results return resp, nil } func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []struct { imageResult *types.ImageResult file multipart.File }, svcCtx *svc.ServiceContext, ctx context.Context) ([]*types.ImageResult, error) { if clusters == nil || len(clusters) == 0 { return nil, errors.New("clusters is nil") } for i := len(clusters) - 1; i >= 0; i-- { if clusters[i].Replicas == 0 { clusters = append(clusters[:i], clusters[i+1:]...) } } var wg sync.WaitGroup var cluster_ch = make(chan struct { urls []*collector.ImageInferUrl clusterName string imageNum int32 }, len(clusters)) var cs []struct { urls []*collector.ImageInferUrl clusterName string imageNum int32 } collectorMap := svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId] for _, cluster := range clusters { wg.Add(1) c := cluster go func() { imageUrls, err := collectorMap[c.ClusterId].GetImageInferUrl(ctx, opt) if err != nil { return } clusterName, _ := svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId) s := struct { urls []*collector.ImageInferUrl clusterName string imageNum int32 }{ urls: imageUrls, clusterName: clusterName, imageNum: c.Replicas, } cluster_ch <- s wg.Done() return }() } wg.Wait() close(cluster_ch) for s := range cluster_ch { cs = append(cs, s) } var result_ch = make(chan *types.ImageResult, len(ts)) var results []*types.ImageResult wg.Add(len(ts)) var imageNumIdx int32 = 0 var imageNumIdxEnd int32 = 0 for _, c := range cs { new_images := make([]struct { imageResult *types.ImageResult file multipart.File }, len(ts)) copy(new_images, ts) imageNumIdxEnd = imageNumIdxEnd + c.imageNum new_images = new_images[imageNumIdx:imageNumIdxEnd] imageNumIdx = imageNumIdx + c.imageNum go sendInferReq(new_images, c, &wg, result_ch) } wg.Wait() close(result_ch) for s := range result_ch { results = append(results, s) } return results, nil } func sendInferReq(images []struct { imageResult *types.ImageResult file multipart.File }, cluster struct { urls []*collector.ImageInferUrl clusterName string imageNum int32 }, wg *sync.WaitGroup, ch chan<- *types.ImageResult) { for _, image := range images { go func(t struct { imageResult *types.ImageResult file multipart.File }, c struct { urls []*collector.ImageInferUrl clusterName string imageNum int32 }) { if len(c.urls) == 1 { r, err := getInferResult(c.urls[0].Url, t.file, t.imageResult.ImageName) if err != nil { t.imageResult.ImageResult = err.Error() ch <- t.imageResult wg.Done() return } t.imageResult.ImageResult = r t.imageResult.ClusterName = c.clusterName t.imageResult.Card = c.urls[0].Card ch <- t.imageResult wg.Done() return } else { idx := rand.Intn(len(c.urls)) r, err := getInferResult(c.urls[idx].Url, t.file, t.imageResult.ImageName) if err != nil { t.imageResult.ImageResult = err.Error() ch <- t.imageResult wg.Done() return } t.imageResult.ImageResult = r t.imageResult.ClusterName = c.clusterName t.imageResult.Card = c.urls[idx].Card ch <- t.imageResult wg.Done() return } }(image, cluster) } } func getInferResult(url string, file multipart.File, fileName string) (string, error) { var res Res req := GetACHttpRequest() _, err := req. SetFileReader("file", fileName, file). SetResult(&res). Post(url) if err != nil { return "", err } return res.Result, nil } func GetACHttpRequest() *resty.Request { client := resty.New() request := client.R() return request } type Res struct { Result string `json:"result"` }