| @@ -5,11 +5,11 @@ import ( | |||||
| "errors" | "errors" | ||||
| "github.com/zeromicro/go-zero/core/logx" | "github.com/zeromicro/go-zero/core/logx" | ||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" | "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" | ||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference/imageInference" | "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference/imageInference" | ||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/strategy" | "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/strategy" | ||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc" | "gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc" | ||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" | "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" | ||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" | |||||
| "net/http" | "net/http" | ||||
| ) | ) | ||||
| @@ -102,44 +102,31 @@ func (l *ImageInferenceLogic) ImageInfer(r *http.Request, req *types.ImageInfere | |||||
| } | } | ||||
| } | } | ||||
| //save task | |||||
| var synergystatus int64 | |||||
| if len(clusters) > 1 { | |||||
| synergystatus = 1 | |||||
| } | |||||
| strategyCode, err := l.svcCtx.Scheduler.AiStorages.GetStrategyCode(opt.Strategy) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| adapterName, err := l.svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId) | adapterName, err := l.svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId) | ||||
| if err != nil { | if err != nil { | ||||
| return nil, err | return nil, err | ||||
| } | } | ||||
| id, err := l.svcCtx.Scheduler.AiStorages.SaveTask(opt.TaskName, strategyCode, synergystatus, "11") | |||||
| imageInfer, err := imageInference.New(imageInference.NewImageClassification(), ts, clusters, opt, l.svcCtx.Scheduler.AiStorages, l.svcCtx.Scheduler.AiService.InferenceAdapterMap, adapterName) | |||||
| if err != nil { | if err != nil { | ||||
| return nil, err | return nil, err | ||||
| } | } | ||||
| l.svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "create", "任务创建中") | |||||
| in := inference.Inference{ | |||||
| In: imageInfer, | |||||
| } | |||||
| //save taskai | |||||
| for _, c := range clusters { | |||||
| clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId) | |||||
| opt.Replica = c.Replicas | |||||
| err := l.svcCtx.Scheduler.AiStorages.SaveAiTask(id, opt, adapterName, c.ClusterId, clusterName, "", constants.Saved, "") | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| id, err := in.In.CreateTask() | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | } | ||||
| go func() { | go func() { | ||||
| ic, err := imageInference.NewImageClassification(ts, clusters, opt, l.svcCtx.Scheduler.AiStorages, l.svcCtx.Scheduler.AiService.InferenceAdapterMap, id, adapterName) | |||||
| err := in.In.InferTask(id) | |||||
| if err != nil { | if err != nil { | ||||
| logx.Errorf(err.Error()) | logx.Errorf(err.Error()) | ||||
| return | return | ||||
| } | } | ||||
| ic.Classify() | |||||
| }() | }() | ||||
| return resp, nil | return resp, nil | ||||
| @@ -6,14 +6,9 @@ import ( | |||||
| "github.com/zeromicro/go-zero/core/logx" | "github.com/zeromicro/go-zero/core/logx" | ||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" | "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" | ||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" | "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" | ||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/storeLink" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference/textInference" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc" | "gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc" | ||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" | "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" | ||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" | |||||
| "strconv" | |||||
| "sync" | |||||
| "time" | |||||
| ) | ) | ||||
| type TextToTextInferenceLogic struct { | type TextToTextInferenceLogic struct { | ||||
| @@ -46,105 +41,29 @@ func (l *TextToTextInferenceLogic) TextToTextInference(req *types.TextToTextInfe | |||||
| return nil, errors.New("AdapterId does not exist") | return nil, errors.New("AdapterId does not exist") | ||||
| } | } | ||||
| //save task | |||||
| var synergystatus int64 | |||||
| var strategyCode int64 | |||||
| adapterName, err := l.svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId) | adapterName, err := l.svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId) | ||||
| inType, err := textInference.NewTextToText(opt, l.svcCtx.Scheduler.AiStorages, l.svcCtx.Scheduler.AiService.InferenceAdapterMap) | |||||
| if err != nil { | if err != nil { | ||||
| return nil, err | return nil, err | ||||
| } | } | ||||
| id, err := l.svcCtx.Scheduler.AiStorages.SaveTask(opt.TaskName, strategyCode, synergystatus, "12") | |||||
| textInfer, err := textInference.New(inType, opt, l.svcCtx.Scheduler.AiStorages, l.svcCtx.Scheduler.AiService.InferenceAdapterMap, adapterName) | |||||
| if err != nil { | if err != nil { | ||||
| return nil, err | return nil, err | ||||
| } | } | ||||
| var wg sync.WaitGroup | |||||
| var cluster_ch = make(chan struct { | |||||
| urls []*inference.InferUrl | |||||
| clusterId string | |||||
| clusterName string | |||||
| }, len(opt.AiClusterIds)) | |||||
| var cs []struct { | |||||
| urls []*inference.InferUrl | |||||
| clusterId string | |||||
| clusterName string | |||||
| } | |||||
| inferMap := l.svcCtx.Scheduler.AiService.InferenceAdapterMap[opt.AdapterId] | |||||
| //save taskai | |||||
| for _, clusterId := range opt.AiClusterIds { | |||||
| wg.Add(1) | |||||
| go func(cId string) { | |||||
| urls, err := inferMap[cId].GetInferUrl(l.ctx, opt) | |||||
| if err != nil { | |||||
| wg.Done() | |||||
| return | |||||
| } | |||||
| clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(cId) | |||||
| s := struct { | |||||
| urls []*inference.InferUrl | |||||
| clusterId string | |||||
| clusterName string | |||||
| }{ | |||||
| urls: urls, | |||||
| clusterId: cId, | |||||
| clusterName: clusterName, | |||||
| } | |||||
| cluster_ch <- s | |||||
| wg.Done() | |||||
| return | |||||
| }(clusterId) | |||||
| } | |||||
| wg.Wait() | |||||
| close(cluster_ch) | |||||
| for s := range cluster_ch { | |||||
| cs = append(cs, s) | |||||
| } | |||||
| if len(cs) == 0 { | |||||
| clusterId := opt.AiClusterIds[0] | |||||
| clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(opt.AiClusterIds[0]) | |||||
| err := l.svcCtx.Scheduler.AiStorages.SaveAiTask(id, opt, adapterName, clusterId, clusterName, "", constants.Failed, "") | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| l.svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") | |||||
| } | |||||
| for _, c := range cs { | |||||
| clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(c.clusterId) | |||||
| err := l.svcCtx.Scheduler.AiStorages.SaveAiTask(id, opt, adapterName, c.clusterId, clusterName, "", constants.Saved, "") | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| in := inference.Inference{ | |||||
| In: textInfer, | |||||
| } | } | ||||
| var aiTaskList []*models.TaskAi | |||||
| tx := l.svcCtx.DbEngin.Raw("select * from task_ai where `task_id` = ? ", id).Scan(&aiTaskList) | |||||
| if tx.Error != nil { | |||||
| return nil, tx.Error | |||||
| id, err := in.In.CreateTask() | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | } | ||||
| for i, t := range aiTaskList { | |||||
| if strconv.Itoa(int(t.ClusterId)) == cs[i].clusterId { | |||||
| t.Status = constants.Completed | |||||
| t.EndTime = time.Now().Format(time.RFC3339) | |||||
| url := cs[i].urls[0].Url + storeLink.FORWARD_SLASH + "chat" | |||||
| t.InferUrl = url | |||||
| err := l.svcCtx.Scheduler.AiStorages.UpdateAiTask(t) | |||||
| if err != nil { | |||||
| logx.Errorf(tx.Error.Error()) | |||||
| } | |||||
| } | |||||
| err = in.In.InferTask(id) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | } | ||||
| l.svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "completed", "任务完成") | |||||
| return resp, nil | return resp, nil | ||||
| } | } | ||||
| @@ -26,7 +26,7 @@ const ( | |||||
| type AiService struct { | type AiService struct { | ||||
| AiExecutorAdapterMap map[string]map[string]executor.AiExecutor | AiExecutorAdapterMap map[string]map[string]executor.AiExecutor | ||||
| AiCollectorAdapterMap map[string]map[string]collector.AiCollector | AiCollectorAdapterMap map[string]map[string]collector.AiCollector | ||||
| InferenceAdapterMap map[string]map[string]inference.Inference | |||||
| InferenceAdapterMap map[string]map[string]inference.ICluster | |||||
| Storage *database.AiStorage | Storage *database.AiStorage | ||||
| mu sync.Mutex | mu sync.Mutex | ||||
| } | } | ||||
| @@ -40,7 +40,7 @@ func NewAiService(conf *config.Config, storages *database.AiStorage) (*AiService | |||||
| aiService := &AiService{ | aiService := &AiService{ | ||||
| AiExecutorAdapterMap: make(map[string]map[string]executor.AiExecutor), | AiExecutorAdapterMap: make(map[string]map[string]executor.AiExecutor), | ||||
| AiCollectorAdapterMap: make(map[string]map[string]collector.AiCollector), | AiCollectorAdapterMap: make(map[string]map[string]collector.AiCollector), | ||||
| InferenceAdapterMap: make(map[string]map[string]inference.Inference), | |||||
| InferenceAdapterMap: make(map[string]map[string]inference.ICluster), | |||||
| Storage: storages, | Storage: storages, | ||||
| } | } | ||||
| for _, id := range adapterIds { | for _, id := range adapterIds { | ||||
| @@ -60,10 +60,10 @@ func NewAiService(conf *config.Config, storages *database.AiStorage) (*AiService | |||||
| return aiService, nil | return aiService, nil | ||||
| } | } | ||||
| func InitAiClusterMap(conf *config.Config, clusters []types.ClusterInfo) (map[string]executor.AiExecutor, map[string]collector.AiCollector, map[string]inference.Inference) { | |||||
| func InitAiClusterMap(conf *config.Config, clusters []types.ClusterInfo) (map[string]executor.AiExecutor, map[string]collector.AiCollector, map[string]inference.ICluster) { | |||||
| executorMap := make(map[string]executor.AiExecutor) | executorMap := make(map[string]executor.AiExecutor) | ||||
| collectorMap := make(map[string]collector.AiCollector) | collectorMap := make(map[string]collector.AiCollector) | ||||
| inferenceMap := make(map[string]inference.Inference) | |||||
| inferenceMap := make(map[string]inference.ICluster) | |||||
| for _, c := range clusters { | for _, c := range clusters { | ||||
| switch c.Name { | switch c.Name { | ||||
| case OCTOPUS: | case OCTOPUS: | ||||
| @@ -1,419 +1,26 @@ | |||||
| package imageInference | package imageInference | ||||
| import ( | |||||
| "encoding/json" | |||||
| "errors" | |||||
| "github.com/go-resty/resty/v2" | |||||
| "github.com/zeromicro/go-zero/core/logx" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/strategy" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils" | |||||
| "log" | |||||
| "math/rand" | |||||
| "mime/multipart" | |||||
| "net/http" | |||||
| "sort" | |||||
| "strconv" | |||||
| "sync" | |||||
| "time" | |||||
| ) | |||||
| import "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" | |||||
| const ( | const ( | ||||
| IMAGE = "image" | |||||
| FORWARD_SLASH = "/" | |||||
| CLASSIFICATION = "image" | |||||
| CLASSIFICATION_AiTYPE = "11" | |||||
| ) | ) | ||||
| type ImageClassificationInterface interface { | |||||
| Classify() ([]*types.ImageResult, error) | |||||
| } | |||||
| type ImageFile struct { | |||||
| ImageResult *types.ImageResult | |||||
| File multipart.File | |||||
| } | |||||
| type FilteredCluster struct { | |||||
| urls []*inference.InferUrl | |||||
| clusterId string | |||||
| clusterName string | |||||
| imageNum int32 | |||||
| } | |||||
| type ImageClassification struct { | type ImageClassification struct { | ||||
| files []*ImageFile | |||||
| clusters []*strategy.AssignedCluster | |||||
| opt *option.InferOption | |||||
| storage *database.AiStorage | |||||
| inferAdapter map[string]map[string]inference.Inference | |||||
| errMap map[string]string | |||||
| taskId int64 | |||||
| adapterName string | |||||
| aiTaskList []*models.TaskAi | |||||
| } | |||||
| func NewImageClassification(files []*ImageFile, | |||||
| clusters []*strategy.AssignedCluster, | |||||
| opt *option.InferOption, | |||||
| storage *database.AiStorage, | |||||
| inferAdapter map[string]map[string]inference.Inference, | |||||
| taskId int64, | |||||
| adapterName string) (*ImageClassification, error) { | |||||
| aiTaskList, err := storage.GetAiTaskListById(taskId) | |||||
| if err != nil || len(aiTaskList) == 0 { | |||||
| return nil, err | |||||
| } | |||||
| return &ImageClassification{ | |||||
| files: files, | |||||
| clusters: clusters, | |||||
| opt: opt, | |||||
| storage: storage, | |||||
| inferAdapter: inferAdapter, | |||||
| taskId: taskId, | |||||
| adapterName: adapterName, | |||||
| errMap: make(map[string]string), | |||||
| aiTaskList: aiTaskList, | |||||
| }, nil | |||||
| } | |||||
| func (i *ImageClassification) Classify() ([]*types.ImageResult, error) { | |||||
| clusters, err := i.filterClusters() | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| err = i.updateStatus(clusters) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| results, err := i.inferImages(clusters) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return results, nil | |||||
| } | } | ||||
| func (i *ImageClassification) filterClusters() ([]*FilteredCluster, error) { | |||||
| var wg sync.WaitGroup | |||||
| var ch = make(chan *FilteredCluster, len(i.clusters)) | |||||
| var cs []*FilteredCluster | |||||
| var mutex sync.Mutex | |||||
| inferMap := i.inferAdapter[i.opt.AdapterId] | |||||
| for _, cluster := range i.clusters { | |||||
| wg.Add(1) | |||||
| c := cluster | |||||
| go func() { | |||||
| r := http.Request{} | |||||
| imageUrls, err := inferMap[c.ClusterId].GetInferUrl(r.Context(), i.opt) | |||||
| if err != nil { | |||||
| mutex.Lock() | |||||
| i.errMap[c.ClusterId] = err.Error() | |||||
| mutex.Unlock() | |||||
| wg.Done() | |||||
| return | |||||
| } | |||||
| for i, _ := range imageUrls { | |||||
| imageUrls[i].Url = imageUrls[i].Url + FORWARD_SLASH + IMAGE | |||||
| } | |||||
| clusterName, _ := i.storage.GetClusterNameById(c.ClusterId) | |||||
| var f FilteredCluster | |||||
| f.urls = imageUrls | |||||
| f.clusterId = c.ClusterId | |||||
| f.clusterName = clusterName | |||||
| f.imageNum = c.Replicas | |||||
| ch <- &f | |||||
| wg.Done() | |||||
| return | |||||
| }() | |||||
| } | |||||
| wg.Wait() | |||||
| close(ch) | |||||
| for s := range ch { | |||||
| cs = append(cs, s) | |||||
| } | |||||
| return cs, nil | |||||
| } | |||||
| func (i *ImageClassification) inferImages(cs []*FilteredCluster) ([]*types.ImageResult, error) { | |||||
| var wg sync.WaitGroup | |||||
| var ch = make(chan *types.ImageResult, len(i.files)) | |||||
| var results []*types.ImageResult | |||||
| limit := make(chan bool, 7) | |||||
| var imageNumIdx int32 = 0 | |||||
| var imageNumIdxEnd int32 = 0 | |||||
| for _, c := range cs { | |||||
| new_images := make([]*ImageFile, len(i.files)) | |||||
| copy(new_images, i.files) | |||||
| imageNumIdxEnd = imageNumIdxEnd + c.imageNum | |||||
| new_images = new_images[imageNumIdx:imageNumIdxEnd] | |||||
| imageNumIdx = imageNumIdx + c.imageNum | |||||
| wg.Add(len(new_images)) | |||||
| go sendInferReq(new_images, c, &wg, ch, limit) | |||||
| } | |||||
| wg.Wait() | |||||
| close(ch) | |||||
| for s := range ch { | |||||
| results = append(results, s) | |||||
| } | |||||
| sort.Slice(results, func(p, q int) bool { | |||||
| return results[p].ClusterName < results[q].ClusterName | |||||
| }) | |||||
| //save ai sub tasks | |||||
| for _, r := range results { | |||||
| for _, task := range i.aiTaskList { | |||||
| if r.ClusterId == strconv.Itoa(int(task.ClusterId)) { | |||||
| taskAiSub := models.TaskAiSub{ | |||||
| TaskId: i.taskId, | |||||
| TaskName: task.Name, | |||||
| TaskAiId: task.TaskId, | |||||
| TaskAiName: task.Name, | |||||
| ImageName: r.ImageName, | |||||
| Result: r.ImageResult, | |||||
| Card: r.Card, | |||||
| ClusterId: task.ClusterId, | |||||
| ClusterName: r.ClusterName, | |||||
| } | |||||
| err := i.storage.SaveAiTaskImageSubTask(&taskAiSub) | |||||
| if err != nil { | |||||
| panic(err) | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| // update succeeded cluster status | |||||
| var successStatusCount int | |||||
| for _, c := range cs { | |||||
| for _, t := range i.aiTaskList { | |||||
| if c.clusterId == strconv.Itoa(int(t.ClusterId)) { | |||||
| t.Status = constants.Completed | |||||
| t.EndTime = time.Now().Format(time.RFC3339) | |||||
| err := i.storage.UpdateAiTask(t) | |||||
| if err != nil { | |||||
| logx.Errorf(err.Error()) | |||||
| } | |||||
| successStatusCount++ | |||||
| } else { | |||||
| continue | |||||
| } | |||||
| } | |||||
| } | |||||
| if len(cs) == successStatusCount { | |||||
| i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "completed", "任务完成") | |||||
| } else { | |||||
| i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "failed", "任务失败") | |||||
| } | |||||
| return results, nil | |||||
| func NewImageClassification() *ImageClassification { | |||||
| return &ImageClassification{} | |||||
| } | } | ||||
| func (i *ImageClassification) updateStatus(cs []*FilteredCluster) error { | |||||
| //no cluster available | |||||
| if len(cs) == 0 { | |||||
| for _, t := range i.aiTaskList { | |||||
| t.Status = constants.Failed | |||||
| t.EndTime = time.Now().Format(time.RFC3339) | |||||
| if _, ok := i.errMap[strconv.Itoa(int(t.ClusterId))]; ok { | |||||
| t.Msg = i.errMap[strconv.Itoa(int(t.ClusterId))] | |||||
| } | |||||
| err := i.storage.UpdateAiTask(t) | |||||
| if err != nil { | |||||
| logx.Errorf(err.Error()) | |||||
| } | |||||
| } | |||||
| i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "failed", "任务失败") | |||||
| return errors.New("image infer task failed") | |||||
| } | |||||
| //change cluster status | |||||
| if len(i.clusters) != len(cs) { | |||||
| var acs []*strategy.AssignedCluster | |||||
| var rcs []*strategy.AssignedCluster | |||||
| for _, cluster := range i.clusters { | |||||
| if contains(cs, cluster.ClusterId) { | |||||
| var ac *strategy.AssignedCluster | |||||
| ac = cluster | |||||
| rcs = append(rcs, ac) | |||||
| } else { | |||||
| var ac *strategy.AssignedCluster | |||||
| ac = cluster | |||||
| acs = append(acs, ac) | |||||
| } | |||||
| } | |||||
| // update failed cluster status | |||||
| for _, ac := range acs { | |||||
| for _, t := range i.aiTaskList { | |||||
| if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { | |||||
| t.Status = constants.Failed | |||||
| t.EndTime = time.Now().Format(time.RFC3339) | |||||
| if _, ok := i.errMap[strconv.Itoa(int(t.ClusterId))]; ok { | |||||
| t.Msg = i.errMap[strconv.Itoa(int(t.ClusterId))] | |||||
| } | |||||
| err := i.storage.UpdateAiTask(t) | |||||
| if err != nil { | |||||
| logx.Errorf(err.Error()) | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| // update running cluster status | |||||
| for _, ac := range rcs { | |||||
| for _, t := range i.aiTaskList { | |||||
| if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { | |||||
| t.Status = constants.Running | |||||
| err := i.storage.UpdateAiTask(t) | |||||
| if err != nil { | |||||
| logx.Errorf(err.Error()) | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "failed", "任务失败") | |||||
| } else { | |||||
| for _, t := range i.aiTaskList { | |||||
| t.Status = constants.Running | |||||
| err := i.storage.UpdateAiTask(t) | |||||
| if err != nil { | |||||
| logx.Errorf(err.Error()) | |||||
| } | |||||
| } | |||||
| i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "running", "任务运行中") | |||||
| func (ic *ImageClassification) AppendRoute(urls []*inference.InferUrl) error { | |||||
| for i, _ := range urls { | |||||
| urls[i].Url = urls[i].Url + inference.FORWARD_SLASH + CLASSIFICATION | |||||
| } | } | ||||
| return nil | return nil | ||||
| } | } | ||||
| func sendInferReq(images []*ImageFile, cluster *FilteredCluster, wg *sync.WaitGroup, ch chan<- *types.ImageResult, limit chan bool) { | |||||
| for _, image := range images { | |||||
| limit <- true | |||||
| go func(t *ImageFile, c *FilteredCluster) { | |||||
| if len(c.urls) == 1 { | |||||
| r, err := getInferResult(c.urls[0].Url, t.File, t.ImageResult.ImageName, c.clusterName) | |||||
| if err != nil { | |||||
| t.ImageResult.ImageResult = err.Error() | |||||
| t.ImageResult.ClusterId = c.clusterId | |||||
| t.ImageResult.ClusterName = c.clusterName | |||||
| t.ImageResult.Card = c.urls[0].Card | |||||
| ch <- t.ImageResult | |||||
| wg.Done() | |||||
| <-limit | |||||
| return | |||||
| } | |||||
| t.ImageResult.ImageResult = r | |||||
| t.ImageResult.ClusterId = c.clusterId | |||||
| t.ImageResult.ClusterName = c.clusterName | |||||
| t.ImageResult.Card = c.urls[0].Card | |||||
| ch <- t.ImageResult | |||||
| wg.Done() | |||||
| <-limit | |||||
| return | |||||
| } else { | |||||
| idx := rand.Intn(len(c.urls)) | |||||
| r, err := getInferResult(c.urls[idx].Url, t.File, t.ImageResult.ImageName, c.clusterName) | |||||
| if err != nil { | |||||
| t.ImageResult.ImageResult = err.Error() | |||||
| t.ImageResult.ClusterId = c.clusterId | |||||
| t.ImageResult.ClusterName = c.clusterName | |||||
| t.ImageResult.Card = c.urls[idx].Card | |||||
| ch <- t.ImageResult | |||||
| wg.Done() | |||||
| <-limit | |||||
| return | |||||
| } | |||||
| t.ImageResult.ImageResult = r | |||||
| t.ImageResult.ClusterId = c.clusterId | |||||
| t.ImageResult.ClusterName = c.clusterName | |||||
| t.ImageResult.Card = c.urls[idx].Card | |||||
| ch <- t.ImageResult | |||||
| wg.Done() | |||||
| <-limit | |||||
| return | |||||
| } | |||||
| }(image, cluster) | |||||
| <-limit | |||||
| } | |||||
| } | |||||
| func getInferResult(url string, file multipart.File, fileName string, clusterName string) (string, error) { | |||||
| if clusterName == "鹏城云脑II-modelarts" { | |||||
| r, err := getInferResultModelarts(url, file, fileName) | |||||
| if err != nil { | |||||
| return "", err | |||||
| } | |||||
| return r, nil | |||||
| } | |||||
| var res Res | |||||
| req := GetRestyRequest(20) | |||||
| _, err := req. | |||||
| SetFileReader("file", fileName, file). | |||||
| SetResult(&res). | |||||
| Post(url) | |||||
| if err != nil { | |||||
| return "", err | |||||
| } | |||||
| return res.Result, nil | |||||
| } | |||||
| func getInferResultModelarts(url string, file multipart.File, fileName string) (string, error) { | |||||
| var res Res | |||||
| /* req := GetRestyRequest(20) | |||||
| _, err := req. | |||||
| SetFileReader("file", fileName, file). | |||||
| SetHeaders(map[string]string{ | |||||
| "ak": "UNEHPHO4Z7YSNPKRXFE4", | |||||
| "sk": "JWXCE9qcYbc7RjpSRIWt4WgG3ZKF6Q4lPzkJReX9", | |||||
| }). | |||||
| SetResult(&res). | |||||
| Post(url) | |||||
| if err != nil { | |||||
| return "", err | |||||
| }*/ | |||||
| body, err := utils.SendRequest("POST", url, file, fileName) | |||||
| if err != nil { | |||||
| return "", err | |||||
| } | |||||
| errjson := json.Unmarshal([]byte(body), &res) | |||||
| if errjson != nil { | |||||
| log.Fatalf("Error parsing JSON: %s", errjson) | |||||
| } | |||||
| return res.Result, nil | |||||
| } | |||||
| func GetRestyRequest(timeoutSeconds int64) *resty.Request { | |||||
| client := resty.New().SetTimeout(time.Duration(timeoutSeconds) * time.Second) | |||||
| request := client.R() | |||||
| return request | |||||
| } | |||||
| type Res struct { | |||||
| Result string `json:"result"` | |||||
| } | |||||
| func contains(cs []*FilteredCluster, e string) bool { | |||||
| for _, c := range cs { | |||||
| if c.clusterId == e { | |||||
| return true | |||||
| } | |||||
| } | |||||
| return false | |||||
| func (ic *ImageClassification) GetAiType() string { | |||||
| return CLASSIFICATION_AiTYPE | |||||
| } | } | ||||
| @@ -1,9 +1,469 @@ | |||||
| package imageInference | package imageInference | ||||
| import "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" | |||||
| import ( | |||||
| "encoding/json" | |||||
| "errors" | |||||
| "github.com/go-resty/resty/v2" | |||||
| "github.com/zeromicro/go-zero/core/logx" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/strategy" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils" | |||||
| "log" | |||||
| "math/rand" | |||||
| "mime/multipart" | |||||
| "net/http" | |||||
| "sort" | |||||
| "strconv" | |||||
| "sync" | |||||
| "time" | |||||
| ) | |||||
| type ImageInference interface { | |||||
| filterClusters() ([]*FilteredCluster, error) | |||||
| inferImages(cs []*FilteredCluster) ([]*types.ImageResult, error) | |||||
| updateStatus(cs []*FilteredCluster) error | |||||
| type IImageInference interface { | |||||
| AppendRoute(urls []*inference.InferUrl) error | |||||
| GetAiType() string | |||||
| } | |||||
| type ImageFile struct { | |||||
| ImageResult *types.ImageResult | |||||
| File multipart.File | |||||
| } | |||||
| type FilteredCluster struct { | |||||
| urls []*inference.InferUrl | |||||
| clusterId string | |||||
| clusterName string | |||||
| imageNum int32 | |||||
| } | |||||
| type ImageInference struct { | |||||
| inference IImageInference | |||||
| files []*ImageFile | |||||
| clusters []*strategy.AssignedCluster | |||||
| opt *option.InferOption | |||||
| storage *database.AiStorage | |||||
| inferAdapter map[string]map[string]inference.ICluster | |||||
| errMap map[string]string | |||||
| adapterName string | |||||
| } | |||||
| func New( | |||||
| inference IImageInference, | |||||
| files []*ImageFile, | |||||
| clusters []*strategy.AssignedCluster, | |||||
| opt *option.InferOption, | |||||
| storage *database.AiStorage, | |||||
| inferAdapter map[string]map[string]inference.ICluster, | |||||
| adapterName string) (*ImageInference, error) { | |||||
| return &ImageInference{ | |||||
| inference: inference, | |||||
| files: files, | |||||
| clusters: clusters, | |||||
| opt: opt, | |||||
| storage: storage, | |||||
| inferAdapter: inferAdapter, | |||||
| adapterName: adapterName, | |||||
| errMap: make(map[string]string), | |||||
| }, nil | |||||
| } | |||||
| func (i *ImageInference) CreateTask() (int64, error) { | |||||
| id, err := i.saveTask() | |||||
| if err != nil { | |||||
| return 0, err | |||||
| } | |||||
| err = i.saveAiTask(id) | |||||
| if err != nil { | |||||
| return 0, err | |||||
| } | |||||
| return id, nil | |||||
| } | |||||
| func (i *ImageInference) InferTask(id int64) error { | |||||
| clusters, err := i.filterClusters() | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| aiTaskList, err := i.storage.GetAiTaskListById(id) | |||||
| if err != nil || len(aiTaskList) == 0 { | |||||
| return err | |||||
| } | |||||
| err = i.updateStatus(aiTaskList, clusters) | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| results, err := i.inferImages(clusters) | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| err = i.saveAiSubTasks(id, aiTaskList, clusters, results) | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| return nil | |||||
| } | |||||
| func (i *ImageInference) saveTask() (int64, error) { | |||||
| var synergystatus int64 | |||||
| if len(i.clusters) > 1 { | |||||
| synergystatus = 1 | |||||
| } | |||||
| strategyCode, err := i.storage.GetStrategyCode(i.opt.Strategy) | |||||
| if err != nil { | |||||
| return 0, err | |||||
| } | |||||
| id, err := i.storage.SaveTask(i.opt.TaskName, strategyCode, synergystatus, i.inference.GetAiType()) | |||||
| if err != nil { | |||||
| return 0, err | |||||
| } | |||||
| i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "create", "任务创建中") | |||||
| return id, nil | |||||
| } | |||||
| func (i *ImageInference) saveAiTask(id int64) error { | |||||
| for _, c := range i.clusters { | |||||
| clusterName, _ := i.storage.GetClusterNameById(c.ClusterId) | |||||
| i.opt.Replica = c.Replicas | |||||
| err := i.storage.SaveAiTask(id, i.opt, i.adapterName, c.ClusterId, clusterName, "", constants.Saved, "") | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| } | |||||
| return nil | |||||
| } | |||||
| func (i *ImageInference) filterClusters() ([]*FilteredCluster, error) { | |||||
| var wg sync.WaitGroup | |||||
| var ch = make(chan *FilteredCluster, len(i.clusters)) | |||||
| var cs []*FilteredCluster | |||||
| var mutex sync.Mutex | |||||
| inferMap := i.inferAdapter[i.opt.AdapterId] | |||||
| for _, cluster := range i.clusters { | |||||
| wg.Add(1) | |||||
| c := cluster | |||||
| go func() { | |||||
| r := http.Request{} | |||||
| imageUrls, err := inferMap[c.ClusterId].GetInferUrl(r.Context(), i.opt) | |||||
| if err != nil { | |||||
| mutex.Lock() | |||||
| i.errMap[c.ClusterId] = err.Error() | |||||
| mutex.Unlock() | |||||
| wg.Done() | |||||
| return | |||||
| } | |||||
| i.inference.AppendRoute(imageUrls) | |||||
| clusterName, _ := i.storage.GetClusterNameById(c.ClusterId) | |||||
| var f FilteredCluster | |||||
| f.urls = imageUrls | |||||
| f.clusterId = c.ClusterId | |||||
| f.clusterName = clusterName | |||||
| f.imageNum = c.Replicas | |||||
| ch <- &f | |||||
| wg.Done() | |||||
| return | |||||
| }() | |||||
| } | |||||
| wg.Wait() | |||||
| close(ch) | |||||
| for s := range ch { | |||||
| cs = append(cs, s) | |||||
| } | |||||
| return cs, nil | |||||
| } | |||||
| func (i *ImageInference) inferImages(cs []*FilteredCluster) ([]*types.ImageResult, error) { | |||||
| var wg sync.WaitGroup | |||||
| var ch = make(chan *types.ImageResult, len(i.files)) | |||||
| var results []*types.ImageResult | |||||
| limit := make(chan bool, 7) | |||||
| var imageNumIdx int32 = 0 | |||||
| var imageNumIdxEnd int32 = 0 | |||||
| for _, c := range cs { | |||||
| new_images := make([]*ImageFile, len(i.files)) | |||||
| copy(new_images, i.files) | |||||
| imageNumIdxEnd = imageNumIdxEnd + c.imageNum | |||||
| new_images = new_images[imageNumIdx:imageNumIdxEnd] | |||||
| imageNumIdx = imageNumIdx + c.imageNum | |||||
| wg.Add(len(new_images)) | |||||
| go sendInferReq(new_images, c, &wg, ch, limit) | |||||
| } | |||||
| wg.Wait() | |||||
| close(ch) | |||||
| for s := range ch { | |||||
| results = append(results, s) | |||||
| } | |||||
| sort.Slice(results, func(p, q int) bool { | |||||
| return results[p].ClusterName < results[q].ClusterName | |||||
| }) | |||||
| return results, nil | |||||
| } | |||||
| func (i *ImageInference) updateStatus(aiTaskList []*models.TaskAi, cs []*FilteredCluster) error { | |||||
| //no cluster available | |||||
| if len(cs) == 0 { | |||||
| for _, t := range aiTaskList { | |||||
| t.Status = constants.Failed | |||||
| t.EndTime = time.Now().Format(time.RFC3339) | |||||
| if _, ok := i.errMap[strconv.Itoa(int(t.ClusterId))]; ok { | |||||
| t.Msg = i.errMap[strconv.Itoa(int(t.ClusterId))] | |||||
| } | |||||
| err := i.storage.UpdateAiTask(t) | |||||
| if err != nil { | |||||
| logx.Errorf(err.Error()) | |||||
| } | |||||
| } | |||||
| i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "failed", "任务失败") | |||||
| return errors.New("available clusters' empty, image infer task failed") | |||||
| } | |||||
| //change cluster status | |||||
| if len(i.clusters) != len(cs) { | |||||
| var acs []*strategy.AssignedCluster | |||||
| var rcs []*strategy.AssignedCluster | |||||
| for _, cluster := range i.clusters { | |||||
| if contains(cs, cluster.ClusterId) { | |||||
| var ac *strategy.AssignedCluster | |||||
| ac = cluster | |||||
| rcs = append(rcs, ac) | |||||
| } else { | |||||
| var ac *strategy.AssignedCluster | |||||
| ac = cluster | |||||
| acs = append(acs, ac) | |||||
| } | |||||
| } | |||||
| // update failed cluster status | |||||
| for _, ac := range acs { | |||||
| for _, t := range aiTaskList { | |||||
| if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { | |||||
| t.Status = constants.Failed | |||||
| t.EndTime = time.Now().Format(time.RFC3339) | |||||
| if _, ok := i.errMap[strconv.Itoa(int(t.ClusterId))]; ok { | |||||
| t.Msg = i.errMap[strconv.Itoa(int(t.ClusterId))] | |||||
| } | |||||
| err := i.storage.UpdateAiTask(t) | |||||
| if err != nil { | |||||
| logx.Errorf(err.Error()) | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| // update running cluster status | |||||
| for _, ac := range rcs { | |||||
| for _, t := range aiTaskList { | |||||
| if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { | |||||
| t.Status = constants.Running | |||||
| err := i.storage.UpdateAiTask(t) | |||||
| if err != nil { | |||||
| logx.Errorf(err.Error()) | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "failed", "任务失败") | |||||
| } else { | |||||
| for _, t := range aiTaskList { | |||||
| t.Status = constants.Running | |||||
| err := i.storage.UpdateAiTask(t) | |||||
| if err != nil { | |||||
| logx.Errorf(err.Error()) | |||||
| } | |||||
| } | |||||
| i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "running", "任务运行中") | |||||
| } | |||||
| return nil | |||||
| } | |||||
| func sendInferReq(images []*ImageFile, cluster *FilteredCluster, wg *sync.WaitGroup, ch chan<- *types.ImageResult, limit chan bool) { | |||||
| for _, image := range images { | |||||
| limit <- true | |||||
| go func(t *ImageFile, c *FilteredCluster) { | |||||
| if len(c.urls) == 1 { | |||||
| r, err := getInferResult(c.urls[0].Url, t.File, t.ImageResult.ImageName, c.clusterName) | |||||
| if err != nil { | |||||
| t.ImageResult.ImageResult = err.Error() | |||||
| t.ImageResult.ClusterId = c.clusterId | |||||
| t.ImageResult.ClusterName = c.clusterName | |||||
| t.ImageResult.Card = c.urls[0].Card | |||||
| ch <- t.ImageResult | |||||
| wg.Done() | |||||
| <-limit | |||||
| return | |||||
| } | |||||
| t.ImageResult.ImageResult = r | |||||
| t.ImageResult.ClusterId = c.clusterId | |||||
| t.ImageResult.ClusterName = c.clusterName | |||||
| t.ImageResult.Card = c.urls[0].Card | |||||
| ch <- t.ImageResult | |||||
| wg.Done() | |||||
| <-limit | |||||
| return | |||||
| } else { | |||||
| idx := rand.Intn(len(c.urls)) | |||||
| r, err := getInferResult(c.urls[idx].Url, t.File, t.ImageResult.ImageName, c.clusterName) | |||||
| if err != nil { | |||||
| t.ImageResult.ImageResult = err.Error() | |||||
| t.ImageResult.ClusterId = c.clusterId | |||||
| t.ImageResult.ClusterName = c.clusterName | |||||
| t.ImageResult.Card = c.urls[idx].Card | |||||
| ch <- t.ImageResult | |||||
| wg.Done() | |||||
| <-limit | |||||
| return | |||||
| } | |||||
| t.ImageResult.ImageResult = r | |||||
| t.ImageResult.ClusterId = c.clusterId | |||||
| t.ImageResult.ClusterName = c.clusterName | |||||
| t.ImageResult.Card = c.urls[idx].Card | |||||
| ch <- t.ImageResult | |||||
| wg.Done() | |||||
| <-limit | |||||
| return | |||||
| } | |||||
| }(image, cluster) | |||||
| <-limit | |||||
| } | |||||
| } | |||||
| func (i *ImageInference) saveAiSubTasks(id int64, aiTaskList []*models.TaskAi, cs []*FilteredCluster, results []*types.ImageResult) error { | |||||
| //save ai sub tasks | |||||
| for _, r := range results { | |||||
| for _, task := range aiTaskList { | |||||
| if r.ClusterId == strconv.Itoa(int(task.ClusterId)) { | |||||
| taskAiSub := models.TaskAiSub{ | |||||
| TaskId: id, | |||||
| TaskName: task.Name, | |||||
| TaskAiId: task.TaskId, | |||||
| TaskAiName: task.Name, | |||||
| ImageName: r.ImageName, | |||||
| Result: r.ImageResult, | |||||
| Card: r.Card, | |||||
| ClusterId: task.ClusterId, | |||||
| ClusterName: r.ClusterName, | |||||
| } | |||||
| err := i.storage.SaveAiTaskImageSubTask(&taskAiSub) | |||||
| if err != nil { | |||||
| panic(err) | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| // update succeeded cluster status | |||||
| var successStatusCount int | |||||
| for _, c := range cs { | |||||
| for _, t := range aiTaskList { | |||||
| if c.clusterId == strconv.Itoa(int(t.ClusterId)) { | |||||
| t.Status = constants.Completed | |||||
| t.EndTime = time.Now().Format(time.RFC3339) | |||||
| err := i.storage.UpdateAiTask(t) | |||||
| if err != nil { | |||||
| logx.Errorf(err.Error()) | |||||
| } | |||||
| successStatusCount++ | |||||
| } else { | |||||
| continue | |||||
| } | |||||
| } | |||||
| } | |||||
| if len(cs) == successStatusCount { | |||||
| i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "completed", "任务完成") | |||||
| } else { | |||||
| i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "failed", "任务失败") | |||||
| } | |||||
| return nil | |||||
| } | |||||
| func getInferResult(url string, file multipart.File, fileName string, clusterName string) (string, error) { | |||||
| if clusterName == "鹏城云脑II-modelarts" { | |||||
| r, err := getInferResultModelarts(url, file, fileName) | |||||
| if err != nil { | |||||
| return "", err | |||||
| } | |||||
| return r, nil | |||||
| } | |||||
| var res Res | |||||
| req := GetRestyRequest(20) | |||||
| _, err := req. | |||||
| SetFileReader("file", fileName, file). | |||||
| SetResult(&res). | |||||
| Post(url) | |||||
| if err != nil { | |||||
| return "", err | |||||
| } | |||||
| return res.Result, nil | |||||
| } | |||||
| func getInferResultModelarts(url string, file multipart.File, fileName string) (string, error) { | |||||
| var res Res | |||||
| /* req := GetRestyRequest(20) | |||||
| _, err := req. | |||||
| SetFileReader("file", fileName, file). | |||||
| SetHeaders(map[string]string{ | |||||
| "ak": "UNEHPHO4Z7YSNPKRXFE4", | |||||
| "sk": "JWXCE9qcYbc7RjpSRIWt4WgG3ZKF6Q4lPzkJReX9", | |||||
| }). | |||||
| SetResult(&res). | |||||
| Post(url) | |||||
| if err != nil { | |||||
| return "", err | |||||
| }*/ | |||||
| body, err := utils.SendRequest("POST", url, file, fileName) | |||||
| if err != nil { | |||||
| return "", err | |||||
| } | |||||
| errjson := json.Unmarshal([]byte(body), &res) | |||||
| if errjson != nil { | |||||
| log.Fatalf("Error parsing JSON: %s", errjson) | |||||
| } | |||||
| return res.Result, nil | |||||
| } | |||||
| func GetRestyRequest(timeoutSeconds int64) *resty.Request { | |||||
| client := resty.New().SetTimeout(time.Duration(timeoutSeconds) * time.Second) | |||||
| request := client.R() | |||||
| return request | |||||
| } | |||||
| type Res struct { | |||||
| Result string `json:"result"` | |||||
| } | |||||
| func contains(cs []*FilteredCluster, e string) bool { | |||||
| for _, c := range cs { | |||||
| if c.clusterId == e { | |||||
| return true | |||||
| } | |||||
| } | |||||
| return false | |||||
| } | } | ||||
| @@ -1,19 +1,22 @@ | |||||
| package imageInference | package imageInference | ||||
| import ( | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/strategy" | |||||
| import "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" | |||||
| const ( | |||||
| IMAGETOTEXT = "image-to-text" | |||||
| IMAGETOTEXT_AiTYPE = "13" | |||||
| ) | ) | ||||
| type ImageToText struct { | type ImageToText struct { | ||||
| files []*ImageFile | |||||
| clusters []*strategy.AssignedCluster | |||||
| opt *option.InferOption | |||||
| storage *database.AiStorage | |||||
| inferAdapter map[string]map[string]inference.Inference | |||||
| errMap map[string]string | |||||
| taskId int64 | |||||
| adapterName string | |||||
| } | |||||
| func (it *ImageToText) AppendRoute(urls []*inference.InferUrl) error { | |||||
| for i, _ := range urls { | |||||
| urls[i].Url = urls[i].Url + inference.FORWARD_SLASH + IMAGETOTEXT | |||||
| } | |||||
| return nil | |||||
| } | |||||
| func (it *ImageToText) GetAiType() string { | |||||
| return IMAGETOTEXT_AiTYPE | |||||
| } | } | ||||
| @@ -5,377 +5,25 @@ import ( | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" | "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" | ||||
| ) | ) | ||||
| type Inference interface { | |||||
| const ( | |||||
| FORWARD_SLASH = "/" | |||||
| ) | |||||
| type ICluster interface { | |||||
| GetInferUrl(ctx context.Context, option *option.InferOption) ([]*InferUrl, error) | GetInferUrl(ctx context.Context, option *option.InferOption) ([]*InferUrl, error) | ||||
| //GetInferDeployInstanceList(ctx context.Context, option *option.InferOption) | //GetInferDeployInstanceList(ctx context.Context, option *option.InferOption) | ||||
| } | } | ||||
| type IInference interface { | |||||
| CreateTask() (int64, error) | |||||
| InferTask(id int64) error | |||||
| } | |||||
| type Inference struct { | |||||
| In IInference | |||||
| } | |||||
| type InferUrl struct { | type InferUrl struct { | ||||
| Url string | Url string | ||||
| Card string | Card string | ||||
| } | } | ||||
| //func ImageInfer(opt *option.InferOption, id int64, adapterName string, clusters []*strategy.AssignedCluster, ts []*ImageFile, inferAdapterMap map[string]map[string]Inference, storage *database.AiStorage, ctx context.Context) ([]*types.ImageResult, error) { | |||||
| // | |||||
| // //for i := len(clusters) - 1; i >= 0; i-- { | |||||
| // // if clusters[i].Replicas == 0 { | |||||
| // // clusters = append(clusters[:i], clusters[i+1:]...) | |||||
| // // } | |||||
| // //} | |||||
| // var wg sync.WaitGroup | |||||
| // var cluster_ch = make(chan struct { | |||||
| // urls []*InferUrl | |||||
| // clusterId string | |||||
| // clusterName string | |||||
| // imageNum int32 | |||||
| // }, len(clusters)) | |||||
| // | |||||
| // var cs []struct { | |||||
| // urls []*InferUrl | |||||
| // clusterId string | |||||
| // clusterName string | |||||
| // imageNum int32 | |||||
| // } | |||||
| // inferMap := inferAdapterMap[opt.AdapterId] | |||||
| // | |||||
| // ////save taskai | |||||
| // //for _, c := range clusters { | |||||
| // // clusterName, _ := storage.GetClusterNameById(c.ClusterId) | |||||
| // // opt.Replica = c.Replicas | |||||
| // // err := storage.SaveAiTask(id, opt, adapterName, c.ClusterId, clusterName, "", constants.Saved, "") | |||||
| // // if err != nil { | |||||
| // // return nil, err | |||||
| // // } | |||||
| // //} | |||||
| // | |||||
| // var mutex sync.Mutex | |||||
| // errMap := make(map[string]string) | |||||
| // for _, cluster := range clusters { | |||||
| // wg.Add(1) | |||||
| // c := cluster | |||||
| // go func() { | |||||
| // imageUrls, err := inferMap[c.ClusterId].GetInferUrl(ctx, opt) | |||||
| // if err != nil { | |||||
| // mutex.Lock() | |||||
| // errMap[c.ClusterId] = err.Error() | |||||
| // mutex.Unlock() | |||||
| // wg.Done() | |||||
| // return | |||||
| // } | |||||
| // for i, _ := range imageUrls { | |||||
| // imageUrls[i].Url = imageUrls[i].Url + "/" + "image" | |||||
| // } | |||||
| // clusterName, _ := storage.GetClusterNameById(c.ClusterId) | |||||
| // | |||||
| // s := struct { | |||||
| // urls []*InferUrl | |||||
| // clusterId string | |||||
| // clusterName string | |||||
| // imageNum int32 | |||||
| // }{ | |||||
| // urls: imageUrls, | |||||
| // clusterId: c.ClusterId, | |||||
| // clusterName: clusterName, | |||||
| // imageNum: c.Replicas, | |||||
| // } | |||||
| // | |||||
| // cluster_ch <- s | |||||
| // wg.Done() | |||||
| // return | |||||
| // }() | |||||
| // } | |||||
| // wg.Wait() | |||||
| // close(cluster_ch) | |||||
| // | |||||
| // for s := range cluster_ch { | |||||
| // cs = append(cs, s) | |||||
| // } | |||||
| // | |||||
| // aiTaskList, err := storage.GetAiTaskListById(id) | |||||
| // if err != nil { | |||||
| // return nil, err | |||||
| // } | |||||
| // | |||||
| // //no cluster available | |||||
| // if len(cs) == 0 { | |||||
| // for _, t := range aiTaskList { | |||||
| // t.Status = constants.Failed | |||||
| // t.EndTime = time.Now().Format(time.RFC3339) | |||||
| // if _, ok := errMap[strconv.Itoa(int(t.ClusterId))]; ok { | |||||
| // t.Msg = errMap[strconv.Itoa(int(t.ClusterId))] | |||||
| // } | |||||
| // err := storage.UpdateAiTask(t) | |||||
| // if err != nil { | |||||
| // logx.Errorf(err.Error()) | |||||
| // } | |||||
| // } | |||||
| // storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") | |||||
| // return nil, errors.New("image infer task failed") | |||||
| // } | |||||
| // | |||||
| // //change cluster status | |||||
| // if len(clusters) != len(cs) { | |||||
| // var acs []*strategy.AssignedCluster | |||||
| // var rcs []*strategy.AssignedCluster | |||||
| // for _, cluster := range clusters { | |||||
| // if contains(cs, cluster.ClusterId) { | |||||
| // var ac *strategy.AssignedCluster | |||||
| // ac = cluster | |||||
| // rcs = append(rcs, ac) | |||||
| // } else { | |||||
| // var ac *strategy.AssignedCluster | |||||
| // ac = cluster | |||||
| // acs = append(acs, ac) | |||||
| // } | |||||
| // } | |||||
| // | |||||
| // // update failed cluster status | |||||
| // for _, ac := range acs { | |||||
| // for _, t := range aiTaskList { | |||||
| // if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { | |||||
| // t.Status = constants.Failed | |||||
| // t.EndTime = time.Now().Format(time.RFC3339) | |||||
| // if _, ok := errMap[strconv.Itoa(int(t.ClusterId))]; ok { | |||||
| // t.Msg = errMap[strconv.Itoa(int(t.ClusterId))] | |||||
| // } | |||||
| // err := storage.UpdateAiTask(t) | |||||
| // if err != nil { | |||||
| // logx.Errorf(err.Error()) | |||||
| // } | |||||
| // } | |||||
| // } | |||||
| // } | |||||
| // | |||||
| // // update running cluster status | |||||
| // for _, ac := range rcs { | |||||
| // for _, t := range aiTaskList { | |||||
| // if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { | |||||
| // t.Status = constants.Running | |||||
| // err := storage.UpdateAiTask(t) | |||||
| // if err != nil { | |||||
| // logx.Errorf(err.Error()) | |||||
| // } | |||||
| // } | |||||
| // } | |||||
| // } | |||||
| // storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") | |||||
| // } else { | |||||
| // for _, t := range aiTaskList { | |||||
| // t.Status = constants.Running | |||||
| // err := storage.UpdateAiTask(t) | |||||
| // if err != nil { | |||||
| // logx.Errorf(err.Error()) | |||||
| // } | |||||
| // } | |||||
| // storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "running", "任务运行中") | |||||
| // } | |||||
| // | |||||
| // var result_ch = make(chan *types.ImageResult, len(ts)) | |||||
| // var results []*types.ImageResult | |||||
| // limit := make(chan bool, 7) | |||||
| // | |||||
| // var imageNumIdx int32 = 0 | |||||
| // var imageNumIdxEnd int32 = 0 | |||||
| // for _, c := range cs { | |||||
| // new_images := make([]*ImageFile, len(ts)) | |||||
| // copy(new_images, ts) | |||||
| // | |||||
| // imageNumIdxEnd = imageNumIdxEnd + c.imageNum | |||||
| // new_images = new_images[imageNumIdx:imageNumIdxEnd] | |||||
| // imageNumIdx = imageNumIdx + c.imageNum | |||||
| // | |||||
| // wg.Add(len(new_images)) | |||||
| // go sendInferReq(new_images, c, &wg, result_ch, limit) | |||||
| // } | |||||
| // wg.Wait() | |||||
| // close(result_ch) | |||||
| // | |||||
| // for s := range result_ch { | |||||
| // results = append(results, s) | |||||
| // } | |||||
| // | |||||
| // sort.Slice(results, func(p, q int) bool { | |||||
| // return results[p].ClusterName < results[q].ClusterName | |||||
| // }) | |||||
| // | |||||
| // //save ai sub tasks | |||||
| // for _, r := range results { | |||||
| // for _, task := range aiTaskList { | |||||
| // if r.ClusterId == strconv.Itoa(int(task.ClusterId)) { | |||||
| // taskAiSub := models.TaskAiSub{ | |||||
| // TaskId: id, | |||||
| // TaskName: task.Name, | |||||
| // TaskAiId: task.TaskId, | |||||
| // TaskAiName: task.Name, | |||||
| // ImageName: r.ImageName, | |||||
| // Result: r.ImageResult, | |||||
| // Card: r.Card, | |||||
| // ClusterId: task.ClusterId, | |||||
| // ClusterName: r.ClusterName, | |||||
| // } | |||||
| // err := storage.SaveAiTaskImageSubTask(&taskAiSub) | |||||
| // if err != nil { | |||||
| // panic(err) | |||||
| // } | |||||
| // } | |||||
| // } | |||||
| // } | |||||
| // | |||||
| // // update succeeded cluster status | |||||
| // var successStatusCount int | |||||
| // for _, c := range cs { | |||||
| // for _, t := range aiTaskList { | |||||
| // if c.clusterId == strconv.Itoa(int(t.ClusterId)) { | |||||
| // t.Status = constants.Completed | |||||
| // t.EndTime = time.Now().Format(time.RFC3339) | |||||
| // err := storage.UpdateAiTask(t) | |||||
| // if err != nil { | |||||
| // logx.Errorf(err.Error()) | |||||
| // } | |||||
| // successStatusCount++ | |||||
| // } else { | |||||
| // continue | |||||
| // } | |||||
| // } | |||||
| // } | |||||
| // | |||||
| // if len(cs) == successStatusCount { | |||||
| // storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "completed", "任务完成") | |||||
| // } else { | |||||
| // storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") | |||||
| // } | |||||
| // | |||||
| // return results, nil | |||||
| //} | |||||
| // | |||||
| //func sendInferReq(images []*ImageFile, cluster struct { | |||||
| // urls []*InferUrl | |||||
| // clusterId string | |||||
| // clusterName string | |||||
| // imageNum int32 | |||||
| //}, wg *sync.WaitGroup, ch chan<- *types.ImageResult, limit chan bool) { | |||||
| // for _, image := range images { | |||||
| // limit <- true | |||||
| // go func(t *ImageFile, c struct { | |||||
| // urls []*InferUrl | |||||
| // clusterId string | |||||
| // clusterName string | |||||
| // imageNum int32 | |||||
| // }) { | |||||
| // if len(c.urls) == 1 { | |||||
| // r, err := getInferResult(c.urls[0].Url, t.File, t.ImageResult.ImageName, c.clusterName) | |||||
| // if err != nil { | |||||
| // t.ImageResult.ImageResult = err.Error() | |||||
| // t.ImageResult.ClusterId = c.clusterId | |||||
| // t.ImageResult.ClusterName = c.clusterName | |||||
| // t.ImageResult.Card = c.urls[0].Card | |||||
| // ch <- t.ImageResult | |||||
| // wg.Done() | |||||
| // <-limit | |||||
| // return | |||||
| // } | |||||
| // t.ImageResult.ImageResult = r | |||||
| // t.ImageResult.ClusterId = c.clusterId | |||||
| // t.ImageResult.ClusterName = c.clusterName | |||||
| // t.ImageResult.Card = c.urls[0].Card | |||||
| // | |||||
| // ch <- t.ImageResult | |||||
| // wg.Done() | |||||
| // <-limit | |||||
| // return | |||||
| // } else { | |||||
| // idx := rand.Intn(len(c.urls)) | |||||
| // r, err := getInferResult(c.urls[idx].Url, t.File, t.ImageResult.ImageName, c.clusterName) | |||||
| // if err != nil { | |||||
| // t.ImageResult.ImageResult = err.Error() | |||||
| // t.ImageResult.ClusterId = c.clusterId | |||||
| // t.ImageResult.ClusterName = c.clusterName | |||||
| // t.ImageResult.Card = c.urls[idx].Card | |||||
| // ch <- t.ImageResult | |||||
| // wg.Done() | |||||
| // <-limit | |||||
| // return | |||||
| // } | |||||
| // t.ImageResult.ImageResult = r | |||||
| // t.ImageResult.ClusterId = c.clusterId | |||||
| // t.ImageResult.ClusterName = c.clusterName | |||||
| // t.ImageResult.Card = c.urls[idx].Card | |||||
| // | |||||
| // ch <- t.ImageResult | |||||
| // wg.Done() | |||||
| // <-limit | |||||
| // return | |||||
| // } | |||||
| // }(image, cluster) | |||||
| // <-limit | |||||
| // } | |||||
| //} | |||||
| // | |||||
| //func getInferResult(url string, file multipart.File, fileName string, clusterName string) (string, error) { | |||||
| // if clusterName == "鹏城云脑II-modelarts" { | |||||
| // r, err := getInferResultModelarts(url, file, fileName) | |||||
| // if err != nil { | |||||
| // return "", err | |||||
| // } | |||||
| // return r, nil | |||||
| // } | |||||
| // var res Res | |||||
| // req := GetRestyRequest(20) | |||||
| // _, err := req. | |||||
| // SetFileReader("file", fileName, file). | |||||
| // SetResult(&res). | |||||
| // Post(url) | |||||
| // if err != nil { | |||||
| // return "", err | |||||
| // } | |||||
| // return res.Result, nil | |||||
| //} | |||||
| // | |||||
| //func getInferResultModelarts(url string, file multipart.File, fileName string) (string, error) { | |||||
| // var res Res | |||||
| // /* req := GetRestyRequest(20) | |||||
| // _, err := req. | |||||
| // SetFileReader("file", fileName, file). | |||||
| // SetHeaders(map[string]string{ | |||||
| // "ak": "UNEHPHO4Z7YSNPKRXFE4", | |||||
| // "sk": "JWXCE9qcYbc7RjpSRIWt4WgG3ZKF6Q4lPzkJReX9", | |||||
| // }). | |||||
| // SetResult(&res). | |||||
| // Post(url) | |||||
| // if err != nil { | |||||
| // return "", err | |||||
| // }*/ | |||||
| // body, err := utils.SendRequest("POST", url, file, fileName) | |||||
| // if err != nil { | |||||
| // return "", err | |||||
| // } | |||||
| // errjson := json.Unmarshal([]byte(body), &res) | |||||
| // if errjson != nil { | |||||
| // log.Fatalf("Error parsing JSON: %s", errjson) | |||||
| // } | |||||
| // return res.Result, nil | |||||
| //} | |||||
| // | |||||
| //func GetRestyRequest(timeoutSeconds int64) *resty.Request { | |||||
| // client := resty.New().SetTimeout(time.Duration(timeoutSeconds) * time.Second) | |||||
| // request := client.R() | |||||
| // return request | |||||
| //} | |||||
| // | |||||
| //type Res struct { | |||||
| // Result string `json:"result"` | |||||
| //} | |||||
| // | |||||
| //func contains(cs []struct { | |||||
| // urls []*InferUrl | |||||
| // clusterId string | |||||
| // clusterName string | |||||
| // imageNum int32 | |||||
| //}, e string) bool { | |||||
| // for _, c := range cs { | |||||
| // if c.clusterId == e { | |||||
| // return true | |||||
| // } | |||||
| // } | |||||
| // return false | |||||
| //} | |||||
| @@ -1 +1,97 @@ | |||||
| package textInference | package textInference | ||||
| import ( | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" | |||||
| ) | |||||
| type ITextInference interface { | |||||
| SaveAiTask(id int64, adapterName string) error | |||||
| UpdateStatus(aiTaskList []*models.TaskAi, adapterName string) error | |||||
| AppendRoute(urls []*inference.InferUrl) error | |||||
| AiType() string | |||||
| } | |||||
| type FilteredCluster struct { | |||||
| urls []*inference.InferUrl | |||||
| clusterId string | |||||
| clusterName string | |||||
| } | |||||
| type TextInference struct { | |||||
| inference ITextInference | |||||
| opt *option.InferOption | |||||
| storage *database.AiStorage | |||||
| inferAdapter map[string]map[string]inference.ICluster | |||||
| errMap map[string]string | |||||
| adapterName string | |||||
| } | |||||
| func New( | |||||
| inference ITextInference, | |||||
| opt *option.InferOption, | |||||
| storage *database.AiStorage, | |||||
| inferAdapter map[string]map[string]inference.ICluster, | |||||
| adapterName string) (*TextInference, error) { | |||||
| return &TextInference{ | |||||
| inference: inference, | |||||
| opt: opt, | |||||
| storage: storage, | |||||
| inferAdapter: inferAdapter, | |||||
| adapterName: adapterName, | |||||
| errMap: make(map[string]string), | |||||
| }, nil | |||||
| } | |||||
| func (ti *TextInference) CreateTask() (int64, error) { | |||||
| id, err := ti.saveTask() | |||||
| if err != nil { | |||||
| return 0, err | |||||
| } | |||||
| err = ti.saveAiTask(id) | |||||
| if err != nil { | |||||
| return 0, err | |||||
| } | |||||
| return id, nil | |||||
| } | |||||
| func (ti *TextInference) InferTask(id int64) error { | |||||
| aiTaskList, err := ti.storage.GetAiTaskListById(id) | |||||
| if err != nil || len(aiTaskList) == 0 { | |||||
| return err | |||||
| } | |||||
| err = ti.updateStatus(aiTaskList) | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| return nil | |||||
| } | |||||
| func (ti *TextInference) saveTask() (int64, error) { | |||||
| var synergystatus int64 | |||||
| var strategyCode int64 | |||||
| id, err := ti.storage.SaveTask(ti.opt.TaskName, strategyCode, synergystatus, ti.inference.AiType()) | |||||
| if err != nil { | |||||
| return 0, err | |||||
| } | |||||
| return id, nil | |||||
| } | |||||
| func (ti *TextInference) saveAiTask(id int64) error { | |||||
| err := ti.inference.SaveAiTask(id, ti.adapterName) | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| return nil | |||||
| } | |||||
| func (ti *TextInference) updateStatus(aiTaskList []*models.TaskAi) error { | |||||
| err := ti.inference.UpdateStatus(aiTaskList, ti.adapterName) | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| return nil | |||||
| } | |||||
| @@ -0,0 +1,48 @@ | |||||
| package textInference | |||||
| import ( | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/strategy" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" | |||||
| ) | |||||
| const ( | |||||
| TEXTTOIMAGE = "text-to-image" | |||||
| TEXTTOIMAGE_AiTYPE = "14" | |||||
| ) | |||||
| type TextToImage struct { | |||||
| clusters []*strategy.AssignedCluster | |||||
| storage *database.AiStorage | |||||
| opt *option.InferOption | |||||
| } | |||||
| func (t *TextToImage) SaveAiTask(id int64, adapterName string) error { | |||||
| for _, c := range t.clusters { | |||||
| clusterName, _ := t.storage.GetClusterNameById(c.ClusterId) | |||||
| t.opt.Replica = c.Replicas | |||||
| err := t.storage.SaveAiTask(id, t.opt, adapterName, c.ClusterId, clusterName, "", constants.Saved, "") | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| } | |||||
| return nil | |||||
| } | |||||
| func (t *TextToImage) UpdateStatus(aiTaskList []*models.TaskAi, adapterName string) error { | |||||
| return nil | |||||
| } | |||||
| func (t *TextToImage) AppendRoute(urls []*inference.InferUrl) error { | |||||
| for i, _ := range urls { | |||||
| urls[i].Url = urls[i].Url + inference.FORWARD_SLASH + TEXTTOIMAGE | |||||
| } | |||||
| return nil | |||||
| } | |||||
| func (t *TextToImage) AiType() string { | |||||
| return TEXTTOIMAGE_AiTYPE | |||||
| } | |||||
| @@ -0,0 +1,131 @@ | |||||
| package textInference | |||||
| import ( | |||||
| "github.com/zeromicro/go-zero/core/logx" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" | |||||
| "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" | |||||
| "net/http" | |||||
| "strconv" | |||||
| "sync" | |||||
| "time" | |||||
| ) | |||||
| const ( | |||||
| CHAT = "chat" | |||||
| TEXTTOTEXT_AITYPE = "12" | |||||
| ) | |||||
| type TextToText struct { | |||||
| opt *option.InferOption | |||||
| storage *database.AiStorage | |||||
| inferAdapter map[string]map[string]inference.ICluster | |||||
| cs []*FilteredCluster | |||||
| } | |||||
| func NewTextToText(opt *option.InferOption, storage *database.AiStorage, inferAdapter map[string]map[string]inference.ICluster) (*TextToText, error) { | |||||
| cs, err := filterClusters(opt, storage, inferAdapter) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return &TextToText{ | |||||
| opt: opt, | |||||
| storage: storage, | |||||
| inferAdapter: inferAdapter, | |||||
| cs: cs, | |||||
| }, nil | |||||
| } | |||||
| func (tt *TextToText) AppendRoute(urls []*inference.InferUrl) error { | |||||
| for i, _ := range urls { | |||||
| urls[i].Url = urls[i].Url + inference.FORWARD_SLASH + CHAT | |||||
| } | |||||
| return nil | |||||
| } | |||||
| func (tt *TextToText) AiType() string { | |||||
| return TEXTTOTEXT_AITYPE | |||||
| } | |||||
| func (tt *TextToText) SaveAiTask(id int64, adapterName string) error { | |||||
| if len(tt.cs) == 0 { | |||||
| clusterId := tt.opt.AiClusterIds[0] | |||||
| clusterName, _ := tt.storage.GetClusterNameById(tt.opt.AiClusterIds[0]) | |||||
| err := tt.storage.SaveAiTask(id, tt.opt, adapterName, clusterId, clusterName, "", constants.Failed, "") | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| tt.storage.AddNoticeInfo(tt.opt.AdapterId, adapterName, "", "", tt.opt.TaskName, "failed", "任务失败") | |||||
| } | |||||
| for _, c := range tt.cs { | |||||
| clusterName, _ := tt.storage.GetClusterNameById(c.clusterId) | |||||
| err := tt.storage.SaveAiTask(id, tt.opt, adapterName, c.clusterId, clusterName, "", constants.Saved, "") | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| } | |||||
| return nil | |||||
| } | |||||
| func filterClusters(opt *option.InferOption, storage *database.AiStorage, inferAdapter map[string]map[string]inference.ICluster) ([]*FilteredCluster, error) { | |||||
| var wg sync.WaitGroup | |||||
| var ch = make(chan *FilteredCluster, len(opt.AiClusterIds)) | |||||
| var cs []*FilteredCluster | |||||
| inferMap := inferAdapter[opt.AdapterId] | |||||
| for _, clusterId := range opt.AiClusterIds { | |||||
| wg.Add(1) | |||||
| go func(cId string) { | |||||
| r := http.Request{} | |||||
| urls, err := inferMap[cId].GetInferUrl(r.Context(), opt) | |||||
| if err != nil { | |||||
| wg.Done() | |||||
| return | |||||
| } | |||||
| for i, _ := range urls { | |||||
| urls[i].Url = urls[i].Url + inference.FORWARD_SLASH + CHAT | |||||
| } | |||||
| clusterName, _ := storage.GetClusterNameById(cId) | |||||
| var f FilteredCluster | |||||
| f.urls = urls | |||||
| f.clusterId = cId | |||||
| f.clusterName = clusterName | |||||
| ch <- &f | |||||
| wg.Done() | |||||
| return | |||||
| }(clusterId) | |||||
| } | |||||
| wg.Wait() | |||||
| close(ch) | |||||
| for s := range ch { | |||||
| cs = append(cs, s) | |||||
| } | |||||
| return cs, nil | |||||
| } | |||||
| func (tt *TextToText) UpdateStatus(aiTaskList []*models.TaskAi, adapterName string) error { | |||||
| for i, t := range aiTaskList { | |||||
| if strconv.Itoa(int(t.ClusterId)) == tt.cs[i].clusterId { | |||||
| t.Status = constants.Completed | |||||
| t.EndTime = time.Now().Format(time.RFC3339) | |||||
| url := tt.cs[i].urls[0].Url | |||||
| t.InferUrl = url | |||||
| err := tt.storage.UpdateAiTask(t) | |||||
| if err != nil { | |||||
| logx.Errorf(err.Error()) | |||||
| return err | |||||
| } | |||||
| } | |||||
| } | |||||
| tt.storage.AddNoticeInfo(tt.opt.AdapterId, adapterName, "", "", tt.opt.TaskName, "completed", "任务完成") | |||||
| return nil | |||||
| } | |||||