package textInference import ( "github.com/zeromicro/go-zero/core/logx" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" "net/http" "strconv" "sync" "time" ) const ( CHAT = "chat" TEXTTOTEXT_AITYPE = "12" ) type TextToText struct { opt *option.InferOption storage *database.AiStorage inferAdapter map[string]map[string]inference.ICluster cs []*FilteredCluster } func NewTextToText(opt *option.InferOption, storage *database.AiStorage, inferAdapter map[string]map[string]inference.ICluster) (*TextToText, error) { cs, err := filterClusters(opt, storage, inferAdapter) if err != nil { return nil, err } return &TextToText{ opt: opt, storage: storage, inferAdapter: inferAdapter, cs: cs, }, nil } func (tt *TextToText) GetAiType() string { return TEXTTOTEXT_AITYPE } func (tt *TextToText) SaveAiTask(id int64, adapterName string) error { if len(tt.cs) == 0 { clusterId := tt.opt.AiClusterIds[0] clusterName, _ := tt.storage.GetClusterNameById(tt.opt.AiClusterIds[0]) err := tt.storage.SaveAiTask(id, tt.opt, adapterName, clusterId, clusterName, "", constants.Failed, "") if err != nil { return err } tt.storage.AddNoticeInfo(tt.opt.AdapterId, adapterName, "", "", tt.opt.TaskName, "failed", "任务失败") } for _, c := range tt.cs { clusterName, _ := tt.storage.GetClusterNameById(c.clusterId) err := tt.storage.SaveAiTask(id, tt.opt, adapterName, c.clusterId, clusterName, "", constants.Saved, "") if err != nil { return err } } return nil } func filterClusters(opt *option.InferOption, storage *database.AiStorage, inferAdapter map[string]map[string]inference.ICluster) ([]*FilteredCluster, error) { var wg sync.WaitGroup var ch = make(chan *FilteredCluster, len(opt.AiClusterIds)) var cs []*FilteredCluster inferMap := inferAdapter[opt.AdapterId] for _, clusterId := range opt.AiClusterIds { wg.Add(1) go func(cId string) { r := http.Request{} clusterInferUrl, err := inferMap[cId].GetClusterInferUrl(r.Context(), opt) if err != nil { wg.Done() return } for i, _ := range clusterInferUrl.InferUrls { clusterInferUrl.InferUrls[i].Url = clusterInferUrl.InferUrls[i].Url + inference.FORWARD_SLASH + CHAT } clusterName, _ := storage.GetClusterNameById(cId) var f FilteredCluster f.urls = clusterInferUrl.InferUrls f.clusterId = cId f.clusterName = clusterName ch <- &f wg.Done() return }(clusterId) } wg.Wait() close(ch) for s := range ch { cs = append(cs, s) } return cs, nil } func (tt *TextToText) UpdateStatus(aiTaskList []*models.TaskAi, adapterName string) error { for i, t := range aiTaskList { if strconv.Itoa(int(t.ClusterId)) == tt.cs[i].clusterId { t.Status = constants.Completed t.EndTime = time.Now().Format(time.RFC3339) url := tt.cs[i].urls[0].Url t.InferUrl = url err := tt.storage.UpdateAiTask(t) if err != nil { logx.Errorf(err.Error()) return err } } } tt.storage.AddNoticeInfo(tt.opt.AdapterId, adapterName, "", "", tt.opt.TaskName, "completed", "任务完成") return nil }