You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

commithpctasklogic.go 11 kB

3 months ago
3 months ago
3 months ago
3 months ago
3 months ago
3 months ago
3 months ago
3 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. package hpc
  2. import (
  3. "context"
  4. "fmt"
  5. "regexp"
  6. "strconv"
  7. "strings"
  8. "sync"
  9. "text/template"
  10. "time"
  11. jsoniter "github.com/json-iterator/go"
  12. "github.com/pkg/errors"
  13. "github.com/rs/zerolog/log"
  14. "github.com/zeromicro/go-zero/core/logc"
  15. "github.com/zeromicro/go-zero/core/logx"
  16. clientCore "gitlink.org.cn/JointCloud/pcm-coordinator/client"
  17. "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service"
  18. "gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc"
  19. "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types"
  20. "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants"
  21. "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
  22. "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils"
  23. )
  24. type CommitHpcTaskLogic struct {
  25. logx.Logger
  26. ctx context.Context
  27. svcCtx *svc.ServiceContext
  28. hpcService *service.HpcService
  29. }
  30. type JobRequest struct {
  31. App string `json:"app"`
  32. Common CommonParams `json:"common"`
  33. AppSpecific map[string]interface{} `json:"appSpecific"`
  34. }
  35. type CommonParams struct {
  36. JobName string `json:"jobName"`
  37. Partition string `json:"partition"`
  38. Nodes string `json:"nodes"`
  39. NTasks string `json:"ntasks"`
  40. Time string `json:"time,omitempty"`
  41. App string `json:"app"`
  42. }
  43. func NewCommitHpcTaskLogic(ctx context.Context, svcCtx *svc.ServiceContext) *CommitHpcTaskLogic {
  44. cache := make(map[string]interface{}, 10)
  45. hpcService, err := service.NewHpcService(&svcCtx.Config, svcCtx.Scheduler.HpcStorages, cache)
  46. if err != nil {
  47. return nil
  48. }
  49. return &CommitHpcTaskLogic{
  50. Logger: logx.WithContext(ctx),
  51. ctx: ctx,
  52. svcCtx: svcCtx,
  53. hpcService: hpcService,
  54. }
  55. }
  56. // 新增:缓存模板对象
  57. var templateCache = sync.Map{}
  58. func (l *CommitHpcTaskLogic) getClusterInfo(clusterID string) (*types.ClusterInfo, *types.AdapterInfo, error) {
  59. var clusterInfo types.ClusterInfo
  60. if err := l.svcCtx.DbEngin.Table("t_cluster").Where("id = ?", clusterID).First(&clusterInfo).Error; err != nil {
  61. return nil, nil, fmt.Errorf("cluster query failed: %w", err)
  62. }
  63. if clusterInfo.Id == "" {
  64. return nil, nil, errors.New("cluster not found")
  65. }
  66. var adapterInfo types.AdapterInfo
  67. if err := l.svcCtx.DbEngin.Table("t_adapter").Where("id = ?", clusterInfo.AdapterId).First(&adapterInfo).Error; err != nil {
  68. return nil, nil, fmt.Errorf("adapter query failed: %w", err)
  69. }
  70. if adapterInfo.Id == "" {
  71. return nil, nil, errors.New("adapter not found")
  72. }
  73. return &clusterInfo, &adapterInfo, nil
  74. }
  75. // 自定义函数映射
  76. func createFuncMap() template.FuncMap {
  77. return template.FuncMap{
  78. "regexMatch": regexMatch,
  79. "required": required,
  80. "error": errorHandler,
  81. "default": defaultHandler,
  82. }
  83. }
  84. func extractUserError(originalErr error) error {
  85. // 尝试匹配模板引擎返回的错误格式
  86. re := regexp.MustCompile(`error calling \w+: (.*)$`)
  87. matches := re.FindStringSubmatch(originalErr.Error())
  88. if len(matches) > 1 {
  89. return errors.New(matches[1])
  90. }
  91. return originalErr
  92. }
  93. // 正则匹配函数
  94. func regexMatch(pattern string) *regexp.Regexp {
  95. return regexp.MustCompile(pattern)
  96. }
  97. // 必填字段检查
  98. func required(msg string, val interface{}) (interface{}, error) {
  99. if val == nil || val == "" {
  100. return nil, errors.New(msg)
  101. }
  102. return val, nil
  103. }
  104. // 错误处理函数
  105. func errorHandler(msg string) (string, error) {
  106. return "", errors.New(msg)
  107. }
  108. // 默认值处理函数
  109. func defaultHandler(defaultVal interface{}, val interface{}) interface{} {
  110. switch v := val.(type) {
  111. case nil:
  112. return defaultVal
  113. case string:
  114. if v == "" {
  115. return defaultVal
  116. }
  117. case int:
  118. if v == 0 {
  119. return defaultVal
  120. }
  121. // 可根据需要添加其他类型判断
  122. }
  123. return val
  124. }
  125. func (l *CommitHpcTaskLogic) RenderJobScript(templateContent string, req *JobRequest) (string, error) {
  126. // 使用缓存模板
  127. tmpl, ok := templateCache.Load(templateContent)
  128. if !ok {
  129. parsedTmpl, err := template.New("slurmTemplate").Funcs(createFuncMap()).Parse(templateContent)
  130. if err != nil {
  131. return "", err
  132. }
  133. templateCache.Store(templateContent, parsedTmpl)
  134. tmpl = parsedTmpl
  135. }
  136. params := map[string]interface{}{
  137. "Common": req.Common,
  138. "App": req.AppSpecific,
  139. }
  140. var buf strings.Builder
  141. if err := tmpl.(*template.Template).Execute(&buf, params); err != nil {
  142. log.Error().Err(err).Msg("模板渲染失败")
  143. return "", extractUserError(err)
  144. }
  145. return buf.String(), nil
  146. }
  147. func ConvertToJobRequest(job *types.CommitHpcTaskReq) (JobRequest, error) {
  148. required := []string{"jobName", "nodes", "ntasks"}
  149. for _, field := range required {
  150. if job.Parameters[field] == "" {
  151. return JobRequest{}, fmt.Errorf("%s is empty", field)
  152. }
  153. }
  154. return JobRequest{
  155. App: job.App,
  156. Common: CommonParams{
  157. JobName: job.Parameters["jobName"],
  158. Partition: job.Parameters["partition"],
  159. Nodes: job.Parameters["nodes"],
  160. NTasks: job.Parameters["ntasks"],
  161. Time: job.Parameters["time"],
  162. App: job.App,
  163. },
  164. AppSpecific: utils.MpaStringToInterface(job.Parameters),
  165. }, nil
  166. }
  167. func (l *CommitHpcTaskLogic) SaveHpcTaskToDB(req *types.CommitHpcTaskReq, jobScript, jobId, workDir string) (taskId string, err error) {
  168. // 使用事务确保数据一致性
  169. tx := l.svcCtx.DbEngin.Begin()
  170. defer func() {
  171. if r := recover(); r != nil {
  172. tx.Rollback()
  173. err = fmt.Errorf("transaction panic: %v", r)
  174. } else if err != nil {
  175. tx.Rollback()
  176. }
  177. }()
  178. userID, _ := strconv.ParseInt(req.Parameters[constants.UserId], 10, 64)
  179. taskID := utils.GenSnowflakeID()
  180. taskModel := models.Task{
  181. Id: taskID,
  182. Name: req.Name,
  183. Description: req.Description,
  184. CommitTime: time.Now(),
  185. Status: constants.StatusSaved,
  186. AdapterTypeDict: constants.AdapterTypeHPC,
  187. UserId: userID,
  188. UserName: req.Parameters[constants.UserName],
  189. }
  190. if err = tx.Table("task").Create(&taskModel).Error; err != nil {
  191. return "", fmt.Errorf("failed to create task: %w", err)
  192. }
  193. clusterInfo, adapterInfo, err := l.getClusterInfo(req.ClusterId)
  194. if err != nil {
  195. return "", err
  196. }
  197. paramsJSON, err := jsoniter.MarshalToString(req)
  198. if err != nil {
  199. return "", fmt.Errorf("failed to marshal parameters: %w", err)
  200. }
  201. //解析slurm脚本内容
  202. var resource models.ResourceSpec
  203. if req.Backend == string(constants.HPC_SYSTEM_SLURM) {
  204. parser := utils.NewSlurmParser()
  205. slurmResource := parser.ParseScript(jobScript)
  206. resource = models.ResourceSpec{
  207. //资源规格名称,采用拼接的方式 集群名+队列名
  208. ResourceName: fmt.Sprintf("%s_%s", clusterInfo.Name, slurmResource.Partition),
  209. Partition: slurmResource.Partition,
  210. Specifications: slurmResource,
  211. }
  212. }
  213. clusterID := utils.StringToInt64(clusterInfo.Id)
  214. hpcTask := models.TaskHpc{
  215. Id: utils.GenSnowflakeID(),
  216. TaskId: taskID,
  217. AdapterId: clusterInfo.AdapterId,
  218. AdapterName: adapterInfo.Name,
  219. ClusterId: clusterID,
  220. ClusterName: clusterInfo.Name,
  221. Name: taskModel.Name,
  222. Backend: req.Backend,
  223. OperateType: req.OperateType,
  224. CmdScript: req.Parameters["cmdScript"],
  225. WallTime: req.Parameters["wallTime"],
  226. AppType: req.Parameters["appType"],
  227. AppName: req.App,
  228. Queue: req.Parameters["queue"],
  229. SubmitType: req.Parameters["submitType"],
  230. NNode: req.Parameters["nNode"],
  231. Account: clusterInfo.Username,
  232. StdInput: req.Parameters["stdInput"],
  233. Partition: req.Parameters["partition"],
  234. CreatedTime: time.Now(),
  235. UpdatedTime: time.Now(),
  236. Status: constants.StatusDeploying,
  237. UserId: userID,
  238. Params: paramsJSON,
  239. Script: jobScript,
  240. JobId: jobId,
  241. WorkDir: workDir,
  242. ResourceSpec: resource,
  243. }
  244. if err = tx.Table("task_hpc").Create(&hpcTask).Error; err != nil {
  245. return "", fmt.Errorf("failed to create HPC task: %w", err)
  246. }
  247. noticeInfo := clientCore.NoticeInfo{
  248. AdapterId: clusterInfo.AdapterId,
  249. AdapterName: adapterInfo.Name,
  250. ClusterId: clusterID,
  251. ClusterName: clusterInfo.Name,
  252. NoticeType: "create",
  253. TaskName: req.Name,
  254. TaskId: taskID,
  255. Incident: "任务创建中",
  256. CreatedTime: time.Now(),
  257. }
  258. if err = tx.Table("t_notice").Create(&noticeInfo).Error; err != nil {
  259. return "", fmt.Errorf("failed to create notice: %w", err)
  260. }
  261. if err = tx.Commit().Error; err != nil {
  262. return "", fmt.Errorf("transaction commit failed: %w", err)
  263. }
  264. return utils.Int64ToString(taskID), nil
  265. }
  266. func (l *CommitHpcTaskLogic) CommitHpcTask(req *types.CommitHpcTaskReq) (resp *types.CommitHpcTaskResp, err error) {
  267. reqJSON, err := jsoniter.MarshalToString(req)
  268. if err != nil {
  269. return nil, fmt.Errorf("failed to marshal request: %w", err)
  270. }
  271. logc.Infof(l.ctx, "提交超算任务请求参数: %s", reqJSON)
  272. jobName := generateJobName(req)
  273. req.Parameters["jobName"] = jobName
  274. // 获取集群和适配器信息
  275. clusterInfo, adapterInfo, err := l.getClusterInfo(req.ClusterId)
  276. if err != nil {
  277. return nil, err
  278. }
  279. scriptContent := req.ScriptContent
  280. if scriptContent == "" {
  281. // 获取模板
  282. var templateInfo types.HpcAppTemplateInfo
  283. tx := l.svcCtx.DbEngin.Table("hpc_app_template").
  284. Where("cluster_id = ? and app = ? ", req.ClusterId, req.App)
  285. if req.OperateType != "" {
  286. tx.Where("app_type = ?", req.OperateType)
  287. }
  288. if err := tx.First(&templateInfo).Error; err != nil {
  289. return nil, fmt.Errorf("获取HPC应用【%s】模板失败: %w", req.App, err)
  290. }
  291. // 转换请求参数
  292. jobRequest, err := ConvertToJobRequest(req)
  293. if err != nil {
  294. return nil, err
  295. }
  296. // 渲染脚本
  297. script, err := l.RenderJobScript(templateInfo.Content, &jobRequest)
  298. if err != nil {
  299. return nil, err
  300. }
  301. scriptContent = script
  302. }
  303. q, _ := jsoniter.MarshalToString(scriptContent)
  304. submitQ := types.SubmitHpcTaskReq{
  305. App: req.App,
  306. ClusterId: req.ClusterId,
  307. JobName: jobName,
  308. ScriptContent: scriptContent,
  309. Parameters: req.Parameters,
  310. Backend: req.Backend,
  311. }
  312. log.Info().Msgf("Submitting HPC task to cluster %s with params: %s", clusterInfo.Name, q)
  313. resp, err = l.hpcService.HpcExecutorAdapterMap[adapterInfo.Id].SubmitTask(l.ctx, submitQ)
  314. if err != nil {
  315. log.Error().Err(err).Msgf("提交超算任务失败, cluster: %s, jobName: %s, scriptContent: %s", clusterInfo.Name, jobName, scriptContent)
  316. return nil, fmt.Errorf("网络请求失败,请稍后重试")
  317. }
  318. jobID := resp.Data.JobInfo["jobId"]
  319. workDir := resp.Data.JobInfo["jobDir"]
  320. taskID, err := l.SaveHpcTaskToDB(req, scriptContent, jobID, workDir)
  321. if err != nil {
  322. log.Error().Msgf("超算任务保存到数据库失败, cluster: %s, jobName: %s, scriptContent: %s, error: %v", clusterInfo.Name, jobName, scriptContent, err)
  323. return nil, fmt.Errorf("保存超算任务到数据库失败: %w", err)
  324. }
  325. resp.Data.JobInfo["taskId"] = taskID
  326. return resp, nil
  327. }
  328. func generateJobName(req *types.CommitHpcTaskReq) string {
  329. if req.OperateType == "" {
  330. return req.Name
  331. }
  332. return req.Name + "_" + req.OperateType
  333. }

PCM is positioned as Software stack over Cloud, aiming to build the standards and ecology of heterogeneous cloud collaboration for JCC in a non intrusive and autonomous peer-to-peer manner.