package inference import ( "context" "encoding/json" "errors" "github.com/go-resty/resty/v2" "github.com/zeromicro/go-zero/core/logx" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/database" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/collector" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/strategy" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/storeLink" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils" "log" "math/rand" "mime/multipart" "sort" "strconv" "sync" "time" ) type ImageFile struct { ImageResult *types.ImageResult File multipart.File } func Infer(opt *option.InferOption, id int64, adapterName string, clusters []*strategy.AssignedCluster, ts []*ImageFile, aiCollectorAdapterMap map[string]map[string]collector.AiCollector, storage *database.AiStorage, ctx context.Context) ([]*types.ImageResult, error) { //for i := len(clusters) - 1; i >= 0; i-- { // if clusters[i].Replicas == 0 { // clusters = append(clusters[:i], clusters[i+1:]...) // } //} var wg sync.WaitGroup var cluster_ch = make(chan struct { urls []*collector.InferUrl clusterId string clusterName string imageNum int32 }, len(clusters)) var cs []struct { urls []*collector.InferUrl clusterId string clusterName string imageNum int32 } collectorMap := aiCollectorAdapterMap[opt.AdapterId] ////save taskai //for _, c := range clusters { // clusterName, _ := storage.GetClusterNameById(c.ClusterId) // opt.Replica = c.Replicas // err := storage.SaveAiTask(id, opt, adapterName, c.ClusterId, clusterName, "", constants.Saved, "") // if err != nil { // return nil, err // } //} var mutex sync.Mutex errMap := make(map[string]string) for _, cluster := range clusters { wg.Add(1) c := cluster go func() { imageUrls, err := collectorMap[c.ClusterId].GetInferUrl(ctx, opt) if err != nil { mutex.Lock() errMap[c.ClusterId] = err.Error() mutex.Unlock() wg.Done() return } for i, _ := range imageUrls { imageUrls[i].Url = imageUrls[i].Url + storeLink.FORWARD_SLASH + "image" } clusterName, _ := storage.GetClusterNameById(c.ClusterId) s := struct { urls []*collector.InferUrl clusterId string clusterName string imageNum int32 }{ urls: imageUrls, clusterId: c.ClusterId, clusterName: clusterName, imageNum: c.Replicas, } cluster_ch <- s wg.Done() return }() } wg.Wait() close(cluster_ch) for s := range cluster_ch { cs = append(cs, s) } aiTaskList, err := storage.GetAiTaskListById(id) if err != nil { return nil, err } //no cluster available if len(cs) == 0 { for _, t := range aiTaskList { t.Status = constants.Failed t.EndTime = time.Now().Format(time.RFC3339) if _, ok := errMap[strconv.Itoa(int(t.ClusterId))]; ok { t.Msg = errMap[strconv.Itoa(int(t.ClusterId))] } err := storage.UpdateAiTask(t) if err != nil { logx.Errorf(err.Error()) } } storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") return nil, errors.New("image infer task failed") } //change cluster status if len(clusters) != len(cs) { var acs []*strategy.AssignedCluster var rcs []*strategy.AssignedCluster for _, cluster := range clusters { if contains(cs, cluster.ClusterId) { var ac *strategy.AssignedCluster ac = cluster rcs = append(rcs, ac) } else { var ac *strategy.AssignedCluster ac = cluster acs = append(acs, ac) } } // update failed cluster status for _, ac := range acs { for _, t := range aiTaskList { if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { t.Status = constants.Failed t.EndTime = time.Now().Format(time.RFC3339) if _, ok := errMap[strconv.Itoa(int(t.ClusterId))]; ok { t.Msg = errMap[strconv.Itoa(int(t.ClusterId))] } err := storage.UpdateAiTask(t) if err != nil { logx.Errorf(err.Error()) } } } } // update running cluster status for _, ac := range rcs { for _, t := range aiTaskList { if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { t.Status = constants.Running err := storage.UpdateAiTask(t) if err != nil { logx.Errorf(err.Error()) } } } } storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") } else { for _, t := range aiTaskList { t.Status = constants.Running err := storage.UpdateAiTask(t) if err != nil { logx.Errorf(err.Error()) } } storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "running", "任务运行中") } var result_ch = make(chan *types.ImageResult, len(ts)) var results []*types.ImageResult limit := make(chan bool, 7) var imageNumIdx int32 = 0 var imageNumIdxEnd int32 = 0 for _, c := range cs { new_images := make([]*ImageFile, len(ts)) copy(new_images, ts) imageNumIdxEnd = imageNumIdxEnd + c.imageNum new_images = new_images[imageNumIdx:imageNumIdxEnd] imageNumIdx = imageNumIdx + c.imageNum wg.Add(len(new_images)) go sendInferReq(new_images, c, &wg, result_ch, limit) } wg.Wait() close(result_ch) for s := range result_ch { results = append(results, s) } sort.Slice(results, func(p, q int) bool { return results[p].ClusterName < results[q].ClusterName }) //save ai sub tasks for _, r := range results { for _, task := range aiTaskList { if r.ClusterId == strconv.Itoa(int(task.ClusterId)) { taskAiSub := models.TaskAiSub{ TaskId: id, TaskName: task.Name, TaskAiId: task.TaskId, TaskAiName: task.Name, ImageName: r.ImageName, Result: r.ImageResult, Card: r.Card, ClusterId: task.ClusterId, ClusterName: r.ClusterName, } err := storage.SaveAiTaskImageSubTask(&taskAiSub) if err != nil { panic(err) } } } } // update succeeded cluster status var successStatusCount int for _, c := range cs { for _, t := range aiTaskList { if c.clusterId == strconv.Itoa(int(t.ClusterId)) { t.Status = constants.Completed t.EndTime = time.Now().Format(time.RFC3339) err := storage.UpdateAiTask(t) if err != nil { logx.Errorf(err.Error()) } successStatusCount++ } else { continue } } } if len(cs) == successStatusCount { storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "completed", "任务完成") } else { storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") } return results, nil } func sendInferReq(images []*ImageFile, cluster struct { urls []*collector.InferUrl clusterId string clusterName string imageNum int32 }, wg *sync.WaitGroup, ch chan<- *types.ImageResult, limit chan bool) { for _, image := range images { limit <- true go func(t *ImageFile, c struct { urls []*collector.InferUrl clusterId string clusterName string imageNum int32 }) { if len(c.urls) == 1 { r, err := getInferResult(c.urls[0].Url, t.File, t.ImageResult.ImageName, c.clusterName) if err != nil { t.ImageResult.ImageResult = err.Error() t.ImageResult.ClusterId = c.clusterId t.ImageResult.ClusterName = c.clusterName t.ImageResult.Card = c.urls[0].Card ch <- t.ImageResult wg.Done() <-limit return } t.ImageResult.ImageResult = r t.ImageResult.ClusterId = c.clusterId t.ImageResult.ClusterName = c.clusterName t.ImageResult.Card = c.urls[0].Card ch <- t.ImageResult wg.Done() <-limit return } else { idx := rand.Intn(len(c.urls)) r, err := getInferResult(c.urls[idx].Url, t.File, t.ImageResult.ImageName, c.clusterName) if err != nil { t.ImageResult.ImageResult = err.Error() t.ImageResult.ClusterId = c.clusterId t.ImageResult.ClusterName = c.clusterName t.ImageResult.Card = c.urls[idx].Card ch <- t.ImageResult wg.Done() <-limit return } t.ImageResult.ImageResult = r t.ImageResult.ClusterId = c.clusterId t.ImageResult.ClusterName = c.clusterName t.ImageResult.Card = c.urls[idx].Card ch <- t.ImageResult wg.Done() <-limit return } }(image, cluster) <-limit } } func getInferResult(url string, file multipart.File, fileName string, clusterName string) (string, error) { if clusterName == "鹏城云脑II-modelarts" { r, err := getInferResultModelarts(url, file, fileName) if err != nil { return "", err } return r, nil } var res Res req := GetRestyRequest(20) _, err := req. SetFileReader("file", fileName, file). SetResult(&res). Post(url) if err != nil { return "", err } return res.Result, nil } func getInferResultModelarts(url string, file multipart.File, fileName string) (string, error) { var res Res /* req := GetRestyRequest(20) _, err := req. SetFileReader("file", fileName, file). SetHeaders(map[string]string{ "ak": "UNEHPHO4Z7YSNPKRXFE4", "sk": "JWXCE9qcYbc7RjpSRIWt4WgG3ZKF6Q4lPzkJReX9", }). SetResult(&res). Post(url) if err != nil { return "", err }*/ body, err := utils.SendRequest("POST", url, file, fileName) if err != nil { return "", err } errjson := json.Unmarshal([]byte(body), &res) if errjson != nil { log.Fatalf("Error parsing JSON: %s", errjson) } return res.Result, nil } func GetRestyRequest(timeoutSeconds int64) *resty.Request { client := resty.New().SetTimeout(time.Duration(timeoutSeconds) * time.Second) request := client.R() return request } type Res struct { Result string `json:"result"` } func contains(cs []struct { urls []*collector.InferUrl clusterId string clusterName string imageNum int32 }, e string) bool { for _, c := range cs { if c.clusterId == e { return true } } return false }