package database import ( "github.com/zeromicro/go-zero/core/logx" clientCore "gitlink.org.cn/JointCloud/pcm-coordinator/client" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/tracker" "gorm.io/gorm" "strconv" "time" ) type AiStorage struct { DbEngin *gorm.DB } func (s *AiStorage) GetParticipants() (*types.ClusterListResp, error) { var resp types.ClusterListResp tx := s.DbEngin.Raw("select * from t_cluster where `deleted_at` IS NULL ORDER BY create_time Desc").Scan(&resp.List) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return nil, tx.Error } return &resp, nil } func (s *AiStorage) GetClustersByAdapterId(id string) (*types.ClusterListResp, error) { var resp types.ClusterListResp tx := s.DbEngin.Raw("select * from t_cluster where `deleted_at` IS NULL and `adapter_id` = ? ORDER BY create_time Desc", id).Scan(&resp.List) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return nil, tx.Error } return &resp, nil } func (s *AiStorage) GetClusterNameById(id string) (string, error) { var name string tx := s.DbEngin.Raw("select `description` from t_cluster where `id` = ?", id).Scan(&name) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return "", tx.Error } return name, nil } func (s *AiStorage) GetAdapterNameById(id string) (string, error) { var name string tx := s.DbEngin.Raw("select `name` from t_adapter where `id` = ?", id).Scan(&name) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return "", tx.Error } return name, nil } func (s *AiStorage) GetAdapterIdsByType(adapterType string) ([]string, error) { var list []types.AdapterInfo var ids []string db := s.DbEngin.Model(&types.AdapterInfo{}).Table("t_adapter") db = db.Where("type = ?", adapterType) err := db.Order("create_time desc").Find(&list).Error if err != nil { return nil, err } for _, info := range list { ids = append(ids, info.Id) } return ids, nil } func (s *AiStorage) GetAdaptersByType(adapterType string) ([]*types.AdapterInfo, error) { var list []*types.AdapterInfo db := s.DbEngin.Model(&types.AdapterInfo{}).Table("t_adapter") db = db.Where("type = ?", adapterType) err := db.Order("create_time desc").Find(&list).Error if err != nil { return nil, err } return list, nil } func (s *AiStorage) GetAiTasksByAdapterId(adapterId string) ([]*models.TaskAi, error) { var resp []*models.TaskAi db := s.DbEngin.Model(&models.TaskAi{}).Table("task_ai") db = db.Where("adapter_id = ?", adapterId) err := db.Order("commit_time desc").Find(&resp).Error if err != nil { return nil, err } return resp, nil } func (s *AiStorage) GetAiTaskListById(id int64) ([]*models.TaskAi, error) { var aiTaskList []*models.TaskAi tx := s.DbEngin.Raw("select * from task_ai where `task_id` = ? ", id).Scan(&aiTaskList) if tx.Error != nil { return nil, tx.Error } return aiTaskList, nil } func (s *AiStorage) SaveTask(name string, strategyCode int64, synergyStatus int64, aiType string) (int64, error) { startTime := time.Now() // 构建主任务结构体 taskModel := models.Task{ Status: constants.Saved, Description: "ai task", Name: name, SynergyStatus: synergyStatus, Strategy: strategyCode, AdapterTypeDict: "1", TaskTypeDict: aiType, StartTime: &startTime, CommitTime: time.Now(), } // 保存任务数据到数据库 tx := s.DbEngin.Create(&taskModel) if tx.Error != nil { return 0, tx.Error } return taskModel.Id, nil } func (s *AiStorage) UpdateTask(task *types.TaskModel) error { task.UpdatedTime = time.Now().Format(constants.Layout) tx := s.DbEngin.Table("task").Model(task).Updates(task) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return tx.Error } return nil } func (s *AiStorage) SaveAiTask(taskId int64, opt option.Option, adapterName string, clusterId string, clusterName string, jobId string, status string, msg string) error { var aiOpt *option.AiOption switch (opt).(type) { case *option.AiOption: aiOpt = (opt).(*option.AiOption) case *option.InferOption: inferOpt := (opt).(*option.InferOption) aiOpt = &option.AiOption{} aiOpt.TaskName = inferOpt.TaskName aiOpt.Replica = inferOpt.Replica aiOpt.AdapterId = inferOpt.AdapterId aiOpt.TaskType = inferOpt.ModelType aiOpt.StrategyName = inferOpt.Strategy } // 构建主任务结构体 aId, err := strconv.ParseInt(aiOpt.AdapterId, 10, 64) if err != nil { return err } cId, err := strconv.ParseInt(clusterId, 10, 64) if err != nil { return err } aiTaskModel := models.TaskAi{ TaskId: taskId, AdapterId: aId, AdapterName: adapterName, ClusterId: cId, ClusterName: clusterName, Name: aiOpt.TaskName, Replica: int64(aiOpt.Replica), JobId: jobId, TaskType: aiOpt.TaskType, Strategy: aiOpt.StrategyName, Status: status, Msg: msg, Card: aiOpt.ComputeCard, StartTime: time.Now().Format(time.RFC3339), CommitTime: time.Now(), } // 保存任务数据到数据库 tx := s.DbEngin.Create(&aiTaskModel) if tx.Error != nil { return tx.Error } return nil } func (s *AiStorage) SaveAiTaskImageSubTask(ta *models.TaskAiSub) error { tx := s.DbEngin.Table("task_ai_sub").Create(ta) if tx.Error != nil { return tx.Error } return nil } func (s *AiStorage) SaveClusterTaskQueue(adapterId string, clusterId string, queueNum int64) error { aId, err := strconv.ParseInt(adapterId, 10, 64) if err != nil { return err } cId, err := strconv.ParseInt(clusterId, 10, 64) if err != nil { return err } taskQueue := models.TClusterTaskQueue{ AdapterId: aId, ClusterId: cId, QueueNum: queueNum, } tx := s.DbEngin.Create(&taskQueue) if tx.Error != nil { return tx.Error } return nil } func (s *AiStorage) GetClusterTaskQueues(adapterId string, clusterId string) ([]*models.TClusterTaskQueue, error) { var taskQueues []*models.TClusterTaskQueue tx := s.DbEngin.Raw("select * from t_cluster_task_queue where `adapter_id` = ? and `cluster_id` = ?", adapterId, clusterId).Scan(&taskQueues) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return nil, tx.Error } return taskQueues, nil } func (s *AiStorage) GetAiTaskIdByClusterIdAndTaskId(clusterId string, taskId string) (string, error) { var aiTask models.TaskAi tx := s.DbEngin.Raw("select * from task_ai where `cluster_id` = ? and `task_id` = ?", clusterId, taskId).Scan(&aiTask) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return "", tx.Error } return aiTask.JobId, nil } func (s *AiStorage) GetClusterResourcesById(clusterId string) (*models.TClusterResource, error) { var clusterResource models.TClusterResource tx := s.DbEngin.Raw("select * from t_cluster_resource where `cluster_id` = ?", clusterId).Scan(&clusterResource) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return nil, tx.Error } return &clusterResource, nil } func (s *AiStorage) SaveClusterResources(adapterId string, clusterId string, clusterName string, clusterType int64, cpuAvail float64, cpuTotal float64, memAvail float64, memTotal float64, diskAvail float64, diskTotal float64, gpuAvail float64, gpuTotal float64, cardTotal int64, topsTotal float64, cardHours float64, balance float64, taskCompleted int64) error { cId, err := strconv.ParseInt(clusterId, 10, 64) if err != nil { return err } aId, err := strconv.ParseInt(adapterId, 10, 64) if err != nil { return err } clusterResource := models.TClusterResource{ AdapterId: aId, ClusterId: cId, ClusterName: clusterName, ClusterType: clusterType, CpuAvail: cpuAvail, CpuTotal: cpuTotal, MemAvail: memAvail, MemTotal: memTotal, DiskAvail: diskAvail, DiskTotal: diskTotal, GpuAvail: gpuAvail, GpuTotal: gpuTotal, CardTotal: cardTotal, CardTopsTotal: topsTotal, CardHours: cardHours, Balance: balance, TaskCompleted: taskCompleted, } tx := s.DbEngin.Create(&clusterResource) if tx.Error != nil { return tx.Error } // prometheus param := tracker.ClusterLoadRecord{ AdapterId: aId, ClusterName: clusterName, CpuAvail: cpuAvail, CpuTotal: cpuTotal, CpuUtilisation: clusterResource.CpuAvail / clusterResource.CpuTotal, MemoryAvail: memAvail, MemoryTotal: memTotal, MemoryUtilisation: clusterResource.MemAvail / clusterResource.MemTotal, DiskAvail: diskAvail, DiskTotal: diskTotal, DiskUtilisation: clusterResource.DiskAvail / clusterResource.DiskTotal, } tracker.SyncClusterLoad(param) return nil } func (s *AiStorage) UpdateClusterResources(clusterResource *models.TClusterResource) error { tx := s.DbEngin.Where("cluster_id = ?", clusterResource.ClusterId).Updates(clusterResource) if tx.Error != nil { return tx.Error } // prometheus param := tracker.ClusterLoadRecord{ AdapterId: clusterResource.AdapterId, ClusterName: clusterResource.ClusterName, CpuAvail: clusterResource.CpuAvail, CpuTotal: clusterResource.CpuTotal, CpuUtilisation: clusterResource.CpuAvail / clusterResource.CpuTotal, MemoryAvail: clusterResource.MemAvail, MemoryTotal: clusterResource.MemTotal, MemoryUtilisation: clusterResource.MemAvail / clusterResource.MemTotal, DiskAvail: clusterResource.DiskAvail, DiskTotal: clusterResource.DiskTotal, DiskUtilisation: clusterResource.DiskAvail / clusterResource.DiskTotal, } tracker.SyncClusterLoad(param) return nil } func (s *AiStorage) UpdateAiTask(task *models.TaskAi) error { tx := s.DbEngin.Updates(task) if tx.Error != nil { return tx.Error } return nil } func (s *AiStorage) GetStrategyCode(name string) (int64, error) { var strategy int64 sqlStr := `select t_dict_item.item_value from t_dict left join t_dict_item on t_dict.id = t_dict_item.dict_id where item_text = ? and t_dict.dict_code = 'schedule_Strategy'` //查询调度策略 err := s.DbEngin.Raw(sqlStr, name).Scan(&strategy).Error if err != nil { return strategy, nil } return strategy, nil } func (s *AiStorage) AddNoticeInfo(adapterId string, adapterName string, clusterId string, clusterName string, taskName string, noticeType string, incident string) { aId, err := strconv.ParseInt(adapterId, 10, 64) if err != nil { logx.Errorf("adapterId convert failure, err: %v", err) } var cId int64 if clusterId != "" { cId, err = strconv.ParseInt(clusterId, 10, 64) if err != nil { logx.Errorf("clusterId convert failure, err: %v", err) } } noticeInfo := clientCore.NoticeInfo{ AdapterId: aId, AdapterName: adapterName, ClusterId: cId, ClusterName: clusterName, NoticeType: noticeType, TaskName: taskName, Incident: incident, CreatedTime: time.Now(), } result := s.DbEngin.Table("t_notice").Create(¬iceInfo) if result.Error != nil { logx.Errorf("Task creation failure, err: %v", result.Error) } }