package status import ( "errors" "fmt" "github.com/zeromicro/go-zero/core/logx" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/config" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/collector" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/utils/jcs" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "net/http" "strconv" "sync" ) type TaskStatus struct { taskSyncLock sync.Mutex aiStorages *database.AiStorage aiCollectorAdapterMap map[string]map[string]collector.AiCollector config *config.Config } func NewTaskStatus(storage *database.AiStorage, aiCollectorAdapterMap map[string]map[string]collector.AiCollector, config *config.Config) *TaskStatus { return &TaskStatus{ aiStorages: storage, aiCollectorAdapterMap: aiCollectorAdapterMap, config: config, } } func (s *TaskStatus) UpdateAiTaskStatus(tasklist []*types.TaskModel) { s.taskSyncLock.Lock() defer s.taskSyncLock.Unlock() list := make([]*types.TaskModel, len(tasklist)) copy(list, tasklist) for i := len(list) - 1; i >= 0; i-- { if list[i].AdapterTypeDict != "1" || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed { list = append(list[:i], list[i+1:]...) } } if len(list) == 0 { return } for _, task := range list { aiTaskList, err := s.aiStorages.GetAiTaskListById(task.Id) if err != nil { logx.Errorf("UpdateAiTaskStatus Get AiTask Error %s", err.Error()) } if len(aiTaskList) == 0 { continue } go s.updateAiTask(aiTaskList) } } func (s *TaskStatus) UpdateTaskStatus(tasklist []*types.TaskModel) { s.taskSyncLock.Lock() defer s.taskSyncLock.Unlock() list := make([]*types.TaskModel, len(tasklist)) copy(list, tasklist) for i := len(list) - 1; i >= 0; i-- { if list[i].AdapterTypeDict != "1" || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed || list[i].Status == constants.Cancelled { list = append(list[:i], list[i+1:]...) } } if len(list) == 0 { return } for _, task := range list { aiTask, err := s.aiStorages.GetAiTaskListById(task.Id) if err != nil { logx.Errorf("UpdateTaskStatus Get AiTask Error %s", err.Error()) } if len(aiTask) == 0 { continue } logx.Errorf("############ Report Status Message Before switch %s", task.Status) if len(aiTask) == 1 { logx.Errorf("############ Report Status Message Switch %s", aiTask[0].Status) switch aiTask[0].Status { case constants.Completed: task.Status = constants.Succeeded logx.Errorf("############ Report Status Message Before Sending %s", task.Status) err = s.reportStatusMessages(task, aiTask[0]) if err != nil { logx.Errorf("reportStatusMessages Error %s", err.Error()) } case constants.Failed: task.Status = constants.Failed logx.Errorf("############ Report Status Message Before Sending %s", task.Status) err = s.reportStatusMessages(task, aiTask[0]) if err != nil { logx.Errorf("reportStatusMessages Error %s", err.Error()) } default: task.Status = aiTask[0].Status } task.StartTime = aiTask[0].StartTime task.EndTime = aiTask[0].EndTime err := s.aiStorages.UpdateTask(task) if err != nil { logx.Errorf("UpdateTaskStatus Update Task Error %s", err.Error()) } } } } func (s *TaskStatus) updateAiTask(aiTaskList []*models.TaskAi) { var wg sync.WaitGroup for _, aitask := range aiTaskList { t := aitask if t.Status == constants.Completed || t.Status == constants.Failed || t.JobId == "" || t.Status == constants.Cancelled { continue } wg.Add(1) go func() { h := http.Request{} trainingTask, err := s.aiCollectorAdapterMap[strconv.FormatInt(t.AdapterId, 10)][strconv.FormatInt(t.ClusterId, 10)].GetTrainingTask(h.Context(), t.JobId) if err != nil { if status.Code(err) == codes.DeadlineExceeded { msg := fmt.Sprintf("###UpdateAiTaskStatus###, AiTaskId: %v, clusterId: %v , JobId: %v, error: %v \n", t.Id, t.ClusterId, t.JobId, err.Error()) logx.Errorf(errors.New(msg).Error()) wg.Done() return } msg := fmt.Sprintf("###UpdateAiTaskStatus###, AiTaskId: %v, clusterId: %v , JobId: %v, error: %v \n", t.Id, t.ClusterId, t.JobId, err.Error()) logx.Errorf(errors.New(msg).Error()) wg.Done() return } if trainingTask == nil { wg.Done() return } switch trainingTask.Status { case constants.Running: if t.Status != trainingTask.Status { s.aiStorages.AddNoticeInfo(strconv.FormatInt(t.AdapterId, 10), t.AdapterName, strconv.FormatInt(t.ClusterId, 10), t.ClusterName, t.Name, "running", "任务运行中") t.Status = trainingTask.Status } case constants.Failed: if t.Status != trainingTask.Status { s.aiStorages.AddNoticeInfo(strconv.FormatInt(t.AdapterId, 10), t.AdapterName, strconv.FormatInt(t.ClusterId, 10), t.ClusterName, t.Name, "failed", "任务失败") t.Status = trainingTask.Status } case constants.Completed: if t.Status != trainingTask.Status { s.aiStorages.AddNoticeInfo(strconv.FormatInt(t.AdapterId, 10), t.AdapterName, strconv.FormatInt(t.ClusterId, 10), t.ClusterName, t.Name, "completed", "任务完成") t.Status = trainingTask.Status } default: if t.Status != trainingTask.Status { s.aiStorages.AddNoticeInfo(strconv.FormatInt(t.AdapterId, 10), t.AdapterName, strconv.FormatInt(t.ClusterId, 10), t.ClusterName, t.Name, "pending", "任务pending") t.Status = trainingTask.Status } } t.StartTime = trainingTask.Start t.EndTime = trainingTask.End err = s.aiStorages.UpdateAiTask(t) if err != nil { msg := fmt.Sprintf("###UpdateAiTaskStatus###, AiTaskId: %v, clusterId: %v , JobId: %v, error: %v \n", t.Id, t.ClusterId, t.JobId, err.Error()) logx.Errorf(errors.New(msg).Error()) wg.Done() return } wg.Done() }() } wg.Wait() } func (s *TaskStatus) reportStatusMessages(task *types.TaskModel, aiTask *models.TaskAi) error { report := &jcs.TrainReportMessage{ Type: "Train", TaskName: task.Name, TaskID: strconv.FormatInt(task.Id, 10), } var output string switch aiTask.ClusterName { case "openI": output = aiTask.JobId case "鹏城云脑II-modelarts": output = aiTask.Output } report.Status = true report.Message = "" report.ClusterID = strconv.FormatInt(aiTask.ClusterId, 10) report.Output = output err := jcs.StatusReport(s.config.JcsMiddleware.JobStatusReportUrl, report) if err != nil { return err } err = jcs.TempSaveReportToTask(s.aiStorages, task, report) if err != nil { return err } return nil } func ReportStatus(svc *svc.ServiceContext, taskName string, taskId string, clusterId string, url string, status bool, msg string) error { report := &jcs.InferReportMessage{ Type: "Inference", TaskName: taskName, TaskID: taskId, Status: status, Message: msg, ClusterID: clusterId, Url: url, } err := jcs.StatusReport(svc.Scheduler.AiService.Conf.JcsMiddleware.JobStatusReportUrl, report) if err != nil { return err } return nil }