package database import ( "github.com/zeromicro/go-zero/core/logx" clientCore "gitlink.org.cn/JointCloud/pcm-coordinator/client" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/tracker" "gorm.io/gorm" "strconv" "time" ) type AiStorage struct { DbEngin *gorm.DB } func (s *AiStorage) GetParticipants() (*types.ClusterListResp, error) { var resp types.ClusterListResp tx := s.DbEngin.Raw("select * from t_cluster where `deleted_at` IS NULL ORDER BY create_time Desc").Scan(&resp.List) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return nil, tx.Error } return &resp, nil } func (s *AiStorage) GetClustersByAdapterId(id string) (*types.ClusterListResp, error) { var resp types.ClusterListResp tx := s.DbEngin.Raw("select * from t_cluster where `deleted_at` IS NULL and `adapter_id` = ? ORDER BY create_time Desc", id).Scan(&resp.List) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return nil, tx.Error } return &resp, nil } func (s *AiStorage) GetClusterNameById(id string) (string, error) { var name string tx := s.DbEngin.Raw("select `description` from t_cluster where `id` = ?", id).Scan(&name) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return "", tx.Error } return name, nil } func (s *AiStorage) GetAdapterNameById(id string) (string, error) { var name string tx := s.DbEngin.Raw("select `name` from t_adapter where `id` = ?", id).Scan(&name) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return "", tx.Error } return name, nil } func (s *AiStorage) GetAdapterIdsByType(adapterType string) ([]string, error) { var list []types.AdapterInfo var ids []string db := s.DbEngin.Model(&types.AdapterInfo{}).Table("t_adapter") db = db.Where("type = ?", adapterType) err := db.Order("create_time desc").Find(&list).Error if err != nil { return nil, err } for _, info := range list { ids = append(ids, info.Id) } return ids, nil } func (s *AiStorage) GetAdaptersByType(adapterType string) ([]*types.AdapterInfo, error) { var list []*types.AdapterInfo db := s.DbEngin.Model(&types.AdapterInfo{}).Table("t_adapter") db = db.Where("type = ?", adapterType) err := db.Order("create_time desc").Find(&list).Error if err != nil { return nil, err } return list, nil } func (s *AiStorage) GetAiTasksByAdapterId(adapterId string) ([]*models.TaskAi, error) { var resp []*models.TaskAi db := s.DbEngin.Model(&models.TaskAi{}).Table("task_ai") db = db.Where("adapter_id = ?", adapterId) err := db.Order("commit_time desc").Find(&resp).Error if err != nil { return nil, err } return resp, nil } func (s *AiStorage) GetAiTaskListById(id int64) ([]*models.TaskAi, error) { var aiTaskList []*models.TaskAi tx := s.DbEngin.Raw("select * from task_ai where `task_id` = ? ", id).Scan(&aiTaskList) if tx.Error != nil { return nil, tx.Error } return aiTaskList, nil } func (s *AiStorage) DoesTaskNameExist(name string, taskType string) (bool, error) { var total int32 switch taskType { case "training": tx := s.DbEngin.Raw("select count(*) from task where `name` = ?", name).Scan(&total) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return false, tx.Error } case "inference": tx := s.DbEngin.Raw("select count(*) from ai_deploy_instance_task where `name` = ?", name).Scan(&total) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return false, tx.Error } } return total > 0, nil } func (s *AiStorage) SaveTask(name string, desc string, userId int64, strategyCode int64, synergyStatus int64, aiType string, yaml string, saveToChain func(task models.Task, id int64) error) (int64, error) { startTime := time.Now() // 构建主任务结构体 taskModel := models.Task{ Status: constants.Saved, Description: desc, Name: name, UserId: userId, SynergyStatus: synergyStatus, Strategy: strategyCode, AdapterTypeDict: "1", TaskTypeDict: aiType, YamlString: yaml, StartTime: &startTime, CommitTime: time.Now(), } // 保存任务数据到数据库 tx := s.DbEngin.Create(&taskModel) if tx.Error != nil { return 0, tx.Error } id := taskModel.Id // 数据上链 if saveToChain != nil { err := saveToChain(taskModel, id) if err != nil { logx.Error(err) } } return id, nil } func (s *AiStorage) UpdateTask(task *types.TaskModel) error { task.UpdatedTime = time.Now().Format(constants.Layout) tx := s.DbEngin.Table("task").Model(task).Updates(task) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return tx.Error } return nil } func (s *AiStorage) SaveAiTask(taskId int64, opt option.Option, adapterName string, clusterId string, clusterName string, jobId string, status string, msg string) error { var aiOpt *option.AiOption switch (opt).(type) { case *option.AiOption: aiOpt = (opt).(*option.AiOption) case *option.InferOption: inferOpt := (opt).(*option.InferOption) aiOpt = &option.AiOption{} aiOpt.TaskName = inferOpt.TaskName aiOpt.Replica = inferOpt.Replica aiOpt.AdapterId = inferOpt.AdapterId aiOpt.TaskType = inferOpt.ModelType aiOpt.ModelName = inferOpt.ModelName aiOpt.StrategyName = inferOpt.Strategy } // 构建主任务结构体 aId, err := strconv.ParseInt(aiOpt.AdapterId, 10, 64) if err != nil { return err } cId, err := strconv.ParseInt(clusterId, 10, 64) if err != nil { return err } aiTaskModel := models.TaskAi{ TaskId: taskId, AdapterId: aId, AdapterName: adapterName, ClusterId: cId, ClusterName: clusterName, Name: aiOpt.TaskName, Replica: int64(aiOpt.Replica), JobId: jobId, TaskType: aiOpt.TaskType, ModelName: aiOpt.ModelName, Strategy: aiOpt.StrategyName, Status: status, Msg: msg, Output: aiOpt.Output, Card: aiOpt.ComputeCard, StartTime: time.Now().Format(time.RFC3339), CommitTime: time.Now(), } // 保存任务数据到数据库 tx := s.DbEngin.Create(&aiTaskModel) if tx.Error != nil { return tx.Error } return nil } func (s *AiStorage) SaveAiTaskImageSubTask(ta *models.TaskAiSub) error { tx := s.DbEngin.Table("task_ai_sub").Create(ta) if tx.Error != nil { return tx.Error } return nil } func (s *AiStorage) SaveClusterTaskQueue(adapterId string, clusterId string, queueNum int64) error { aId, err := strconv.ParseInt(adapterId, 10, 64) if err != nil { return err } cId, err := strconv.ParseInt(clusterId, 10, 64) if err != nil { return err } taskQueue := models.TClusterTaskQueue{ AdapterId: aId, ClusterId: cId, QueueNum: queueNum, } tx := s.DbEngin.Create(&taskQueue) if tx.Error != nil { return tx.Error } return nil } func (s *AiStorage) GetClusterTaskQueues(adapterId string, clusterId string) ([]*models.TClusterTaskQueue, error) { var taskQueues []*models.TClusterTaskQueue tx := s.DbEngin.Raw("select * from t_cluster_task_queue where `adapter_id` = ? and `cluster_id` = ?", adapterId, clusterId).Scan(&taskQueues) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return nil, tx.Error } return taskQueues, nil } func (s *AiStorage) GetAiTaskIdByClusterIdAndTaskId(clusterId string, taskId string) (string, error) { var aiTask models.TaskAi tx := s.DbEngin.Raw("select * from task_ai where `cluster_id` = ? and `task_id` = ?", clusterId, taskId).Scan(&aiTask) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return "", tx.Error } return aiTask.JobId, nil } func (s *AiStorage) GetClusterResourcesById(clusterId string) (*models.TClusterResource, error) { var clusterResource models.TClusterResource tx := s.DbEngin.Raw("select * from t_cluster_resource where `cluster_id` = ?", clusterId).Scan(&clusterResource) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return nil, tx.Error } return &clusterResource, nil } func (s *AiStorage) SaveClusterResources(adapterId string, clusterId string, clusterName string, clusterType int64, cpuAvail float64, cpuTotal float64, memAvail float64, memTotal float64, diskAvail float64, diskTotal float64, gpuAvail float64, gpuTotal float64, cardTotal int64, topsTotal float64, cardHours float64, balance float64, taskCompleted int64) error { cId, err := strconv.ParseInt(clusterId, 10, 64) if err != nil { return err } aId, err := strconv.ParseInt(adapterId, 10, 64) if err != nil { return err } clusterResource := models.TClusterResource{ AdapterId: aId, ClusterId: cId, ClusterName: clusterName, ClusterType: clusterType, CpuAvail: cpuAvail, CpuTotal: cpuTotal, MemAvail: memAvail, MemTotal: memTotal, DiskAvail: diskAvail, DiskTotal: diskTotal, GpuAvail: gpuAvail, GpuTotal: gpuTotal, CardTotal: cardTotal, CardTopsTotal: topsTotal, CardHours: cardHours, Balance: balance, TaskCompleted: taskCompleted, } tx := s.DbEngin.Create(&clusterResource) if tx.Error != nil { return tx.Error } // prometheus param := tracker.ClusterLoadRecord{ AdapterId: aId, ClusterName: clusterName, CpuAvail: cpuAvail, CpuTotal: cpuTotal, CpuUtilisation: clusterResource.CpuAvail / clusterResource.CpuTotal, MemoryAvail: memAvail, MemoryTotal: memTotal, MemoryUtilisation: clusterResource.MemAvail / clusterResource.MemTotal, DiskAvail: diskAvail, DiskTotal: diskTotal, DiskUtilisation: clusterResource.DiskAvail / clusterResource.DiskTotal, } tracker.SyncClusterLoad(param) return nil } func (s *AiStorage) UpdateClusterResources(clusterResource *models.TClusterResource) error { tx := s.DbEngin.Where("cluster_id = ?", clusterResource.ClusterId).Updates(clusterResource) if tx.Error != nil { return tx.Error } // prometheus param := tracker.ClusterLoadRecord{ AdapterId: clusterResource.AdapterId, ClusterName: clusterResource.ClusterName, CpuAvail: clusterResource.CpuAvail, CpuTotal: clusterResource.CpuTotal, CpuUtilisation: clusterResource.CpuAvail / clusterResource.CpuTotal, MemoryAvail: clusterResource.MemAvail, MemoryTotal: clusterResource.MemTotal, MemoryUtilisation: clusterResource.MemAvail / clusterResource.MemTotal, DiskAvail: clusterResource.DiskAvail, DiskTotal: clusterResource.DiskTotal, DiskUtilisation: clusterResource.DiskAvail / clusterResource.DiskTotal, } tracker.SyncClusterLoad(param) return nil } func (s *AiStorage) UpdateAiTask(task *models.TaskAi) error { tx := s.DbEngin.Updates(task) if tx.Error != nil { return tx.Error } return nil } func (s *AiStorage) UpdateTaskByModel(task *models.Task) error { tx := s.DbEngin.Updates(task) if tx.Error != nil { return tx.Error } return nil } func (s *AiStorage) GetStrategyCode(name string) (int64, error) { var strategy int64 sqlStr := `select t_dict_item.item_value from t_dict left join t_dict_item on t_dict.id = t_dict_item.dict_id where item_text = ? and t_dict.dict_code = 'schedule_Strategy'` //查询调度策略 err := s.DbEngin.Raw(sqlStr, name).Scan(&strategy).Error if err != nil { return strategy, nil } return strategy, nil } func (s *AiStorage) AddNoticeInfo(adapterId string, adapterName string, clusterId string, clusterName string, taskName string, noticeType string, incident string) { aId, err := strconv.ParseInt(adapterId, 10, 64) if err != nil { logx.Errorf("adapterId convert failure, err: %v", err) } var cId int64 if clusterId != "" { cId, err = strconv.ParseInt(clusterId, 10, 64) if err != nil { logx.Errorf("clusterId convert failure, err: %v", err) } } noticeInfo := clientCore.NoticeInfo{ AdapterId: aId, AdapterName: adapterName, ClusterId: cId, ClusterName: clusterName, NoticeType: noticeType, TaskName: taskName, Incident: incident, CreatedTime: time.Now(), } result := s.DbEngin.Table("t_notice").Create(¬iceInfo) if result.Error != nil { logx.Errorf("Task creation failure, err: %v", result.Error) } } func (s *AiStorage) SaveInferDeployInstance(taskId int64, instanceId string, instanceName string, adapterId int64, adapterName string, clusterId int64, clusterName string, modelName string, modelType string, inferCard string, clusterType string) (int64, error) { startTime := time.Now().Format(time.RFC3339) // 构建主任务结构体 insModel := models.AiInferDeployInstance{ DeployInstanceTaskId: taskId, InstanceId: instanceId, InstanceName: instanceName, AdapterId: adapterId, AdapterName: adapterName, ClusterId: clusterId, ClusterName: clusterName, ModelName: modelName, ModelType: modelType, InferCard: inferCard, ClusterType: clusterType, Status: constants.Deploying, CreateTime: startTime, UpdateTime: startTime, } // 保存任务数据到数据库 tx := s.DbEngin.Table("ai_infer_deploy_instance").Create(&insModel) if tx.Error != nil { return 0, tx.Error } return insModel.Id, nil } func (s *AiStorage) UpdateInferDeployInstance(instance *models.AiInferDeployInstance, needUpdateTime bool) error { if needUpdateTime { instance.UpdateTime = time.Now().Format(time.RFC3339) } tx := s.DbEngin.Table("ai_infer_deploy_instance").Updates(instance) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return tx.Error } return nil } func (s *AiStorage) GetTaskById(id int64) (*models.Task, error) { var task *models.Task tx := s.DbEngin.Raw("select * from task where `id` = ?", id).Scan(&task) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return nil, tx.Error } return task, nil } func (s *AiStorage) GetInferDeployInstanceById(id int64) (*models.AiInferDeployInstance, error) { var deployIns *models.AiInferDeployInstance tx := s.DbEngin.Raw("select * from ai_infer_deploy_instance where `id` = ?", id).Scan(&deployIns) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return nil, tx.Error } return deployIns, nil } func (s *AiStorage) GetDeployTaskById(id int64) (*models.AiDeployInstanceTask, error) { var task *models.AiDeployInstanceTask tx := s.DbEngin.Raw("select * from ai_deploy_instance_task where `id` = ?", id).Scan(&task) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return nil, tx.Error } return task, nil } func (s *AiStorage) GetDeployTaskListByType(modelType string) ([]*models.AiDeployInstanceTask, error) { var tasks []*models.AiDeployInstanceTask tx := s.DbEngin.Raw("select * from ai_deploy_instance_task where `model_type` = ?", modelType).Scan(&tasks) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return nil, tx.Error } return tasks, nil } func (s *AiStorage) GetAllDeployTasks() ([]*models.AiDeployInstanceTask, error) { var tasks []*models.AiDeployInstanceTask tx := s.DbEngin.Raw("select * from ai_deploy_instance_task").Scan(&tasks) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return nil, tx.Error } return tasks, nil } func (s *AiStorage) UpdateDeployTask(task *models.AiDeployInstanceTask, needUpdateTime bool) error { if needUpdateTime { task.UpdateTime = time.Now().Format(time.RFC3339) } tx := s.DbEngin.Table("ai_deploy_instance_task").Updates(task) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return tx.Error } return nil } func (s *AiStorage) DeleteDeployTaskById(id int64) error { tx := s.DbEngin.Delete(&models.AiDeployInstanceTask{}, id) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return tx.Error } return nil } func (s *AiStorage) UpdateDeployTaskById(id int64) error { task, err := s.GetDeployTaskById(id) if err != nil { return err } err = s.UpdateDeployTask(task, true) if err != nil { return err } return nil } func (s *AiStorage) GetInstanceListByDeployTaskId(id int64) ([]*models.AiInferDeployInstance, error) { var list []*models.AiInferDeployInstance tx := s.DbEngin.Raw("select * from ai_infer_deploy_instance where `deploy_instance_task_id` = ?", id).Scan(&list) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return nil, tx.Error } return list, nil } func (s *AiStorage) GetInferDeployInstanceList() ([]*models.AiInferDeployInstance, error) { var list []*models.AiInferDeployInstance tx := s.DbEngin.Raw("select * from ai_infer_deploy_instance").Scan(&list) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return nil, tx.Error } return list, nil } func (s *AiStorage) GetDeployTaskList() ([]*models.AiDeployInstanceTask, error) { var list []*models.AiDeployInstanceTask tx := s.DbEngin.Raw("select * from ai_deploy_instance_task").Scan(&list) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return nil, tx.Error } return list, nil } func (s *AiStorage) GetInferDeployInstanceTotalNum() (int32, error) { var total int32 tx := s.DbEngin.Raw("select count(*) from ai_infer_deploy_instance").Scan(&total) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return 0, tx.Error } return total, nil } func (s *AiStorage) GetInferDeployInstanceRunningNum() (int32, error) { var total int32 tx := s.DbEngin.Raw("select count(*) from ai_infer_deploy_instance where `status` = 'Running'").Scan(&total) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return 0, tx.Error } return total, nil } func (s *AiStorage) GetInferenceTaskTotalNum() (int32, error) { var total int32 tx := s.DbEngin.Raw("select count(*) from task where `task_type_dict` = 11 or `task_type_dict` = 12").Scan(&total) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return 0, tx.Error } return total, nil } func (s *AiStorage) GetInferenceTaskRunningNum() (int32, error) { var total int32 tx := s.DbEngin.Raw("select count(*) from task where `task_type_dict` = 11 and `status` = 'Running'").Scan(&total) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return 0, tx.Error } return total, nil } func (s *AiStorage) GetTrainingTaskTotalNum() (int32, error) { var total int32 tx := s.DbEngin.Raw("select count(*) from task where `task_type_dict` = 10").Scan(&total) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return 0, tx.Error } return total, nil } func (s *AiStorage) GetTrainingTaskRunningNum() (int32, error) { var total int32 tx := s.DbEngin.Raw("select count(*) from task where `task_type_dict` = 10 and `status` = 'Running'").Scan(&total) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return 0, tx.Error } return total, nil } func (s *AiStorage) SaveInferDeployTask(taskName string, userId int64, modelName string, modelType string, desc string) (int64, error) { startTime := time.Now().Format(time.RFC3339) // 构建主任务结构体 taskModel := models.AiDeployInstanceTask{ Name: taskName, UserId: userId, ModelName: modelName, ModelType: modelType, Desc: desc, CreateTime: startTime, UpdateTime: startTime, } // 保存任务数据到数据库 tx := s.DbEngin.Table("ai_deploy_instance_task").Create(&taskModel) if tx.Error != nil { return 0, tx.Error } return taskModel.Id, nil } func (s *AiStorage) GetRunningDeployInstanceById(id int64, adapterId string) ([]*models.AiInferDeployInstance, error) { var list []*models.AiInferDeployInstance tx := s.DbEngin.Raw("select * from ai_infer_deploy_instance where `deploy_instance_task_id` = ? and `adapter_id` = ? and `status` = 'Running'", id, adapterId).Scan(&list) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return nil, tx.Error } return list, nil } func (s *AiStorage) IsDeployTaskNameDuplicated(name string) (bool, error) { var total int32 tx := s.DbEngin.Raw("select count(*) from ai_deploy_instance_task where `name` = ?", name).Scan(&total) if tx.Error != nil { return false, tx.Error } if total == 0 { return false, nil } return true, nil } func (s *AiStorage) GetClustersById(id string) (*types.ClusterInfo, error) { var resp types.ClusterInfo tx := s.DbEngin.Raw("select * from t_cluster where `id` = ? ", id).Scan(&resp) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return nil, tx.Error } return &resp, nil }