|
- package textInference
-
- import (
- "github.com/zeromicro/go-zero/core/logx"
- "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database"
- "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option"
- "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference"
- "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants"
- "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
- "net/http"
- "strconv"
- "sync"
- "time"
- )
-
- const (
- CHAT = "chat"
- TEXTTOTEXT_AITYPE = "12"
- )
-
- type TextToText struct {
- opt *option.InferOption
- storage *database.AiStorage
- inferAdapter map[string]map[string]inference.ICluster
- cs []*FilteredCluster
- }
-
- func NewTextToText(opt *option.InferOption, storage *database.AiStorage, inferAdapter map[string]map[string]inference.ICluster) (*TextToText, error) {
- cs, err := filterClusters(opt, storage, inferAdapter)
- if err != nil {
- return nil, err
- }
- return &TextToText{
- opt: opt,
- storage: storage,
- inferAdapter: inferAdapter,
- cs: cs,
- }, nil
- }
-
- func (tt *TextToText) GetAiType() string {
- return TEXTTOTEXT_AITYPE
- }
-
- func (tt *TextToText) SaveAiTask(id int64, adapterName string) error {
-
- if len(tt.cs) == 0 {
- clusterId := tt.opt.AiClusterIds[0]
- clusterName, _ := tt.storage.GetClusterNameById(tt.opt.AiClusterIds[0])
- err := tt.storage.SaveAiTask(id, tt.opt, adapterName, clusterId, clusterName, "", constants.Failed, "")
- if err != nil {
- return err
- }
- tt.storage.AddNoticeInfo(tt.opt.AdapterId, adapterName, "", "", tt.opt.TaskName, "failed", "任务失败")
- }
-
- for _, c := range tt.cs {
- clusterName, _ := tt.storage.GetClusterNameById(c.clusterId)
- err := tt.storage.SaveAiTask(id, tt.opt, adapterName, c.clusterId, clusterName, "", constants.Saved, "")
- if err != nil {
- return err
- }
- }
- return nil
- }
-
- func filterClusters(opt *option.InferOption, storage *database.AiStorage, inferAdapter map[string]map[string]inference.ICluster) ([]*FilteredCluster, error) {
- var wg sync.WaitGroup
- var ch = make(chan *FilteredCluster, len(opt.AiClusterIds))
- var cs []*FilteredCluster
- inferMap := inferAdapter[opt.AdapterId]
-
- for _, clusterId := range opt.AiClusterIds {
- wg.Add(1)
- go func(cId string) {
- r := http.Request{}
- clusterInferUrl, err := inferMap[cId].GetClusterInferUrl(r.Context(), opt)
- if err != nil {
- wg.Done()
- return
- }
-
- for i, _ := range clusterInferUrl.InferUrls {
- clusterInferUrl.InferUrls[i].Url = clusterInferUrl.InferUrls[i].Url + inference.FORWARD_SLASH + CHAT
- }
-
- clusterName, _ := storage.GetClusterNameById(cId)
-
- var f FilteredCluster
- f.urls = clusterInferUrl.InferUrls
- f.clusterId = cId
- f.clusterName = clusterName
-
- ch <- &f
- wg.Done()
- return
- }(clusterId)
- }
- wg.Wait()
- close(ch)
-
- for s := range ch {
- cs = append(cs, s)
- }
-
- return cs, nil
- }
-
- func (tt *TextToText) UpdateStatus(aiTaskList []*models.TaskAi, adapterName string) error {
- for i, t := range aiTaskList {
- if strconv.Itoa(int(t.ClusterId)) == tt.cs[i].clusterId {
- t.Status = constants.Completed
- t.EndTime = time.Now().Format(time.RFC3339)
- url := tt.cs[i].urls[0].Url
- t.InferUrl = url
- err := tt.storage.UpdateAiTask(t)
- if err != nil {
- logx.Errorf(err.Error())
- return err
- }
- }
- }
-
- tt.storage.AddNoticeInfo(tt.opt.AdapterId, adapterName, "", "", tt.opt.TaskName, "completed", "任务完成")
- return nil
- }
|