You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

commithpctasklogic.go 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. package hpc
  2. import (
  3. "context"
  4. "fmt"
  5. jsoniter "github.com/json-iterator/go"
  6. "github.com/pkg/errors"
  7. "github.com/rs/zerolog/log"
  8. "github.com/zeromicro/go-zero/core/logc"
  9. "github.com/zeromicro/go-zero/core/logx"
  10. clientCore "gitlink.org.cn/JointCloud/pcm-coordinator/client"
  11. "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service"
  12. "gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc"
  13. "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types"
  14. "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
  15. "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils"
  16. "regexp"
  17. "strconv"
  18. "strings"
  19. "sync"
  20. "text/template"
  21. "time"
  22. )
  23. type CommitHpcTaskLogic struct {
  24. logx.Logger
  25. ctx context.Context
  26. svcCtx *svc.ServiceContext
  27. hpcService *service.HpcService
  28. }
  29. const (
  30. statusSaved = "Saved"
  31. statusDeploying = "Deploying"
  32. adapterTypeHPC = "2"
  33. )
  34. type JobRequest struct {
  35. App string `json:"app"`
  36. Common CommonParams `json:"common"`
  37. AppSpecific map[string]interface{} `json:"appSpecific"`
  38. }
  39. type CommonParams struct {
  40. JobName string `json:"jobName"`
  41. Partition string `json:"partition"`
  42. Nodes string `json:"nodes"`
  43. NTasks string `json:"ntasks"`
  44. Time string `json:"time,omitempty"`
  45. App string `json:"app"`
  46. }
  47. func NewCommitHpcTaskLogic(ctx context.Context, svcCtx *svc.ServiceContext) *CommitHpcTaskLogic {
  48. cache := make(map[string]interface{}, 10)
  49. hpcService, err := service.NewHpcService(&svcCtx.Config, svcCtx.Scheduler.HpcStorages, cache)
  50. if err != nil {
  51. return nil
  52. }
  53. return &CommitHpcTaskLogic{
  54. Logger: logx.WithContext(ctx),
  55. ctx: ctx,
  56. svcCtx: svcCtx,
  57. hpcService: hpcService,
  58. }
  59. }
  60. // 新增:缓存模板对象
  61. var templateCache = sync.Map{}
  62. func (l *CommitHpcTaskLogic) getClusterInfo(clusterID string) (*types.ClusterInfo, *types.AdapterInfo, error) {
  63. var clusterInfo types.ClusterInfo
  64. if err := l.svcCtx.DbEngin.Table("t_cluster").Where("id = ?", clusterID).First(&clusterInfo).Error; err != nil {
  65. return nil, nil, fmt.Errorf("cluster query failed: %w", err)
  66. }
  67. if clusterInfo.Id == "" {
  68. return nil, nil, errors.New("cluster not found")
  69. }
  70. var adapterInfo types.AdapterInfo
  71. if err := l.svcCtx.DbEngin.Table("t_adapter").Where("id = ?", clusterInfo.AdapterId).First(&adapterInfo).Error; err != nil {
  72. return nil, nil, fmt.Errorf("adapter query failed: %w", err)
  73. }
  74. if adapterInfo.Id == "" {
  75. return nil, nil, errors.New("adapter not found")
  76. }
  77. return &clusterInfo, &adapterInfo, nil
  78. }
  79. // 自定义函数映射
  80. func createFuncMap() template.FuncMap {
  81. return template.FuncMap{
  82. "regexMatch": regexMatch,
  83. "required": required,
  84. "error": errorHandler,
  85. "default": defaultHandler,
  86. }
  87. }
  88. func extractUserError(originalErr error) error {
  89. // 尝试匹配模板引擎返回的错误格式
  90. re := regexp.MustCompile(`error calling \w+: (.*)$`)
  91. matches := re.FindStringSubmatch(originalErr.Error())
  92. if len(matches) > 1 {
  93. return errors.New(matches[1])
  94. }
  95. return originalErr
  96. }
  97. // 正则匹配函数
  98. func regexMatch(pattern string) *regexp.Regexp {
  99. return regexp.MustCompile(pattern)
  100. }
  101. // 必填字段检查
  102. func required(msg string, val interface{}) (interface{}, error) {
  103. if val == nil || val == "" {
  104. return nil, errors.New(msg)
  105. }
  106. return val, nil
  107. }
  108. // 错误处理函数
  109. func errorHandler(msg string) (string, error) {
  110. return "", errors.New(msg)
  111. }
  112. // 默认值处理函数
  113. func defaultHandler(defaultVal interface{}, val interface{}) interface{} {
  114. switch v := val.(type) {
  115. case nil:
  116. return defaultVal
  117. case string:
  118. if v == "" {
  119. return defaultVal
  120. }
  121. case int:
  122. if v == 0 {
  123. return defaultVal
  124. }
  125. // 可根据需要添加其他类型判断
  126. }
  127. return val
  128. }
  129. func (l *CommitHpcTaskLogic) RenderJobScript(templateContent string, req *JobRequest) (string, error) {
  130. // 使用缓存模板
  131. tmpl, ok := templateCache.Load(templateContent)
  132. if !ok {
  133. parsedTmpl, err := template.New("slurmTemplate").Funcs(createFuncMap()).Parse(templateContent)
  134. if err != nil {
  135. return "", err
  136. }
  137. templateCache.Store(templateContent, parsedTmpl)
  138. tmpl = parsedTmpl
  139. }
  140. params := map[string]interface{}{
  141. "Common": req.Common,
  142. "App": req.AppSpecific,
  143. }
  144. var buf strings.Builder
  145. if err := tmpl.(*template.Template).Execute(&buf, params); err != nil {
  146. log.Error().Err(err).Msg("模板渲染失败")
  147. return "", extractUserError(err)
  148. }
  149. return buf.String(), nil
  150. }
  151. func ConvertToJobRequest(job *types.CommitHpcTaskReq) (JobRequest, error) {
  152. required := []string{"jobName", "nodes", "ntasks"}
  153. for _, field := range required {
  154. if job.Parameters[field] == "" {
  155. return JobRequest{}, fmt.Errorf("%s is empty", field)
  156. }
  157. }
  158. return JobRequest{
  159. App: job.App,
  160. Common: CommonParams{
  161. JobName: job.Parameters["jobName"],
  162. Partition: job.Parameters["partition"],
  163. Nodes: job.Parameters["nodes"],
  164. NTasks: job.Parameters["ntasks"],
  165. Time: job.Parameters["time"],
  166. App: job.App,
  167. },
  168. AppSpecific: utils.MpaStringToInterface(job.Parameters),
  169. }, nil
  170. }
  171. func (l *CommitHpcTaskLogic) SaveHpcTaskToDB(req *types.CommitHpcTaskReq, jobScript, jobId, workDir string) (taskId string, err error) {
  172. // 使用事务确保数据一致性
  173. tx := l.svcCtx.DbEngin.Begin()
  174. defer func() {
  175. if r := recover(); r != nil {
  176. tx.Rollback()
  177. err = fmt.Errorf("transaction panic: %v", r)
  178. } else if err != nil {
  179. tx.Rollback()
  180. }
  181. }()
  182. userID, _ := strconv.ParseInt(req.Parameters["UserId"], 10, 64)
  183. taskID := utils.GenSnowflakeID()
  184. taskModel := models.Task{
  185. Id: taskID,
  186. Name: req.Name,
  187. Description: req.Description,
  188. CommitTime: time.Now(),
  189. Status: statusSaved,
  190. AdapterTypeDict: adapterTypeHPC,
  191. UserId: userID,
  192. }
  193. if err = tx.Table("task").Create(&taskModel).Error; err != nil {
  194. return "", fmt.Errorf("failed to create task: %w", err)
  195. }
  196. clusterInfo, adapterInfo, err := l.getClusterInfo(req.ClusterId)
  197. if err != nil {
  198. return "", err
  199. }
  200. paramsJSON, err := jsoniter.MarshalToString(req)
  201. if err != nil {
  202. return "", fmt.Errorf("failed to marshal parameters: %w", err)
  203. }
  204. clusterID := utils.StringToInt64(clusterInfo.Id)
  205. hpcTask := models.TaskHpc{
  206. Id: utils.GenSnowflakeID(),
  207. TaskId: taskID,
  208. AdapterId: clusterInfo.AdapterId,
  209. AdapterName: adapterInfo.Name,
  210. ClusterId: clusterID,
  211. ClusterName: clusterInfo.Name,
  212. Name: taskModel.Name,
  213. Backend: req.Backend,
  214. OperateType: req.OperateType,
  215. CmdScript: req.Parameters["cmdScript"],
  216. WallTime: req.Parameters["wallTime"],
  217. AppType: req.Parameters["appType"],
  218. AppName: req.App,
  219. Queue: req.Parameters["queue"],
  220. SubmitType: req.Parameters["submitType"],
  221. NNode: req.Parameters["nNode"],
  222. Account: clusterInfo.Username,
  223. StdInput: req.Parameters["stdInput"],
  224. Partition: req.Parameters["partition"],
  225. CreatedTime: time.Now(),
  226. UpdatedTime: time.Now(),
  227. Status: statusDeploying,
  228. UserId: userID,
  229. Params: paramsJSON,
  230. Script: jobScript,
  231. JobId: jobId,
  232. WorkDir: workDir,
  233. }
  234. if err = tx.Table("task_hpc").Create(&hpcTask).Error; err != nil {
  235. return "", fmt.Errorf("failed to create HPC task: %w", err)
  236. }
  237. noticeInfo := clientCore.NoticeInfo{
  238. AdapterId: clusterInfo.AdapterId,
  239. AdapterName: adapterInfo.Name,
  240. ClusterId: clusterID,
  241. ClusterName: clusterInfo.Name,
  242. NoticeType: "create",
  243. TaskName: req.Name,
  244. TaskId: taskID,
  245. Incident: "任务创建中",
  246. CreatedTime: time.Now(),
  247. }
  248. if err = tx.Table("t_notice").Create(&noticeInfo).Error; err != nil {
  249. return "", fmt.Errorf("failed to create notice: %w", err)
  250. }
  251. if err = tx.Commit().Error; err != nil {
  252. return "", fmt.Errorf("transaction commit failed: %w", err)
  253. }
  254. return utils.Int64ToString(taskID), nil
  255. }
  256. func (l *CommitHpcTaskLogic) CommitHpcTask(req *types.CommitHpcTaskReq) (resp *types.CommitHpcTaskResp, err error) {
  257. reqJSON, err := jsoniter.MarshalToString(req)
  258. if err != nil {
  259. return nil, fmt.Errorf("failed to marshal request: %w", err)
  260. }
  261. logc.Infof(l.ctx, "提交超算任务请求参数: %s", reqJSON)
  262. jobName := generateJobName(req)
  263. req.Parameters["jobName"] = jobName
  264. // 获取集群和适配器信息
  265. clusterInfo, adapterInfo, err := l.getClusterInfo(req.ClusterId)
  266. if err != nil {
  267. return nil, err
  268. }
  269. scriptContent := req.ScriptContent
  270. if scriptContent == "" {
  271. // 获取模板
  272. var templateInfo types.HpcAppTemplateInfo
  273. tx := l.svcCtx.DbEngin.Table("hpc_app_template").
  274. Where("cluster_id = ? and app = ? ", req.ClusterId, req.App)
  275. if req.OperateType != "" {
  276. tx.Where("app_type = ?", req.OperateType)
  277. }
  278. if err := tx.First(&templateInfo).Error; err != nil {
  279. return nil, fmt.Errorf("获取HPC应用【%s】模板失败: %w", req.App, err)
  280. }
  281. // 转换请求参数
  282. jobRequest, err := ConvertToJobRequest(req)
  283. if err != nil {
  284. return nil, err
  285. }
  286. // 渲染脚本
  287. script, err := l.RenderJobScript(templateInfo.Content, &jobRequest)
  288. if err != nil {
  289. return nil, err
  290. }
  291. scriptContent = script
  292. }
  293. q, _ := jsoniter.MarshalToString(scriptContent)
  294. submitQ := types.SubmitHpcTaskReq{
  295. App: req.App,
  296. ClusterId: req.ClusterId,
  297. JobName: jobName,
  298. ScriptContent: scriptContent,
  299. Parameters: req.Parameters,
  300. Backend: req.Backend,
  301. }
  302. log.Info().Msgf("Submitting HPC task to cluster %s with params: %s", clusterInfo.Name, q)
  303. resp, err = l.hpcService.HpcExecutorAdapterMap[adapterInfo.Id].SubmitTask(l.ctx, submitQ)
  304. if err != nil {
  305. log.Error().Err(err).Msgf("提交超算任务失败, cluster: %s, jobName: %s, scriptContent: %s", clusterInfo.Name, jobName, scriptContent)
  306. return nil, fmt.Errorf("网络请求失败,请稍后重试")
  307. }
  308. jobID := resp.Data.JobInfo["jobId"]
  309. workDir := resp.Data.JobInfo["jobDir"]
  310. taskID, err := l.SaveHpcTaskToDB(req, scriptContent, jobID, workDir)
  311. if err != nil {
  312. log.Error().Msgf("超算任务保存到数据库失败, cluster: %s, jobName: %s, scriptContent: %s, error: %v", clusterInfo.Name, jobName, scriptContent, err)
  313. return nil, fmt.Errorf("保存超算任务到数据库失败: %w", err)
  314. }
  315. resp.Data.JobInfo["taskId"] = taskID
  316. return resp, nil
  317. }
  318. func generateJobName(req *types.CommitHpcTaskReq) string {
  319. if req.OperateType == "" {
  320. return req.Name
  321. }
  322. return req.Name + "_" + req.OperateType
  323. }

PCM is positioned as Software stack over Cloud, aiming to build the standards and ecology of heterogeneous cloud collaboration for JCC in a non intrusive and autonomous peer-to-peer manner.