|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373 |
- package hpc
-
- import (
- "context"
- "fmt"
- "regexp"
- "strconv"
- "strings"
- "sync"
- "text/template"
- "time"
-
- jsoniter "github.com/json-iterator/go"
- "github.com/pkg/errors"
- "github.com/rs/zerolog/log"
- "github.com/zeromicro/go-zero/core/logc"
- "github.com/zeromicro/go-zero/core/logx"
- clientCore "gitlink.org.cn/JointCloud/pcm-coordinator/client"
- "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service"
- "gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc"
- "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types"
- "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants"
- "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
- "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils"
- )
-
- type CommitHpcTaskLogic struct {
- logx.Logger
- ctx context.Context
- svcCtx *svc.ServiceContext
- hpcService *service.HpcService
- }
-
- type JobRequest struct {
- App string `json:"app"`
- Common CommonParams `json:"common"`
- AppSpecific map[string]interface{} `json:"appSpecific"`
- }
- type CommonParams struct {
- JobName string `json:"jobName"`
- Partition string `json:"partition"`
- Nodes string `json:"nodes"`
- NTasks string `json:"ntasks"`
- Time string `json:"time,omitempty"`
- App string `json:"app"`
- }
-
- func NewCommitHpcTaskLogic(ctx context.Context, svcCtx *svc.ServiceContext) *CommitHpcTaskLogic {
- cache := make(map[string]interface{}, 10)
- hpcService, err := service.NewHpcService(&svcCtx.Config, svcCtx.Scheduler.HpcStorages, cache)
- if err != nil {
- return nil
- }
- return &CommitHpcTaskLogic{
- Logger: logx.WithContext(ctx),
- ctx: ctx,
- svcCtx: svcCtx,
- hpcService: hpcService,
- }
- }
-
- // 新增:缓存模板对象
- var templateCache = sync.Map{}
-
- func (l *CommitHpcTaskLogic) getClusterInfo(clusterID string) (*types.ClusterInfo, *types.AdapterInfo, error) {
- var clusterInfo types.ClusterInfo
- if err := l.svcCtx.DbEngin.Table("t_cluster").Where("id = ?", clusterID).First(&clusterInfo).Error; err != nil {
- return nil, nil, fmt.Errorf("cluster query failed: %w", err)
- }
-
- if clusterInfo.Id == "" {
- return nil, nil, errors.New("cluster not found")
- }
-
- var adapterInfo types.AdapterInfo
- if err := l.svcCtx.DbEngin.Table("t_adapter").Where("id = ?", clusterInfo.AdapterId).First(&adapterInfo).Error; err != nil {
- return nil, nil, fmt.Errorf("adapter query failed: %w", err)
- }
-
- if adapterInfo.Id == "" {
- return nil, nil, errors.New("adapter not found")
- }
-
- return &clusterInfo, &adapterInfo, nil
- }
-
- // 自定义函数映射
- func createFuncMap() template.FuncMap {
- return template.FuncMap{
- "regexMatch": regexMatch,
- "required": required,
- "error": errorHandler,
- "default": defaultHandler,
- }
- }
- func extractUserError(originalErr error) error {
- // 尝试匹配模板引擎返回的错误格式
- re := regexp.MustCompile(`error calling \w+: (.*)$`)
- matches := re.FindStringSubmatch(originalErr.Error())
- if len(matches) > 1 {
- return errors.New(matches[1])
- }
- return originalErr
- }
-
- // 正则匹配函数
- func regexMatch(pattern string) *regexp.Regexp {
- return regexp.MustCompile(pattern)
- }
-
- // 必填字段检查
- func required(msg string, val interface{}) (interface{}, error) {
- if val == nil || val == "" {
- return nil, errors.New(msg)
- }
- return val, nil
- }
-
- // 错误处理函数
- func errorHandler(msg string) (string, error) {
- return "", errors.New(msg)
- }
-
- // 默认值处理函数
- func defaultHandler(defaultVal interface{}, val interface{}) interface{} {
- switch v := val.(type) {
- case nil:
- return defaultVal
- case string:
- if v == "" {
- return defaultVal
- }
- case int:
- if v == 0 {
- return defaultVal
- }
- // 可根据需要添加其他类型判断
- }
- return val
- }
- func (l *CommitHpcTaskLogic) RenderJobScript(templateContent string, req *JobRequest) (string, error) {
- // 使用缓存模板
- tmpl, ok := templateCache.Load(templateContent)
- if !ok {
- parsedTmpl, err := template.New("slurmTemplate").Funcs(createFuncMap()).Parse(templateContent)
- if err != nil {
- return "", err
- }
- templateCache.Store(templateContent, parsedTmpl)
- tmpl = parsedTmpl
- }
-
- params := map[string]interface{}{
- "Common": req.Common,
- "App": req.AppSpecific,
- }
-
- var buf strings.Builder
- if err := tmpl.(*template.Template).Execute(&buf, params); err != nil {
- log.Error().Err(err).Msg("模板渲染失败")
- return "", extractUserError(err)
- }
- return buf.String(), nil
- }
-
- func ConvertToJobRequest(job *types.CommitHpcTaskReq) (JobRequest, error) {
- required := []string{"jobName", "nodes", "ntasks"}
- for _, field := range required {
- if job.Parameters[field] == "" {
- return JobRequest{}, fmt.Errorf("%s is empty", field)
- }
- }
-
- return JobRequest{
- App: job.App,
- Common: CommonParams{
- JobName: job.Parameters["jobName"],
- Partition: job.Parameters["partition"],
- Nodes: job.Parameters["nodes"],
- NTasks: job.Parameters["ntasks"],
- Time: job.Parameters["time"],
- App: job.App,
- },
- AppSpecific: utils.MpaStringToInterface(job.Parameters),
- }, nil
- }
-
- func (l *CommitHpcTaskLogic) SaveHpcTaskToDB(req *types.CommitHpcTaskReq, jobScript, jobId, workDir string) (taskId string, err error) {
- // 使用事务确保数据一致性
- tx := l.svcCtx.DbEngin.Begin()
- defer func() {
- if r := recover(); r != nil {
- tx.Rollback()
- err = fmt.Errorf("transaction panic: %v", r)
- } else if err != nil {
- tx.Rollback()
- }
- }()
-
- userID, _ := strconv.ParseInt(req.Parameters[constants.UserId], 10, 64)
- taskID := utils.GenSnowflakeID()
-
- taskModel := models.Task{
- Id: taskID,
- Name: req.Name,
- Description: req.Description,
- CommitTime: time.Now(),
- Status: constants.StatusSaved,
- AdapterTypeDict: constants.AdapterTypeHPC,
- UserId: userID,
- UserName: req.Parameters[constants.UserName],
- }
-
- if err = tx.Table("task").Create(&taskModel).Error; err != nil {
- return "", fmt.Errorf("failed to create task: %w", err)
- }
-
- clusterInfo, adapterInfo, err := l.getClusterInfo(req.ClusterId)
- if err != nil {
- return "", err
- }
-
- paramsJSON, err := jsoniter.MarshalToString(req)
- if err != nil {
- return "", fmt.Errorf("failed to marshal parameters: %w", err)
- }
- //解析slurm脚本内容
- var resource models.ResourceSpec
- if req.Backend == string(constants.HPC_SYSTEM_SLURM) {
- parser := utils.NewSlurmParser()
- slurmResource := parser.ParseScript(jobScript)
- resource = models.ResourceSpec{
- //资源规格名称,采用拼接的方式 集群名+队列名
- ResourceName: fmt.Sprintf("%s_%s", clusterInfo.Name, slurmResource.Partition),
- Partition: slurmResource.Partition,
- Specifications: slurmResource,
- }
- }
- clusterID := utils.StringToInt64(clusterInfo.Id)
- hpcTask := models.TaskHpc{
- Id: utils.GenSnowflakeID(),
- TaskId: taskID,
- AdapterId: clusterInfo.AdapterId,
- AdapterName: adapterInfo.Name,
- ClusterId: clusterID,
- ClusterName: clusterInfo.Name,
- Name: taskModel.Name,
- Backend: req.Backend,
- OperateType: req.OperateType,
- CmdScript: req.Parameters["cmdScript"],
- WallTime: req.Parameters["wallTime"],
- AppType: req.Parameters["appType"],
- AppName: req.App,
- Queue: req.Parameters["queue"],
- SubmitType: req.Parameters["submitType"],
- NNode: req.Parameters["nNode"],
- Account: clusterInfo.Username,
- StdInput: req.Parameters["stdInput"],
- Partition: req.Parameters["partition"],
- CreatedTime: time.Now(),
- UpdatedTime: time.Now(),
- Status: constants.StatusDeploying,
- UserId: userID,
- Params: paramsJSON,
- Script: jobScript,
- JobId: jobId,
- WorkDir: workDir,
- ResourceSpec: resource,
- }
-
- if err = tx.Table("task_hpc").Create(&hpcTask).Error; err != nil {
- return "", fmt.Errorf("failed to create HPC task: %w", err)
- }
-
- noticeInfo := clientCore.NoticeInfo{
- AdapterId: clusterInfo.AdapterId,
- AdapterName: adapterInfo.Name,
- ClusterId: clusterID,
- ClusterName: clusterInfo.Name,
- NoticeType: "create",
- TaskName: req.Name,
- TaskId: taskID,
- Incident: "任务创建中",
- CreatedTime: time.Now(),
- }
-
- if err = tx.Table("t_notice").Create(¬iceInfo).Error; err != nil {
- return "", fmt.Errorf("failed to create notice: %w", err)
- }
-
- if err = tx.Commit().Error; err != nil {
- return "", fmt.Errorf("transaction commit failed: %w", err)
- }
-
- return utils.Int64ToString(taskID), nil
- }
-
- func (l *CommitHpcTaskLogic) CommitHpcTask(req *types.CommitHpcTaskReq) (resp *types.CommitHpcTaskResp, err error) {
- reqJSON, err := jsoniter.MarshalToString(req)
- if err != nil {
- return nil, fmt.Errorf("failed to marshal request: %w", err)
- }
- logc.Infof(l.ctx, "提交超算任务请求参数: %s", reqJSON)
-
- jobName := generateJobName(req)
- req.Parameters["jobName"] = jobName
-
- // 获取集群和适配器信息
- clusterInfo, adapterInfo, err := l.getClusterInfo(req.ClusterId)
- if err != nil {
- return nil, err
- }
- scriptContent := req.ScriptContent
- if scriptContent == "" {
- // 获取模板
- var templateInfo types.HpcAppTemplateInfo
- tx := l.svcCtx.DbEngin.Table("hpc_app_template").
- Where("cluster_id = ? and app = ? ", req.ClusterId, req.App)
- if req.OperateType != "" {
- tx.Where("app_type = ?", req.OperateType)
- }
- if err := tx.First(&templateInfo).Error; err != nil {
- return nil, fmt.Errorf("获取HPC应用【%s】模板失败: %w", req.App, err)
- }
-
- // 转换请求参数
- jobRequest, err := ConvertToJobRequest(req)
- if err != nil {
- return nil, err
- }
-
- // 渲染脚本
- script, err := l.RenderJobScript(templateInfo.Content, &jobRequest)
- if err != nil {
- return nil, err
- }
- scriptContent = script
- }
-
- q, _ := jsoniter.MarshalToString(scriptContent)
- submitQ := types.SubmitHpcTaskReq{
- App: req.App,
- ClusterId: req.ClusterId,
- JobName: jobName,
- ScriptContent: scriptContent,
- Parameters: req.Parameters,
- Backend: req.Backend,
- }
- log.Info().Msgf("Submitting HPC task to cluster %s with params: %s", clusterInfo.Name, q)
- resp, err = l.hpcService.HpcExecutorAdapterMap[adapterInfo.Id].SubmitTask(l.ctx, submitQ)
- if err != nil {
- log.Error().Err(err).Msgf("提交超算任务失败, cluster: %s, jobName: %s, scriptContent: %s", clusterInfo.Name, jobName, scriptContent)
- return nil, fmt.Errorf("网络请求失败,请稍后重试")
- }
-
- jobID := resp.Data.JobInfo["jobId"]
- workDir := resp.Data.JobInfo["jobDir"]
- taskID, err := l.SaveHpcTaskToDB(req, scriptContent, jobID, workDir)
- if err != nil {
- log.Error().Msgf("超算任务保存到数据库失败, cluster: %s, jobName: %s, scriptContent: %s, error: %v", clusterInfo.Name, jobName, scriptContent, err)
- return nil, fmt.Errorf("保存超算任务到数据库失败: %w", err)
- }
-
- resp.Data.JobInfo["taskId"] = taskID
- return resp, nil
- }
-
- func generateJobName(req *types.CommitHpcTaskReq) string {
- if req.OperateType == "" {
- return req.Name
- }
- return req.Name + "_" + req.OperateType
- }
|