| @@ -5,11 +5,11 @@ import ( | |||
| "errors" | |||
| "github.com/zeromicro/go-zero/core/logx" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference/imageInference" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/strategy" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" | |||
| "net/http" | |||
| ) | |||
| @@ -102,44 +102,31 @@ func (l *ImageInferenceLogic) ImageInfer(r *http.Request, req *types.ImageInfere | |||
| } | |||
| } | |||
| //save task | |||
| var synergystatus int64 | |||
| if len(clusters) > 1 { | |||
| synergystatus = 1 | |||
| } | |||
| strategyCode, err := l.svcCtx.Scheduler.AiStorages.GetStrategyCode(opt.Strategy) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| adapterName, err := l.svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| id, err := l.svcCtx.Scheduler.AiStorages.SaveTask(opt.TaskName, strategyCode, synergystatus, "11") | |||
| imageInfer, err := imageInference.New(imageInference.NewImageClassification(), ts, clusters, opt, l.svcCtx.Scheduler.AiStorages, l.svcCtx.Scheduler.AiService.InferenceAdapterMap, adapterName) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| l.svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "create", "任务创建中") | |||
| in := inference.Inference{ | |||
| In: imageInfer, | |||
| } | |||
| //save taskai | |||
| for _, c := range clusters { | |||
| clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId) | |||
| opt.Replica = c.Replicas | |||
| err := l.svcCtx.Scheduler.AiStorages.SaveAiTask(id, opt, adapterName, c.ClusterId, clusterName, "", constants.Saved, "") | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| id, err := in.In.CreateTask() | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| go func() { | |||
| ic, err := imageInference.NewImageClassification(ts, clusters, opt, l.svcCtx.Scheduler.AiStorages, l.svcCtx.Scheduler.AiService.InferenceAdapterMap, id, adapterName) | |||
| err := in.In.InferTask(id) | |||
| if err != nil { | |||
| logx.Errorf(err.Error()) | |||
| return | |||
| } | |||
| ic.Classify() | |||
| }() | |||
| return resp, nil | |||
| @@ -6,14 +6,9 @@ import ( | |||
| "github.com/zeromicro/go-zero/core/logx" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/storeLink" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference/textInference" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" | |||
| "strconv" | |||
| "sync" | |||
| "time" | |||
| ) | |||
| type TextToTextInferenceLogic struct { | |||
| @@ -46,105 +41,29 @@ func (l *TextToTextInferenceLogic) TextToTextInference(req *types.TextToTextInfe | |||
| return nil, errors.New("AdapterId does not exist") | |||
| } | |||
| //save task | |||
| var synergystatus int64 | |||
| var strategyCode int64 | |||
| adapterName, err := l.svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId) | |||
| inType, err := textInference.NewTextToText(opt, l.svcCtx.Scheduler.AiStorages, l.svcCtx.Scheduler.AiService.InferenceAdapterMap) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| id, err := l.svcCtx.Scheduler.AiStorages.SaveTask(opt.TaskName, strategyCode, synergystatus, "12") | |||
| textInfer, err := textInference.New(inType, opt, l.svcCtx.Scheduler.AiStorages, l.svcCtx.Scheduler.AiService.InferenceAdapterMap, adapterName) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| var wg sync.WaitGroup | |||
| var cluster_ch = make(chan struct { | |||
| urls []*inference.InferUrl | |||
| clusterId string | |||
| clusterName string | |||
| }, len(opt.AiClusterIds)) | |||
| var cs []struct { | |||
| urls []*inference.InferUrl | |||
| clusterId string | |||
| clusterName string | |||
| } | |||
| inferMap := l.svcCtx.Scheduler.AiService.InferenceAdapterMap[opt.AdapterId] | |||
| //save taskai | |||
| for _, clusterId := range opt.AiClusterIds { | |||
| wg.Add(1) | |||
| go func(cId string) { | |||
| urls, err := inferMap[cId].GetInferUrl(l.ctx, opt) | |||
| if err != nil { | |||
| wg.Done() | |||
| return | |||
| } | |||
| clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(cId) | |||
| s := struct { | |||
| urls []*inference.InferUrl | |||
| clusterId string | |||
| clusterName string | |||
| }{ | |||
| urls: urls, | |||
| clusterId: cId, | |||
| clusterName: clusterName, | |||
| } | |||
| cluster_ch <- s | |||
| wg.Done() | |||
| return | |||
| }(clusterId) | |||
| } | |||
| wg.Wait() | |||
| close(cluster_ch) | |||
| for s := range cluster_ch { | |||
| cs = append(cs, s) | |||
| } | |||
| if len(cs) == 0 { | |||
| clusterId := opt.AiClusterIds[0] | |||
| clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(opt.AiClusterIds[0]) | |||
| err := l.svcCtx.Scheduler.AiStorages.SaveAiTask(id, opt, adapterName, clusterId, clusterName, "", constants.Failed, "") | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| l.svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") | |||
| } | |||
| for _, c := range cs { | |||
| clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(c.clusterId) | |||
| err := l.svcCtx.Scheduler.AiStorages.SaveAiTask(id, opt, adapterName, c.clusterId, clusterName, "", constants.Saved, "") | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| in := inference.Inference{ | |||
| In: textInfer, | |||
| } | |||
| var aiTaskList []*models.TaskAi | |||
| tx := l.svcCtx.DbEngin.Raw("select * from task_ai where `task_id` = ? ", id).Scan(&aiTaskList) | |||
| if tx.Error != nil { | |||
| return nil, tx.Error | |||
| id, err := in.In.CreateTask() | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| for i, t := range aiTaskList { | |||
| if strconv.Itoa(int(t.ClusterId)) == cs[i].clusterId { | |||
| t.Status = constants.Completed | |||
| t.EndTime = time.Now().Format(time.RFC3339) | |||
| url := cs[i].urls[0].Url + storeLink.FORWARD_SLASH + "chat" | |||
| t.InferUrl = url | |||
| err := l.svcCtx.Scheduler.AiStorages.UpdateAiTask(t) | |||
| if err != nil { | |||
| logx.Errorf(tx.Error.Error()) | |||
| } | |||
| } | |||
| err = in.In.InferTask(id) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| l.svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "completed", "任务完成") | |||
| return resp, nil | |||
| } | |||
| @@ -26,7 +26,7 @@ const ( | |||
| type AiService struct { | |||
| AiExecutorAdapterMap map[string]map[string]executor.AiExecutor | |||
| AiCollectorAdapterMap map[string]map[string]collector.AiCollector | |||
| InferenceAdapterMap map[string]map[string]inference.Inference | |||
| InferenceAdapterMap map[string]map[string]inference.ICluster | |||
| Storage *database.AiStorage | |||
| mu sync.Mutex | |||
| } | |||
| @@ -40,7 +40,7 @@ func NewAiService(conf *config.Config, storages *database.AiStorage) (*AiService | |||
| aiService := &AiService{ | |||
| AiExecutorAdapterMap: make(map[string]map[string]executor.AiExecutor), | |||
| AiCollectorAdapterMap: make(map[string]map[string]collector.AiCollector), | |||
| InferenceAdapterMap: make(map[string]map[string]inference.Inference), | |||
| InferenceAdapterMap: make(map[string]map[string]inference.ICluster), | |||
| Storage: storages, | |||
| } | |||
| for _, id := range adapterIds { | |||
| @@ -60,10 +60,10 @@ func NewAiService(conf *config.Config, storages *database.AiStorage) (*AiService | |||
| return aiService, nil | |||
| } | |||
| func InitAiClusterMap(conf *config.Config, clusters []types.ClusterInfo) (map[string]executor.AiExecutor, map[string]collector.AiCollector, map[string]inference.Inference) { | |||
| func InitAiClusterMap(conf *config.Config, clusters []types.ClusterInfo) (map[string]executor.AiExecutor, map[string]collector.AiCollector, map[string]inference.ICluster) { | |||
| executorMap := make(map[string]executor.AiExecutor) | |||
| collectorMap := make(map[string]collector.AiCollector) | |||
| inferenceMap := make(map[string]inference.Inference) | |||
| inferenceMap := make(map[string]inference.ICluster) | |||
| for _, c := range clusters { | |||
| switch c.Name { | |||
| case OCTOPUS: | |||
| @@ -1,419 +1,26 @@ | |||
| package imageInference | |||
| import ( | |||
| "encoding/json" | |||
| "errors" | |||
| "github.com/go-resty/resty/v2" | |||
| "github.com/zeromicro/go-zero/core/logx" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/strategy" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils" | |||
| "log" | |||
| "math/rand" | |||
| "mime/multipart" | |||
| "net/http" | |||
| "sort" | |||
| "strconv" | |||
| "sync" | |||
| "time" | |||
| ) | |||
| import "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" | |||
| const ( | |||
| IMAGE = "image" | |||
| FORWARD_SLASH = "/" | |||
| CLASSIFICATION = "image" | |||
| CLASSIFICATION_AiTYPE = "11" | |||
| ) | |||
| type ImageClassificationInterface interface { | |||
| Classify() ([]*types.ImageResult, error) | |||
| } | |||
| type ImageFile struct { | |||
| ImageResult *types.ImageResult | |||
| File multipart.File | |||
| } | |||
| type FilteredCluster struct { | |||
| urls []*inference.InferUrl | |||
| clusterId string | |||
| clusterName string | |||
| imageNum int32 | |||
| } | |||
| type ImageClassification struct { | |||
| files []*ImageFile | |||
| clusters []*strategy.AssignedCluster | |||
| opt *option.InferOption | |||
| storage *database.AiStorage | |||
| inferAdapter map[string]map[string]inference.Inference | |||
| errMap map[string]string | |||
| taskId int64 | |||
| adapterName string | |||
| aiTaskList []*models.TaskAi | |||
| } | |||
| func NewImageClassification(files []*ImageFile, | |||
| clusters []*strategy.AssignedCluster, | |||
| opt *option.InferOption, | |||
| storage *database.AiStorage, | |||
| inferAdapter map[string]map[string]inference.Inference, | |||
| taskId int64, | |||
| adapterName string) (*ImageClassification, error) { | |||
| aiTaskList, err := storage.GetAiTaskListById(taskId) | |||
| if err != nil || len(aiTaskList) == 0 { | |||
| return nil, err | |||
| } | |||
| return &ImageClassification{ | |||
| files: files, | |||
| clusters: clusters, | |||
| opt: opt, | |||
| storage: storage, | |||
| inferAdapter: inferAdapter, | |||
| taskId: taskId, | |||
| adapterName: adapterName, | |||
| errMap: make(map[string]string), | |||
| aiTaskList: aiTaskList, | |||
| }, nil | |||
| } | |||
| func (i *ImageClassification) Classify() ([]*types.ImageResult, error) { | |||
| clusters, err := i.filterClusters() | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| err = i.updateStatus(clusters) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| results, err := i.inferImages(clusters) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| return results, nil | |||
| } | |||
| func (i *ImageClassification) filterClusters() ([]*FilteredCluster, error) { | |||
| var wg sync.WaitGroup | |||
| var ch = make(chan *FilteredCluster, len(i.clusters)) | |||
| var cs []*FilteredCluster | |||
| var mutex sync.Mutex | |||
| inferMap := i.inferAdapter[i.opt.AdapterId] | |||
| for _, cluster := range i.clusters { | |||
| wg.Add(1) | |||
| c := cluster | |||
| go func() { | |||
| r := http.Request{} | |||
| imageUrls, err := inferMap[c.ClusterId].GetInferUrl(r.Context(), i.opt) | |||
| if err != nil { | |||
| mutex.Lock() | |||
| i.errMap[c.ClusterId] = err.Error() | |||
| mutex.Unlock() | |||
| wg.Done() | |||
| return | |||
| } | |||
| for i, _ := range imageUrls { | |||
| imageUrls[i].Url = imageUrls[i].Url + FORWARD_SLASH + IMAGE | |||
| } | |||
| clusterName, _ := i.storage.GetClusterNameById(c.ClusterId) | |||
| var f FilteredCluster | |||
| f.urls = imageUrls | |||
| f.clusterId = c.ClusterId | |||
| f.clusterName = clusterName | |||
| f.imageNum = c.Replicas | |||
| ch <- &f | |||
| wg.Done() | |||
| return | |||
| }() | |||
| } | |||
| wg.Wait() | |||
| close(ch) | |||
| for s := range ch { | |||
| cs = append(cs, s) | |||
| } | |||
| return cs, nil | |||
| } | |||
| func (i *ImageClassification) inferImages(cs []*FilteredCluster) ([]*types.ImageResult, error) { | |||
| var wg sync.WaitGroup | |||
| var ch = make(chan *types.ImageResult, len(i.files)) | |||
| var results []*types.ImageResult | |||
| limit := make(chan bool, 7) | |||
| var imageNumIdx int32 = 0 | |||
| var imageNumIdxEnd int32 = 0 | |||
| for _, c := range cs { | |||
| new_images := make([]*ImageFile, len(i.files)) | |||
| copy(new_images, i.files) | |||
| imageNumIdxEnd = imageNumIdxEnd + c.imageNum | |||
| new_images = new_images[imageNumIdx:imageNumIdxEnd] | |||
| imageNumIdx = imageNumIdx + c.imageNum | |||
| wg.Add(len(new_images)) | |||
| go sendInferReq(new_images, c, &wg, ch, limit) | |||
| } | |||
| wg.Wait() | |||
| close(ch) | |||
| for s := range ch { | |||
| results = append(results, s) | |||
| } | |||
| sort.Slice(results, func(p, q int) bool { | |||
| return results[p].ClusterName < results[q].ClusterName | |||
| }) | |||
| //save ai sub tasks | |||
| for _, r := range results { | |||
| for _, task := range i.aiTaskList { | |||
| if r.ClusterId == strconv.Itoa(int(task.ClusterId)) { | |||
| taskAiSub := models.TaskAiSub{ | |||
| TaskId: i.taskId, | |||
| TaskName: task.Name, | |||
| TaskAiId: task.TaskId, | |||
| TaskAiName: task.Name, | |||
| ImageName: r.ImageName, | |||
| Result: r.ImageResult, | |||
| Card: r.Card, | |||
| ClusterId: task.ClusterId, | |||
| ClusterName: r.ClusterName, | |||
| } | |||
| err := i.storage.SaveAiTaskImageSubTask(&taskAiSub) | |||
| if err != nil { | |||
| panic(err) | |||
| } | |||
| } | |||
| } | |||
| } | |||
| // update succeeded cluster status | |||
| var successStatusCount int | |||
| for _, c := range cs { | |||
| for _, t := range i.aiTaskList { | |||
| if c.clusterId == strconv.Itoa(int(t.ClusterId)) { | |||
| t.Status = constants.Completed | |||
| t.EndTime = time.Now().Format(time.RFC3339) | |||
| err := i.storage.UpdateAiTask(t) | |||
| if err != nil { | |||
| logx.Errorf(err.Error()) | |||
| } | |||
| successStatusCount++ | |||
| } else { | |||
| continue | |||
| } | |||
| } | |||
| } | |||
| if len(cs) == successStatusCount { | |||
| i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "completed", "任务完成") | |||
| } else { | |||
| i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "failed", "任务失败") | |||
| } | |||
| return results, nil | |||
| func NewImageClassification() *ImageClassification { | |||
| return &ImageClassification{} | |||
| } | |||
| func (i *ImageClassification) updateStatus(cs []*FilteredCluster) error { | |||
| //no cluster available | |||
| if len(cs) == 0 { | |||
| for _, t := range i.aiTaskList { | |||
| t.Status = constants.Failed | |||
| t.EndTime = time.Now().Format(time.RFC3339) | |||
| if _, ok := i.errMap[strconv.Itoa(int(t.ClusterId))]; ok { | |||
| t.Msg = i.errMap[strconv.Itoa(int(t.ClusterId))] | |||
| } | |||
| err := i.storage.UpdateAiTask(t) | |||
| if err != nil { | |||
| logx.Errorf(err.Error()) | |||
| } | |||
| } | |||
| i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "failed", "任务失败") | |||
| return errors.New("image infer task failed") | |||
| } | |||
| //change cluster status | |||
| if len(i.clusters) != len(cs) { | |||
| var acs []*strategy.AssignedCluster | |||
| var rcs []*strategy.AssignedCluster | |||
| for _, cluster := range i.clusters { | |||
| if contains(cs, cluster.ClusterId) { | |||
| var ac *strategy.AssignedCluster | |||
| ac = cluster | |||
| rcs = append(rcs, ac) | |||
| } else { | |||
| var ac *strategy.AssignedCluster | |||
| ac = cluster | |||
| acs = append(acs, ac) | |||
| } | |||
| } | |||
| // update failed cluster status | |||
| for _, ac := range acs { | |||
| for _, t := range i.aiTaskList { | |||
| if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { | |||
| t.Status = constants.Failed | |||
| t.EndTime = time.Now().Format(time.RFC3339) | |||
| if _, ok := i.errMap[strconv.Itoa(int(t.ClusterId))]; ok { | |||
| t.Msg = i.errMap[strconv.Itoa(int(t.ClusterId))] | |||
| } | |||
| err := i.storage.UpdateAiTask(t) | |||
| if err != nil { | |||
| logx.Errorf(err.Error()) | |||
| } | |||
| } | |||
| } | |||
| } | |||
| // update running cluster status | |||
| for _, ac := range rcs { | |||
| for _, t := range i.aiTaskList { | |||
| if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { | |||
| t.Status = constants.Running | |||
| err := i.storage.UpdateAiTask(t) | |||
| if err != nil { | |||
| logx.Errorf(err.Error()) | |||
| } | |||
| } | |||
| } | |||
| } | |||
| i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "failed", "任务失败") | |||
| } else { | |||
| for _, t := range i.aiTaskList { | |||
| t.Status = constants.Running | |||
| err := i.storage.UpdateAiTask(t) | |||
| if err != nil { | |||
| logx.Errorf(err.Error()) | |||
| } | |||
| } | |||
| i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "running", "任务运行中") | |||
| func (ic *ImageClassification) AppendRoute(urls []*inference.InferUrl) error { | |||
| for i, _ := range urls { | |||
| urls[i].Url = urls[i].Url + inference.FORWARD_SLASH + CLASSIFICATION | |||
| } | |||
| return nil | |||
| } | |||
| func sendInferReq(images []*ImageFile, cluster *FilteredCluster, wg *sync.WaitGroup, ch chan<- *types.ImageResult, limit chan bool) { | |||
| for _, image := range images { | |||
| limit <- true | |||
| go func(t *ImageFile, c *FilteredCluster) { | |||
| if len(c.urls) == 1 { | |||
| r, err := getInferResult(c.urls[0].Url, t.File, t.ImageResult.ImageName, c.clusterName) | |||
| if err != nil { | |||
| t.ImageResult.ImageResult = err.Error() | |||
| t.ImageResult.ClusterId = c.clusterId | |||
| t.ImageResult.ClusterName = c.clusterName | |||
| t.ImageResult.Card = c.urls[0].Card | |||
| ch <- t.ImageResult | |||
| wg.Done() | |||
| <-limit | |||
| return | |||
| } | |||
| t.ImageResult.ImageResult = r | |||
| t.ImageResult.ClusterId = c.clusterId | |||
| t.ImageResult.ClusterName = c.clusterName | |||
| t.ImageResult.Card = c.urls[0].Card | |||
| ch <- t.ImageResult | |||
| wg.Done() | |||
| <-limit | |||
| return | |||
| } else { | |||
| idx := rand.Intn(len(c.urls)) | |||
| r, err := getInferResult(c.urls[idx].Url, t.File, t.ImageResult.ImageName, c.clusterName) | |||
| if err != nil { | |||
| t.ImageResult.ImageResult = err.Error() | |||
| t.ImageResult.ClusterId = c.clusterId | |||
| t.ImageResult.ClusterName = c.clusterName | |||
| t.ImageResult.Card = c.urls[idx].Card | |||
| ch <- t.ImageResult | |||
| wg.Done() | |||
| <-limit | |||
| return | |||
| } | |||
| t.ImageResult.ImageResult = r | |||
| t.ImageResult.ClusterId = c.clusterId | |||
| t.ImageResult.ClusterName = c.clusterName | |||
| t.ImageResult.Card = c.urls[idx].Card | |||
| ch <- t.ImageResult | |||
| wg.Done() | |||
| <-limit | |||
| return | |||
| } | |||
| }(image, cluster) | |||
| <-limit | |||
| } | |||
| } | |||
| func getInferResult(url string, file multipart.File, fileName string, clusterName string) (string, error) { | |||
| if clusterName == "鹏城云脑II-modelarts" { | |||
| r, err := getInferResultModelarts(url, file, fileName) | |||
| if err != nil { | |||
| return "", err | |||
| } | |||
| return r, nil | |||
| } | |||
| var res Res | |||
| req := GetRestyRequest(20) | |||
| _, err := req. | |||
| SetFileReader("file", fileName, file). | |||
| SetResult(&res). | |||
| Post(url) | |||
| if err != nil { | |||
| return "", err | |||
| } | |||
| return res.Result, nil | |||
| } | |||
| func getInferResultModelarts(url string, file multipart.File, fileName string) (string, error) { | |||
| var res Res | |||
| /* req := GetRestyRequest(20) | |||
| _, err := req. | |||
| SetFileReader("file", fileName, file). | |||
| SetHeaders(map[string]string{ | |||
| "ak": "UNEHPHO4Z7YSNPKRXFE4", | |||
| "sk": "JWXCE9qcYbc7RjpSRIWt4WgG3ZKF6Q4lPzkJReX9", | |||
| }). | |||
| SetResult(&res). | |||
| Post(url) | |||
| if err != nil { | |||
| return "", err | |||
| }*/ | |||
| body, err := utils.SendRequest("POST", url, file, fileName) | |||
| if err != nil { | |||
| return "", err | |||
| } | |||
| errjson := json.Unmarshal([]byte(body), &res) | |||
| if errjson != nil { | |||
| log.Fatalf("Error parsing JSON: %s", errjson) | |||
| } | |||
| return res.Result, nil | |||
| } | |||
| func GetRestyRequest(timeoutSeconds int64) *resty.Request { | |||
| client := resty.New().SetTimeout(time.Duration(timeoutSeconds) * time.Second) | |||
| request := client.R() | |||
| return request | |||
| } | |||
| type Res struct { | |||
| Result string `json:"result"` | |||
| } | |||
| func contains(cs []*FilteredCluster, e string) bool { | |||
| for _, c := range cs { | |||
| if c.clusterId == e { | |||
| return true | |||
| } | |||
| } | |||
| return false | |||
| func (ic *ImageClassification) GetAiType() string { | |||
| return CLASSIFICATION_AiTYPE | |||
| } | |||
| @@ -1,9 +1,469 @@ | |||
| package imageInference | |||
| import "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" | |||
| import ( | |||
| "encoding/json" | |||
| "errors" | |||
| "github.com/go-resty/resty/v2" | |||
| "github.com/zeromicro/go-zero/core/logx" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/strategy" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils" | |||
| "log" | |||
| "math/rand" | |||
| "mime/multipart" | |||
| "net/http" | |||
| "sort" | |||
| "strconv" | |||
| "sync" | |||
| "time" | |||
| ) | |||
| type ImageInference interface { | |||
| filterClusters() ([]*FilteredCluster, error) | |||
| inferImages(cs []*FilteredCluster) ([]*types.ImageResult, error) | |||
| updateStatus(cs []*FilteredCluster) error | |||
| type IImageInference interface { | |||
| AppendRoute(urls []*inference.InferUrl) error | |||
| GetAiType() string | |||
| } | |||
| type ImageFile struct { | |||
| ImageResult *types.ImageResult | |||
| File multipart.File | |||
| } | |||
| type FilteredCluster struct { | |||
| urls []*inference.InferUrl | |||
| clusterId string | |||
| clusterName string | |||
| imageNum int32 | |||
| } | |||
| type ImageInference struct { | |||
| inference IImageInference | |||
| files []*ImageFile | |||
| clusters []*strategy.AssignedCluster | |||
| opt *option.InferOption | |||
| storage *database.AiStorage | |||
| inferAdapter map[string]map[string]inference.ICluster | |||
| errMap map[string]string | |||
| adapterName string | |||
| } | |||
| func New( | |||
| inference IImageInference, | |||
| files []*ImageFile, | |||
| clusters []*strategy.AssignedCluster, | |||
| opt *option.InferOption, | |||
| storage *database.AiStorage, | |||
| inferAdapter map[string]map[string]inference.ICluster, | |||
| adapterName string) (*ImageInference, error) { | |||
| return &ImageInference{ | |||
| inference: inference, | |||
| files: files, | |||
| clusters: clusters, | |||
| opt: opt, | |||
| storage: storage, | |||
| inferAdapter: inferAdapter, | |||
| adapterName: adapterName, | |||
| errMap: make(map[string]string), | |||
| }, nil | |||
| } | |||
| func (i *ImageInference) CreateTask() (int64, error) { | |||
| id, err := i.saveTask() | |||
| if err != nil { | |||
| return 0, err | |||
| } | |||
| err = i.saveAiTask(id) | |||
| if err != nil { | |||
| return 0, err | |||
| } | |||
| return id, nil | |||
| } | |||
| func (i *ImageInference) InferTask(id int64) error { | |||
| clusters, err := i.filterClusters() | |||
| if err != nil { | |||
| return err | |||
| } | |||
| aiTaskList, err := i.storage.GetAiTaskListById(id) | |||
| if err != nil || len(aiTaskList) == 0 { | |||
| return err | |||
| } | |||
| err = i.updateStatus(aiTaskList, clusters) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| results, err := i.inferImages(clusters) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| err = i.saveAiSubTasks(id, aiTaskList, clusters, results) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| return nil | |||
| } | |||
| func (i *ImageInference) saveTask() (int64, error) { | |||
| var synergystatus int64 | |||
| if len(i.clusters) > 1 { | |||
| synergystatus = 1 | |||
| } | |||
| strategyCode, err := i.storage.GetStrategyCode(i.opt.Strategy) | |||
| if err != nil { | |||
| return 0, err | |||
| } | |||
| id, err := i.storage.SaveTask(i.opt.TaskName, strategyCode, synergystatus, i.inference.GetAiType()) | |||
| if err != nil { | |||
| return 0, err | |||
| } | |||
| i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "create", "任务创建中") | |||
| return id, nil | |||
| } | |||
| func (i *ImageInference) saveAiTask(id int64) error { | |||
| for _, c := range i.clusters { | |||
| clusterName, _ := i.storage.GetClusterNameById(c.ClusterId) | |||
| i.opt.Replica = c.Replicas | |||
| err := i.storage.SaveAiTask(id, i.opt, i.adapterName, c.ClusterId, clusterName, "", constants.Saved, "") | |||
| if err != nil { | |||
| return err | |||
| } | |||
| } | |||
| return nil | |||
| } | |||
| func (i *ImageInference) filterClusters() ([]*FilteredCluster, error) { | |||
| var wg sync.WaitGroup | |||
| var ch = make(chan *FilteredCluster, len(i.clusters)) | |||
| var cs []*FilteredCluster | |||
| var mutex sync.Mutex | |||
| inferMap := i.inferAdapter[i.opt.AdapterId] | |||
| for _, cluster := range i.clusters { | |||
| wg.Add(1) | |||
| c := cluster | |||
| go func() { | |||
| r := http.Request{} | |||
| imageUrls, err := inferMap[c.ClusterId].GetInferUrl(r.Context(), i.opt) | |||
| if err != nil { | |||
| mutex.Lock() | |||
| i.errMap[c.ClusterId] = err.Error() | |||
| mutex.Unlock() | |||
| wg.Done() | |||
| return | |||
| } | |||
| i.inference.AppendRoute(imageUrls) | |||
| clusterName, _ := i.storage.GetClusterNameById(c.ClusterId) | |||
| var f FilteredCluster | |||
| f.urls = imageUrls | |||
| f.clusterId = c.ClusterId | |||
| f.clusterName = clusterName | |||
| f.imageNum = c.Replicas | |||
| ch <- &f | |||
| wg.Done() | |||
| return | |||
| }() | |||
| } | |||
| wg.Wait() | |||
| close(ch) | |||
| for s := range ch { | |||
| cs = append(cs, s) | |||
| } | |||
| return cs, nil | |||
| } | |||
| func (i *ImageInference) inferImages(cs []*FilteredCluster) ([]*types.ImageResult, error) { | |||
| var wg sync.WaitGroup | |||
| var ch = make(chan *types.ImageResult, len(i.files)) | |||
| var results []*types.ImageResult | |||
| limit := make(chan bool, 7) | |||
| var imageNumIdx int32 = 0 | |||
| var imageNumIdxEnd int32 = 0 | |||
| for _, c := range cs { | |||
| new_images := make([]*ImageFile, len(i.files)) | |||
| copy(new_images, i.files) | |||
| imageNumIdxEnd = imageNumIdxEnd + c.imageNum | |||
| new_images = new_images[imageNumIdx:imageNumIdxEnd] | |||
| imageNumIdx = imageNumIdx + c.imageNum | |||
| wg.Add(len(new_images)) | |||
| go sendInferReq(new_images, c, &wg, ch, limit) | |||
| } | |||
| wg.Wait() | |||
| close(ch) | |||
| for s := range ch { | |||
| results = append(results, s) | |||
| } | |||
| sort.Slice(results, func(p, q int) bool { | |||
| return results[p].ClusterName < results[q].ClusterName | |||
| }) | |||
| return results, nil | |||
| } | |||
| func (i *ImageInference) updateStatus(aiTaskList []*models.TaskAi, cs []*FilteredCluster) error { | |||
| //no cluster available | |||
| if len(cs) == 0 { | |||
| for _, t := range aiTaskList { | |||
| t.Status = constants.Failed | |||
| t.EndTime = time.Now().Format(time.RFC3339) | |||
| if _, ok := i.errMap[strconv.Itoa(int(t.ClusterId))]; ok { | |||
| t.Msg = i.errMap[strconv.Itoa(int(t.ClusterId))] | |||
| } | |||
| err := i.storage.UpdateAiTask(t) | |||
| if err != nil { | |||
| logx.Errorf(err.Error()) | |||
| } | |||
| } | |||
| i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "failed", "任务失败") | |||
| return errors.New("available clusters' empty, image infer task failed") | |||
| } | |||
| //change cluster status | |||
| if len(i.clusters) != len(cs) { | |||
| var acs []*strategy.AssignedCluster | |||
| var rcs []*strategy.AssignedCluster | |||
| for _, cluster := range i.clusters { | |||
| if contains(cs, cluster.ClusterId) { | |||
| var ac *strategy.AssignedCluster | |||
| ac = cluster | |||
| rcs = append(rcs, ac) | |||
| } else { | |||
| var ac *strategy.AssignedCluster | |||
| ac = cluster | |||
| acs = append(acs, ac) | |||
| } | |||
| } | |||
| // update failed cluster status | |||
| for _, ac := range acs { | |||
| for _, t := range aiTaskList { | |||
| if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { | |||
| t.Status = constants.Failed | |||
| t.EndTime = time.Now().Format(time.RFC3339) | |||
| if _, ok := i.errMap[strconv.Itoa(int(t.ClusterId))]; ok { | |||
| t.Msg = i.errMap[strconv.Itoa(int(t.ClusterId))] | |||
| } | |||
| err := i.storage.UpdateAiTask(t) | |||
| if err != nil { | |||
| logx.Errorf(err.Error()) | |||
| } | |||
| } | |||
| } | |||
| } | |||
| // update running cluster status | |||
| for _, ac := range rcs { | |||
| for _, t := range aiTaskList { | |||
| if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { | |||
| t.Status = constants.Running | |||
| err := i.storage.UpdateAiTask(t) | |||
| if err != nil { | |||
| logx.Errorf(err.Error()) | |||
| } | |||
| } | |||
| } | |||
| } | |||
| i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "failed", "任务失败") | |||
| } else { | |||
| for _, t := range aiTaskList { | |||
| t.Status = constants.Running | |||
| err := i.storage.UpdateAiTask(t) | |||
| if err != nil { | |||
| logx.Errorf(err.Error()) | |||
| } | |||
| } | |||
| i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "running", "任务运行中") | |||
| } | |||
| return nil | |||
| } | |||
| func sendInferReq(images []*ImageFile, cluster *FilteredCluster, wg *sync.WaitGroup, ch chan<- *types.ImageResult, limit chan bool) { | |||
| for _, image := range images { | |||
| limit <- true | |||
| go func(t *ImageFile, c *FilteredCluster) { | |||
| if len(c.urls) == 1 { | |||
| r, err := getInferResult(c.urls[0].Url, t.File, t.ImageResult.ImageName, c.clusterName) | |||
| if err != nil { | |||
| t.ImageResult.ImageResult = err.Error() | |||
| t.ImageResult.ClusterId = c.clusterId | |||
| t.ImageResult.ClusterName = c.clusterName | |||
| t.ImageResult.Card = c.urls[0].Card | |||
| ch <- t.ImageResult | |||
| wg.Done() | |||
| <-limit | |||
| return | |||
| } | |||
| t.ImageResult.ImageResult = r | |||
| t.ImageResult.ClusterId = c.clusterId | |||
| t.ImageResult.ClusterName = c.clusterName | |||
| t.ImageResult.Card = c.urls[0].Card | |||
| ch <- t.ImageResult | |||
| wg.Done() | |||
| <-limit | |||
| return | |||
| } else { | |||
| idx := rand.Intn(len(c.urls)) | |||
| r, err := getInferResult(c.urls[idx].Url, t.File, t.ImageResult.ImageName, c.clusterName) | |||
| if err != nil { | |||
| t.ImageResult.ImageResult = err.Error() | |||
| t.ImageResult.ClusterId = c.clusterId | |||
| t.ImageResult.ClusterName = c.clusterName | |||
| t.ImageResult.Card = c.urls[idx].Card | |||
| ch <- t.ImageResult | |||
| wg.Done() | |||
| <-limit | |||
| return | |||
| } | |||
| t.ImageResult.ImageResult = r | |||
| t.ImageResult.ClusterId = c.clusterId | |||
| t.ImageResult.ClusterName = c.clusterName | |||
| t.ImageResult.Card = c.urls[idx].Card | |||
| ch <- t.ImageResult | |||
| wg.Done() | |||
| <-limit | |||
| return | |||
| } | |||
| }(image, cluster) | |||
| <-limit | |||
| } | |||
| } | |||
| func (i *ImageInference) saveAiSubTasks(id int64, aiTaskList []*models.TaskAi, cs []*FilteredCluster, results []*types.ImageResult) error { | |||
| //save ai sub tasks | |||
| for _, r := range results { | |||
| for _, task := range aiTaskList { | |||
| if r.ClusterId == strconv.Itoa(int(task.ClusterId)) { | |||
| taskAiSub := models.TaskAiSub{ | |||
| TaskId: id, | |||
| TaskName: task.Name, | |||
| TaskAiId: task.TaskId, | |||
| TaskAiName: task.Name, | |||
| ImageName: r.ImageName, | |||
| Result: r.ImageResult, | |||
| Card: r.Card, | |||
| ClusterId: task.ClusterId, | |||
| ClusterName: r.ClusterName, | |||
| } | |||
| err := i.storage.SaveAiTaskImageSubTask(&taskAiSub) | |||
| if err != nil { | |||
| panic(err) | |||
| } | |||
| } | |||
| } | |||
| } | |||
| // update succeeded cluster status | |||
| var successStatusCount int | |||
| for _, c := range cs { | |||
| for _, t := range aiTaskList { | |||
| if c.clusterId == strconv.Itoa(int(t.ClusterId)) { | |||
| t.Status = constants.Completed | |||
| t.EndTime = time.Now().Format(time.RFC3339) | |||
| err := i.storage.UpdateAiTask(t) | |||
| if err != nil { | |||
| logx.Errorf(err.Error()) | |||
| } | |||
| successStatusCount++ | |||
| } else { | |||
| continue | |||
| } | |||
| } | |||
| } | |||
| if len(cs) == successStatusCount { | |||
| i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "completed", "任务完成") | |||
| } else { | |||
| i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "failed", "任务失败") | |||
| } | |||
| return nil | |||
| } | |||
| func getInferResult(url string, file multipart.File, fileName string, clusterName string) (string, error) { | |||
| if clusterName == "鹏城云脑II-modelarts" { | |||
| r, err := getInferResultModelarts(url, file, fileName) | |||
| if err != nil { | |||
| return "", err | |||
| } | |||
| return r, nil | |||
| } | |||
| var res Res | |||
| req := GetRestyRequest(20) | |||
| _, err := req. | |||
| SetFileReader("file", fileName, file). | |||
| SetResult(&res). | |||
| Post(url) | |||
| if err != nil { | |||
| return "", err | |||
| } | |||
| return res.Result, nil | |||
| } | |||
| func getInferResultModelarts(url string, file multipart.File, fileName string) (string, error) { | |||
| var res Res | |||
| /* req := GetRestyRequest(20) | |||
| _, err := req. | |||
| SetFileReader("file", fileName, file). | |||
| SetHeaders(map[string]string{ | |||
| "ak": "UNEHPHO4Z7YSNPKRXFE4", | |||
| "sk": "JWXCE9qcYbc7RjpSRIWt4WgG3ZKF6Q4lPzkJReX9", | |||
| }). | |||
| SetResult(&res). | |||
| Post(url) | |||
| if err != nil { | |||
| return "", err | |||
| }*/ | |||
| body, err := utils.SendRequest("POST", url, file, fileName) | |||
| if err != nil { | |||
| return "", err | |||
| } | |||
| errjson := json.Unmarshal([]byte(body), &res) | |||
| if errjson != nil { | |||
| log.Fatalf("Error parsing JSON: %s", errjson) | |||
| } | |||
| return res.Result, nil | |||
| } | |||
| func GetRestyRequest(timeoutSeconds int64) *resty.Request { | |||
| client := resty.New().SetTimeout(time.Duration(timeoutSeconds) * time.Second) | |||
| request := client.R() | |||
| return request | |||
| } | |||
| type Res struct { | |||
| Result string `json:"result"` | |||
| } | |||
| func contains(cs []*FilteredCluster, e string) bool { | |||
| for _, c := range cs { | |||
| if c.clusterId == e { | |||
| return true | |||
| } | |||
| } | |||
| return false | |||
| } | |||
| @@ -1,19 +1,22 @@ | |||
| package imageInference | |||
| import ( | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/strategy" | |||
| import "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" | |||
| const ( | |||
| IMAGETOTEXT = "image-to-text" | |||
| IMAGETOTEXT_AiTYPE = "13" | |||
| ) | |||
| type ImageToText struct { | |||
| files []*ImageFile | |||
| clusters []*strategy.AssignedCluster | |||
| opt *option.InferOption | |||
| storage *database.AiStorage | |||
| inferAdapter map[string]map[string]inference.Inference | |||
| errMap map[string]string | |||
| taskId int64 | |||
| adapterName string | |||
| } | |||
| func (it *ImageToText) AppendRoute(urls []*inference.InferUrl) error { | |||
| for i, _ := range urls { | |||
| urls[i].Url = urls[i].Url + inference.FORWARD_SLASH + IMAGETOTEXT | |||
| } | |||
| return nil | |||
| } | |||
| func (it *ImageToText) GetAiType() string { | |||
| return IMAGETOTEXT_AiTYPE | |||
| } | |||
| @@ -5,377 +5,25 @@ import ( | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" | |||
| ) | |||
| type Inference interface { | |||
| const ( | |||
| FORWARD_SLASH = "/" | |||
| ) | |||
| type ICluster interface { | |||
| GetInferUrl(ctx context.Context, option *option.InferOption) ([]*InferUrl, error) | |||
| //GetInferDeployInstanceList(ctx context.Context, option *option.InferOption) | |||
| } | |||
| type IInference interface { | |||
| CreateTask() (int64, error) | |||
| InferTask(id int64) error | |||
| } | |||
| type Inference struct { | |||
| In IInference | |||
| } | |||
| type InferUrl struct { | |||
| Url string | |||
| Card string | |||
| } | |||
| //func ImageInfer(opt *option.InferOption, id int64, adapterName string, clusters []*strategy.AssignedCluster, ts []*ImageFile, inferAdapterMap map[string]map[string]Inference, storage *database.AiStorage, ctx context.Context) ([]*types.ImageResult, error) { | |||
| // | |||
| // //for i := len(clusters) - 1; i >= 0; i-- { | |||
| // // if clusters[i].Replicas == 0 { | |||
| // // clusters = append(clusters[:i], clusters[i+1:]...) | |||
| // // } | |||
| // //} | |||
| // var wg sync.WaitGroup | |||
| // var cluster_ch = make(chan struct { | |||
| // urls []*InferUrl | |||
| // clusterId string | |||
| // clusterName string | |||
| // imageNum int32 | |||
| // }, len(clusters)) | |||
| // | |||
| // var cs []struct { | |||
| // urls []*InferUrl | |||
| // clusterId string | |||
| // clusterName string | |||
| // imageNum int32 | |||
| // } | |||
| // inferMap := inferAdapterMap[opt.AdapterId] | |||
| // | |||
| // ////save taskai | |||
| // //for _, c := range clusters { | |||
| // // clusterName, _ := storage.GetClusterNameById(c.ClusterId) | |||
| // // opt.Replica = c.Replicas | |||
| // // err := storage.SaveAiTask(id, opt, adapterName, c.ClusterId, clusterName, "", constants.Saved, "") | |||
| // // if err != nil { | |||
| // // return nil, err | |||
| // // } | |||
| // //} | |||
| // | |||
| // var mutex sync.Mutex | |||
| // errMap := make(map[string]string) | |||
| // for _, cluster := range clusters { | |||
| // wg.Add(1) | |||
| // c := cluster | |||
| // go func() { | |||
| // imageUrls, err := inferMap[c.ClusterId].GetInferUrl(ctx, opt) | |||
| // if err != nil { | |||
| // mutex.Lock() | |||
| // errMap[c.ClusterId] = err.Error() | |||
| // mutex.Unlock() | |||
| // wg.Done() | |||
| // return | |||
| // } | |||
| // for i, _ := range imageUrls { | |||
| // imageUrls[i].Url = imageUrls[i].Url + "/" + "image" | |||
| // } | |||
| // clusterName, _ := storage.GetClusterNameById(c.ClusterId) | |||
| // | |||
| // s := struct { | |||
| // urls []*InferUrl | |||
| // clusterId string | |||
| // clusterName string | |||
| // imageNum int32 | |||
| // }{ | |||
| // urls: imageUrls, | |||
| // clusterId: c.ClusterId, | |||
| // clusterName: clusterName, | |||
| // imageNum: c.Replicas, | |||
| // } | |||
| // | |||
| // cluster_ch <- s | |||
| // wg.Done() | |||
| // return | |||
| // }() | |||
| // } | |||
| // wg.Wait() | |||
| // close(cluster_ch) | |||
| // | |||
| // for s := range cluster_ch { | |||
| // cs = append(cs, s) | |||
| // } | |||
| // | |||
| // aiTaskList, err := storage.GetAiTaskListById(id) | |||
| // if err != nil { | |||
| // return nil, err | |||
| // } | |||
| // | |||
| // //no cluster available | |||
| // if len(cs) == 0 { | |||
| // for _, t := range aiTaskList { | |||
| // t.Status = constants.Failed | |||
| // t.EndTime = time.Now().Format(time.RFC3339) | |||
| // if _, ok := errMap[strconv.Itoa(int(t.ClusterId))]; ok { | |||
| // t.Msg = errMap[strconv.Itoa(int(t.ClusterId))] | |||
| // } | |||
| // err := storage.UpdateAiTask(t) | |||
| // if err != nil { | |||
| // logx.Errorf(err.Error()) | |||
| // } | |||
| // } | |||
| // storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") | |||
| // return nil, errors.New("image infer task failed") | |||
| // } | |||
| // | |||
| // //change cluster status | |||
| // if len(clusters) != len(cs) { | |||
| // var acs []*strategy.AssignedCluster | |||
| // var rcs []*strategy.AssignedCluster | |||
| // for _, cluster := range clusters { | |||
| // if contains(cs, cluster.ClusterId) { | |||
| // var ac *strategy.AssignedCluster | |||
| // ac = cluster | |||
| // rcs = append(rcs, ac) | |||
| // } else { | |||
| // var ac *strategy.AssignedCluster | |||
| // ac = cluster | |||
| // acs = append(acs, ac) | |||
| // } | |||
| // } | |||
| // | |||
| // // update failed cluster status | |||
| // for _, ac := range acs { | |||
| // for _, t := range aiTaskList { | |||
| // if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { | |||
| // t.Status = constants.Failed | |||
| // t.EndTime = time.Now().Format(time.RFC3339) | |||
| // if _, ok := errMap[strconv.Itoa(int(t.ClusterId))]; ok { | |||
| // t.Msg = errMap[strconv.Itoa(int(t.ClusterId))] | |||
| // } | |||
| // err := storage.UpdateAiTask(t) | |||
| // if err != nil { | |||
| // logx.Errorf(err.Error()) | |||
| // } | |||
| // } | |||
| // } | |||
| // } | |||
| // | |||
| // // update running cluster status | |||
| // for _, ac := range rcs { | |||
| // for _, t := range aiTaskList { | |||
| // if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { | |||
| // t.Status = constants.Running | |||
| // err := storage.UpdateAiTask(t) | |||
| // if err != nil { | |||
| // logx.Errorf(err.Error()) | |||
| // } | |||
| // } | |||
| // } | |||
| // } | |||
| // storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") | |||
| // } else { | |||
| // for _, t := range aiTaskList { | |||
| // t.Status = constants.Running | |||
| // err := storage.UpdateAiTask(t) | |||
| // if err != nil { | |||
| // logx.Errorf(err.Error()) | |||
| // } | |||
| // } | |||
| // storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "running", "任务运行中") | |||
| // } | |||
| // | |||
| // var result_ch = make(chan *types.ImageResult, len(ts)) | |||
| // var results []*types.ImageResult | |||
| // limit := make(chan bool, 7) | |||
| // | |||
| // var imageNumIdx int32 = 0 | |||
| // var imageNumIdxEnd int32 = 0 | |||
| // for _, c := range cs { | |||
| // new_images := make([]*ImageFile, len(ts)) | |||
| // copy(new_images, ts) | |||
| // | |||
| // imageNumIdxEnd = imageNumIdxEnd + c.imageNum | |||
| // new_images = new_images[imageNumIdx:imageNumIdxEnd] | |||
| // imageNumIdx = imageNumIdx + c.imageNum | |||
| // | |||
| // wg.Add(len(new_images)) | |||
| // go sendInferReq(new_images, c, &wg, result_ch, limit) | |||
| // } | |||
| // wg.Wait() | |||
| // close(result_ch) | |||
| // | |||
| // for s := range result_ch { | |||
| // results = append(results, s) | |||
| // } | |||
| // | |||
| // sort.Slice(results, func(p, q int) bool { | |||
| // return results[p].ClusterName < results[q].ClusterName | |||
| // }) | |||
| // | |||
| // //save ai sub tasks | |||
| // for _, r := range results { | |||
| // for _, task := range aiTaskList { | |||
| // if r.ClusterId == strconv.Itoa(int(task.ClusterId)) { | |||
| // taskAiSub := models.TaskAiSub{ | |||
| // TaskId: id, | |||
| // TaskName: task.Name, | |||
| // TaskAiId: task.TaskId, | |||
| // TaskAiName: task.Name, | |||
| // ImageName: r.ImageName, | |||
| // Result: r.ImageResult, | |||
| // Card: r.Card, | |||
| // ClusterId: task.ClusterId, | |||
| // ClusterName: r.ClusterName, | |||
| // } | |||
| // err := storage.SaveAiTaskImageSubTask(&taskAiSub) | |||
| // if err != nil { | |||
| // panic(err) | |||
| // } | |||
| // } | |||
| // } | |||
| // } | |||
| // | |||
| // // update succeeded cluster status | |||
| // var successStatusCount int | |||
| // for _, c := range cs { | |||
| // for _, t := range aiTaskList { | |||
| // if c.clusterId == strconv.Itoa(int(t.ClusterId)) { | |||
| // t.Status = constants.Completed | |||
| // t.EndTime = time.Now().Format(time.RFC3339) | |||
| // err := storage.UpdateAiTask(t) | |||
| // if err != nil { | |||
| // logx.Errorf(err.Error()) | |||
| // } | |||
| // successStatusCount++ | |||
| // } else { | |||
| // continue | |||
| // } | |||
| // } | |||
| // } | |||
| // | |||
| // if len(cs) == successStatusCount { | |||
| // storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "completed", "任务完成") | |||
| // } else { | |||
| // storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") | |||
| // } | |||
| // | |||
| // return results, nil | |||
| //} | |||
| // | |||
| //func sendInferReq(images []*ImageFile, cluster struct { | |||
| // urls []*InferUrl | |||
| // clusterId string | |||
| // clusterName string | |||
| // imageNum int32 | |||
| //}, wg *sync.WaitGroup, ch chan<- *types.ImageResult, limit chan bool) { | |||
| // for _, image := range images { | |||
| // limit <- true | |||
| // go func(t *ImageFile, c struct { | |||
| // urls []*InferUrl | |||
| // clusterId string | |||
| // clusterName string | |||
| // imageNum int32 | |||
| // }) { | |||
| // if len(c.urls) == 1 { | |||
| // r, err := getInferResult(c.urls[0].Url, t.File, t.ImageResult.ImageName, c.clusterName) | |||
| // if err != nil { | |||
| // t.ImageResult.ImageResult = err.Error() | |||
| // t.ImageResult.ClusterId = c.clusterId | |||
| // t.ImageResult.ClusterName = c.clusterName | |||
| // t.ImageResult.Card = c.urls[0].Card | |||
| // ch <- t.ImageResult | |||
| // wg.Done() | |||
| // <-limit | |||
| // return | |||
| // } | |||
| // t.ImageResult.ImageResult = r | |||
| // t.ImageResult.ClusterId = c.clusterId | |||
| // t.ImageResult.ClusterName = c.clusterName | |||
| // t.ImageResult.Card = c.urls[0].Card | |||
| // | |||
| // ch <- t.ImageResult | |||
| // wg.Done() | |||
| // <-limit | |||
| // return | |||
| // } else { | |||
| // idx := rand.Intn(len(c.urls)) | |||
| // r, err := getInferResult(c.urls[idx].Url, t.File, t.ImageResult.ImageName, c.clusterName) | |||
| // if err != nil { | |||
| // t.ImageResult.ImageResult = err.Error() | |||
| // t.ImageResult.ClusterId = c.clusterId | |||
| // t.ImageResult.ClusterName = c.clusterName | |||
| // t.ImageResult.Card = c.urls[idx].Card | |||
| // ch <- t.ImageResult | |||
| // wg.Done() | |||
| // <-limit | |||
| // return | |||
| // } | |||
| // t.ImageResult.ImageResult = r | |||
| // t.ImageResult.ClusterId = c.clusterId | |||
| // t.ImageResult.ClusterName = c.clusterName | |||
| // t.ImageResult.Card = c.urls[idx].Card | |||
| // | |||
| // ch <- t.ImageResult | |||
| // wg.Done() | |||
| // <-limit | |||
| // return | |||
| // } | |||
| // }(image, cluster) | |||
| // <-limit | |||
| // } | |||
| //} | |||
| // | |||
| //func getInferResult(url string, file multipart.File, fileName string, clusterName string) (string, error) { | |||
| // if clusterName == "鹏城云脑II-modelarts" { | |||
| // r, err := getInferResultModelarts(url, file, fileName) | |||
| // if err != nil { | |||
| // return "", err | |||
| // } | |||
| // return r, nil | |||
| // } | |||
| // var res Res | |||
| // req := GetRestyRequest(20) | |||
| // _, err := req. | |||
| // SetFileReader("file", fileName, file). | |||
| // SetResult(&res). | |||
| // Post(url) | |||
| // if err != nil { | |||
| // return "", err | |||
| // } | |||
| // return res.Result, nil | |||
| //} | |||
| // | |||
| //func getInferResultModelarts(url string, file multipart.File, fileName string) (string, error) { | |||
| // var res Res | |||
| // /* req := GetRestyRequest(20) | |||
| // _, err := req. | |||
| // SetFileReader("file", fileName, file). | |||
| // SetHeaders(map[string]string{ | |||
| // "ak": "UNEHPHO4Z7YSNPKRXFE4", | |||
| // "sk": "JWXCE9qcYbc7RjpSRIWt4WgG3ZKF6Q4lPzkJReX9", | |||
| // }). | |||
| // SetResult(&res). | |||
| // Post(url) | |||
| // if err != nil { | |||
| // return "", err | |||
| // }*/ | |||
| // body, err := utils.SendRequest("POST", url, file, fileName) | |||
| // if err != nil { | |||
| // return "", err | |||
| // } | |||
| // errjson := json.Unmarshal([]byte(body), &res) | |||
| // if errjson != nil { | |||
| // log.Fatalf("Error parsing JSON: %s", errjson) | |||
| // } | |||
| // return res.Result, nil | |||
| //} | |||
| // | |||
| //func GetRestyRequest(timeoutSeconds int64) *resty.Request { | |||
| // client := resty.New().SetTimeout(time.Duration(timeoutSeconds) * time.Second) | |||
| // request := client.R() | |||
| // return request | |||
| //} | |||
| // | |||
| //type Res struct { | |||
| // Result string `json:"result"` | |||
| //} | |||
| // | |||
| //func contains(cs []struct { | |||
| // urls []*InferUrl | |||
| // clusterId string | |||
| // clusterName string | |||
| // imageNum int32 | |||
| //}, e string) bool { | |||
| // for _, c := range cs { | |||
| // if c.clusterId == e { | |||
| // return true | |||
| // } | |||
| // } | |||
| // return false | |||
| //} | |||
| @@ -1 +1,97 @@ | |||
| package textInference | |||
| import ( | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" | |||
| ) | |||
| type ITextInference interface { | |||
| SaveAiTask(id int64, adapterName string) error | |||
| UpdateStatus(aiTaskList []*models.TaskAi, adapterName string) error | |||
| AppendRoute(urls []*inference.InferUrl) error | |||
| AiType() string | |||
| } | |||
| type FilteredCluster struct { | |||
| urls []*inference.InferUrl | |||
| clusterId string | |||
| clusterName string | |||
| } | |||
| type TextInference struct { | |||
| inference ITextInference | |||
| opt *option.InferOption | |||
| storage *database.AiStorage | |||
| inferAdapter map[string]map[string]inference.ICluster | |||
| errMap map[string]string | |||
| adapterName string | |||
| } | |||
| func New( | |||
| inference ITextInference, | |||
| opt *option.InferOption, | |||
| storage *database.AiStorage, | |||
| inferAdapter map[string]map[string]inference.ICluster, | |||
| adapterName string) (*TextInference, error) { | |||
| return &TextInference{ | |||
| inference: inference, | |||
| opt: opt, | |||
| storage: storage, | |||
| inferAdapter: inferAdapter, | |||
| adapterName: adapterName, | |||
| errMap: make(map[string]string), | |||
| }, nil | |||
| } | |||
| func (ti *TextInference) CreateTask() (int64, error) { | |||
| id, err := ti.saveTask() | |||
| if err != nil { | |||
| return 0, err | |||
| } | |||
| err = ti.saveAiTask(id) | |||
| if err != nil { | |||
| return 0, err | |||
| } | |||
| return id, nil | |||
| } | |||
| func (ti *TextInference) InferTask(id int64) error { | |||
| aiTaskList, err := ti.storage.GetAiTaskListById(id) | |||
| if err != nil || len(aiTaskList) == 0 { | |||
| return err | |||
| } | |||
| err = ti.updateStatus(aiTaskList) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| return nil | |||
| } | |||
| func (ti *TextInference) saveTask() (int64, error) { | |||
| var synergystatus int64 | |||
| var strategyCode int64 | |||
| id, err := ti.storage.SaveTask(ti.opt.TaskName, strategyCode, synergystatus, ti.inference.AiType()) | |||
| if err != nil { | |||
| return 0, err | |||
| } | |||
| return id, nil | |||
| } | |||
| func (ti *TextInference) saveAiTask(id int64) error { | |||
| err := ti.inference.SaveAiTask(id, ti.adapterName) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| return nil | |||
| } | |||
| func (ti *TextInference) updateStatus(aiTaskList []*models.TaskAi) error { | |||
| err := ti.inference.UpdateStatus(aiTaskList, ti.adapterName) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| return nil | |||
| } | |||
| @@ -0,0 +1,48 @@ | |||
| package textInference | |||
| import ( | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/strategy" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" | |||
| ) | |||
| const ( | |||
| TEXTTOIMAGE = "text-to-image" | |||
| TEXTTOIMAGE_AiTYPE = "14" | |||
| ) | |||
| type TextToImage struct { | |||
| clusters []*strategy.AssignedCluster | |||
| storage *database.AiStorage | |||
| opt *option.InferOption | |||
| } | |||
| func (t *TextToImage) SaveAiTask(id int64, adapterName string) error { | |||
| for _, c := range t.clusters { | |||
| clusterName, _ := t.storage.GetClusterNameById(c.ClusterId) | |||
| t.opt.Replica = c.Replicas | |||
| err := t.storage.SaveAiTask(id, t.opt, adapterName, c.ClusterId, clusterName, "", constants.Saved, "") | |||
| if err != nil { | |||
| return err | |||
| } | |||
| } | |||
| return nil | |||
| } | |||
| func (t *TextToImage) UpdateStatus(aiTaskList []*models.TaskAi, adapterName string) error { | |||
| return nil | |||
| } | |||
| func (t *TextToImage) AppendRoute(urls []*inference.InferUrl) error { | |||
| for i, _ := range urls { | |||
| urls[i].Url = urls[i].Url + inference.FORWARD_SLASH + TEXTTOIMAGE | |||
| } | |||
| return nil | |||
| } | |||
| func (t *TextToImage) AiType() string { | |||
| return TEXTTOIMAGE_AiTYPE | |||
| } | |||
| @@ -0,0 +1,131 @@ | |||
| package textInference | |||
| import ( | |||
| "github.com/zeromicro/go-zero/core/logx" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" | |||
| "net/http" | |||
| "strconv" | |||
| "sync" | |||
| "time" | |||
| ) | |||
| const ( | |||
| CHAT = "chat" | |||
| TEXTTOTEXT_AITYPE = "12" | |||
| ) | |||
| type TextToText struct { | |||
| opt *option.InferOption | |||
| storage *database.AiStorage | |||
| inferAdapter map[string]map[string]inference.ICluster | |||
| cs []*FilteredCluster | |||
| } | |||
| func NewTextToText(opt *option.InferOption, storage *database.AiStorage, inferAdapter map[string]map[string]inference.ICluster) (*TextToText, error) { | |||
| cs, err := filterClusters(opt, storage, inferAdapter) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| return &TextToText{ | |||
| opt: opt, | |||
| storage: storage, | |||
| inferAdapter: inferAdapter, | |||
| cs: cs, | |||
| }, nil | |||
| } | |||
| func (tt *TextToText) AppendRoute(urls []*inference.InferUrl) error { | |||
| for i, _ := range urls { | |||
| urls[i].Url = urls[i].Url + inference.FORWARD_SLASH + CHAT | |||
| } | |||
| return nil | |||
| } | |||
| func (tt *TextToText) AiType() string { | |||
| return TEXTTOTEXT_AITYPE | |||
| } | |||
| func (tt *TextToText) SaveAiTask(id int64, adapterName string) error { | |||
| if len(tt.cs) == 0 { | |||
| clusterId := tt.opt.AiClusterIds[0] | |||
| clusterName, _ := tt.storage.GetClusterNameById(tt.opt.AiClusterIds[0]) | |||
| err := tt.storage.SaveAiTask(id, tt.opt, adapterName, clusterId, clusterName, "", constants.Failed, "") | |||
| if err != nil { | |||
| return err | |||
| } | |||
| tt.storage.AddNoticeInfo(tt.opt.AdapterId, adapterName, "", "", tt.opt.TaskName, "failed", "任务失败") | |||
| } | |||
| for _, c := range tt.cs { | |||
| clusterName, _ := tt.storage.GetClusterNameById(c.clusterId) | |||
| err := tt.storage.SaveAiTask(id, tt.opt, adapterName, c.clusterId, clusterName, "", constants.Saved, "") | |||
| if err != nil { | |||
| return err | |||
| } | |||
| } | |||
| return nil | |||
| } | |||
| func filterClusters(opt *option.InferOption, storage *database.AiStorage, inferAdapter map[string]map[string]inference.ICluster) ([]*FilteredCluster, error) { | |||
| var wg sync.WaitGroup | |||
| var ch = make(chan *FilteredCluster, len(opt.AiClusterIds)) | |||
| var cs []*FilteredCluster | |||
| inferMap := inferAdapter[opt.AdapterId] | |||
| for _, clusterId := range opt.AiClusterIds { | |||
| wg.Add(1) | |||
| go func(cId string) { | |||
| r := http.Request{} | |||
| urls, err := inferMap[cId].GetInferUrl(r.Context(), opt) | |||
| if err != nil { | |||
| wg.Done() | |||
| return | |||
| } | |||
| for i, _ := range urls { | |||
| urls[i].Url = urls[i].Url + inference.FORWARD_SLASH + CHAT | |||
| } | |||
| clusterName, _ := storage.GetClusterNameById(cId) | |||
| var f FilteredCluster | |||
| f.urls = urls | |||
| f.clusterId = cId | |||
| f.clusterName = clusterName | |||
| ch <- &f | |||
| wg.Done() | |||
| return | |||
| }(clusterId) | |||
| } | |||
| wg.Wait() | |||
| close(ch) | |||
| for s := range ch { | |||
| cs = append(cs, s) | |||
| } | |||
| return cs, nil | |||
| } | |||
| func (tt *TextToText) UpdateStatus(aiTaskList []*models.TaskAi, adapterName string) error { | |||
| for i, t := range aiTaskList { | |||
| if strconv.Itoa(int(t.ClusterId)) == tt.cs[i].clusterId { | |||
| t.Status = constants.Completed | |||
| t.EndTime = time.Now().Format(time.RFC3339) | |||
| url := tt.cs[i].urls[0].Url | |||
| t.InferUrl = url | |||
| err := tt.storage.UpdateAiTask(t) | |||
| if err != nil { | |||
| logx.Errorf(err.Error()) | |||
| return err | |||
| } | |||
| } | |||
| } | |||
| tt.storage.AddNoticeInfo(tt.opt.AdapterId, adapterName, "", "", tt.opt.TaskName, "completed", "任务完成") | |||
| return nil | |||
| } | |||