package hpc import ( "context" "fmt" "regexp" "strconv" "strings" "sync" "text/template" "time" jsoniter "github.com/json-iterator/go" "github.com/pkg/errors" "github.com/rs/zerolog/log" "github.com/zeromicro/go-zero/core/logc" "github.com/zeromicro/go-zero/core/logx" clientCore "gitlink.org.cn/JointCloud/pcm-coordinator/client" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils" ) type CommitHpcTaskLogic struct { logx.Logger ctx context.Context svcCtx *svc.ServiceContext hpcService *service.HpcService } type JobRequest struct { App string `json:"app"` Common CommonParams `json:"common"` AppSpecific map[string]interface{} `json:"appSpecific"` } type CommonParams struct { JobName string `json:"jobName"` Partition string `json:"partition"` Nodes string `json:"nodes"` NTasks string `json:"ntasks"` Time string `json:"time,omitempty"` App string `json:"app"` } func NewCommitHpcTaskLogic(ctx context.Context, svcCtx *svc.ServiceContext) *CommitHpcTaskLogic { cache := make(map[string]interface{}, 10) hpcService, err := service.NewHpcService(&svcCtx.Config, svcCtx.Scheduler.HpcStorages, cache) if err != nil { return nil } return &CommitHpcTaskLogic{ Logger: logx.WithContext(ctx), ctx: ctx, svcCtx: svcCtx, hpcService: hpcService, } } // 新增:缓存模板对象 var templateCache = sync.Map{} func (l *CommitHpcTaskLogic) getClusterInfo(clusterID string) (*types.ClusterInfo, *types.AdapterInfo, error) { var clusterInfo types.ClusterInfo if err := l.svcCtx.DbEngin.Table("t_cluster").Where("id = ?", clusterID).First(&clusterInfo).Error; err != nil { return nil, nil, fmt.Errorf("cluster query failed: %w", err) } if clusterInfo.Id == "" { return nil, nil, errors.New("cluster not found") } var adapterInfo types.AdapterInfo if err := l.svcCtx.DbEngin.Table("t_adapter").Where("id = ?", clusterInfo.AdapterId).First(&adapterInfo).Error; err != nil { return nil, nil, fmt.Errorf("adapter query failed: %w", err) } if adapterInfo.Id == "" { return nil, nil, errors.New("adapter not found") } return &clusterInfo, &adapterInfo, nil } // 自定义函数映射 func createFuncMap() template.FuncMap { return template.FuncMap{ "regexMatch": regexMatch, "required": required, "error": errorHandler, "default": defaultHandler, } } func extractUserError(originalErr error) error { // 尝试匹配模板引擎返回的错误格式 re := regexp.MustCompile(`error calling \w+: (.*)$`) matches := re.FindStringSubmatch(originalErr.Error()) if len(matches) > 1 { return errors.New(matches[1]) } return originalErr } // 正则匹配函数 func regexMatch(pattern string) *regexp.Regexp { return regexp.MustCompile(pattern) } // 必填字段检查 func required(msg string, val interface{}) (interface{}, error) { if val == nil || val == "" { return nil, errors.New(msg) } return val, nil } // 错误处理函数 func errorHandler(msg string) (string, error) { return "", errors.New(msg) } // 默认值处理函数 func defaultHandler(defaultVal interface{}, val interface{}) interface{} { switch v := val.(type) { case nil: return defaultVal case string: if v == "" { return defaultVal } case int: if v == 0 { return defaultVal } // 可根据需要添加其他类型判断 } return val } func (l *CommitHpcTaskLogic) RenderJobScript(templateContent string, req *JobRequest) (string, error) { // 使用缓存模板 tmpl, ok := templateCache.Load(templateContent) if !ok { parsedTmpl, err := template.New("slurmTemplate").Funcs(createFuncMap()).Parse(templateContent) if err != nil { return "", err } templateCache.Store(templateContent, parsedTmpl) tmpl = parsedTmpl } params := map[string]interface{}{ "Common": req.Common, "App": req.AppSpecific, } var buf strings.Builder if err := tmpl.(*template.Template).Execute(&buf, params); err != nil { log.Error().Err(err).Msg("模板渲染失败") return "", extractUserError(err) } return buf.String(), nil } func ConvertToJobRequest(job *types.CommitHpcTaskReq) (JobRequest, error) { required := []string{"jobName", "nodes", "ntasks"} for _, field := range required { if job.Parameters[field] == "" { return JobRequest{}, fmt.Errorf("%s is empty", field) } } return JobRequest{ App: job.App, Common: CommonParams{ JobName: job.Parameters["jobName"], Partition: job.Parameters["partition"], Nodes: job.Parameters["nodes"], NTasks: job.Parameters["ntasks"], Time: job.Parameters["time"], App: job.App, }, AppSpecific: utils.MpaStringToInterface(job.Parameters), }, nil } func (l *CommitHpcTaskLogic) SaveHpcTaskToDB(req *types.CommitHpcTaskReq, jobScript, jobId, workDir string) (taskId string, err error) { // 使用事务确保数据一致性 tx := l.svcCtx.DbEngin.Begin() defer func() { if r := recover(); r != nil { tx.Rollback() err = fmt.Errorf("transaction panic: %v", r) } else if err != nil { tx.Rollback() } }() userID, _ := strconv.ParseInt(req.Parameters[constants.UserId], 10, 64) taskID := utils.GenSnowflakeID() taskModel := models.Task{ Id: taskID, Name: req.Name, Description: req.Description, CommitTime: time.Now(), Status: constants.StatusSaved, AdapterTypeDict: constants.AdapterTypeHPC, UserId: userID, UserName: req.Parameters[constants.UserName], } if err = tx.Table("task").Create(&taskModel).Error; err != nil { return "", fmt.Errorf("failed to create task: %w", err) } clusterInfo, adapterInfo, err := l.getClusterInfo(req.ClusterId) if err != nil { return "", err } paramsJSON, err := jsoniter.MarshalToString(req) if err != nil { return "", fmt.Errorf("failed to marshal parameters: %w", err) } //解析slurm脚本内容 var resource models.ResourceSpec if req.Backend == string(constants.HPC_SYSTEM_SLURM) { parser := utils.NewSlurmParser() slurmResource := parser.ParseScript(jobScript) resource = models.ResourceSpec{ //资源规格名称,采用拼接的方式 集群名+队列名 ResourceName: fmt.Sprintf("%s_%s", clusterInfo.Name, slurmResource.Partition), Partition: slurmResource.Partition, Specifications: slurmResource, } } clusterID := utils.StringToInt64(clusterInfo.Id) hpcTask := models.TaskHpc{ Id: utils.GenSnowflakeID(), TaskId: taskID, AdapterId: clusterInfo.AdapterId, AdapterName: adapterInfo.Name, ClusterId: clusterID, ClusterName: clusterInfo.Name, Name: taskModel.Name, Backend: req.Backend, OperateType: req.OperateType, CmdScript: req.Parameters["cmdScript"], WallTime: req.Parameters["wallTime"], AppType: req.Parameters["appType"], AppName: req.App, Queue: req.Parameters["queue"], SubmitType: req.Parameters["submitType"], NNode: req.Parameters["nNode"], Account: clusterInfo.Username, StdInput: req.Parameters["stdInput"], Partition: req.Parameters["partition"], CreatedTime: time.Now(), UpdatedTime: time.Now(), Status: constants.StatusDeploying, UserId: userID, Params: paramsJSON, Script: jobScript, JobId: jobId, WorkDir: workDir, ResourceSpec: resource, } if err = tx.Table("task_hpc").Create(&hpcTask).Error; err != nil { return "", fmt.Errorf("failed to create HPC task: %w", err) } noticeInfo := clientCore.NoticeInfo{ AdapterId: clusterInfo.AdapterId, AdapterName: adapterInfo.Name, ClusterId: clusterID, ClusterName: clusterInfo.Name, NoticeType: "create", TaskName: req.Name, TaskId: taskID, Incident: "任务创建中", CreatedTime: time.Now(), } if err = tx.Table("t_notice").Create(¬iceInfo).Error; err != nil { return "", fmt.Errorf("failed to create notice: %w", err) } if err = tx.Commit().Error; err != nil { return "", fmt.Errorf("transaction commit failed: %w", err) } return utils.Int64ToString(taskID), nil } func (l *CommitHpcTaskLogic) CommitHpcTask(req *types.CommitHpcTaskReq) (resp *types.CommitHpcTaskResp, err error) { reqJSON, err := jsoniter.MarshalToString(req) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) } logc.Infof(l.ctx, "提交超算任务请求参数: %s", reqJSON) jobName := generateJobName(req) req.Parameters["jobName"] = jobName // 获取集群和适配器信息 clusterInfo, adapterInfo, err := l.getClusterInfo(req.ClusterId) if err != nil { return nil, err } scriptContent := req.ScriptContent if scriptContent == "" { // 获取模板 var templateInfo types.HpcAppTemplateInfo tx := l.svcCtx.DbEngin.Table("hpc_app_template"). Where("cluster_id = ? and app = ? ", req.ClusterId, req.App) if req.OperateType != "" { tx.Where("app_type = ?", req.OperateType) } if err := tx.First(&templateInfo).Error; err != nil { return nil, fmt.Errorf("获取HPC应用【%s】模板失败: %w", req.App, err) } // 转换请求参数 jobRequest, err := ConvertToJobRequest(req) if err != nil { return nil, err } // 渲染脚本 script, err := l.RenderJobScript(templateInfo.Content, &jobRequest) if err != nil { return nil, err } scriptContent = script } q, _ := jsoniter.MarshalToString(scriptContent) submitQ := types.SubmitHpcTaskReq{ App: req.App, ClusterId: req.ClusterId, JobName: jobName, ScriptContent: scriptContent, Parameters: req.Parameters, Backend: req.Backend, } log.Info().Msgf("Submitting HPC task to cluster %s with params: %s", clusterInfo.Name, q) resp, err = l.hpcService.HpcExecutorAdapterMap[adapterInfo.Id].SubmitTask(l.ctx, submitQ) if err != nil { log.Error().Err(err).Msgf("提交超算任务失败, cluster: %s, jobName: %s, scriptContent: %s", clusterInfo.Name, jobName, scriptContent) return nil, fmt.Errorf("网络请求失败,请稍后重试") } jobID := resp.Data.JobInfo["jobId"] workDir := resp.Data.JobInfo["jobDir"] taskID, err := l.SaveHpcTaskToDB(req, scriptContent, jobID, workDir) if err != nil { log.Error().Msgf("超算任务保存到数据库失败, cluster: %s, jobName: %s, scriptContent: %s, error: %v", clusterInfo.Name, jobName, scriptContent, err) return nil, fmt.Errorf("保存超算任务到数据库失败: %w", err) } resp.Data.JobInfo["taskId"] = taskID return resp, nil } func generateJobName(req *types.CommitHpcTaskReq) string { if req.OperateType == "" { return req.Name } return req.Name + "_" + req.OperateType }