You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grampus.go 4.4 kB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. package grampus
  2. import (
  3. "code.gitea.io/gitea/models"
  4. "code.gitea.io/gitea/modules/context"
  5. "code.gitea.io/gitea/modules/log"
  6. "code.gitea.io/gitea/modules/notification"
  7. "code.gitea.io/gitea/modules/timeutil"
  8. "strings"
  9. )
  10. const (
  11. storageTypeOBS = "obs"
  12. WorkPath = "/home/ma-user/work"
  13. CodePath = "/code/"
  14. DatasetPath = "/dataset"
  15. OutputPath = "/output/"
  16. ResultPath = "/result/"
  17. LogPath = "/log/"
  18. JobPath = "/job/"
  19. OrderDesc = "desc" //向下查询
  20. OrderAsc = "asc" //向上查询
  21. Lines = 500
  22. TrainUrl = "train_url"
  23. DataUrl = "data_url"
  24. ResultUrl = "result_url"
  25. CkptUrl = "ckpt_url"
  26. DeviceTarget = "device_target"
  27. Ascend = "Ascend"
  28. PerPage = 10
  29. IsLatestVersion = "1"
  30. NotLatestVersion = "0"
  31. VersionCount = 1
  32. SortByCreateTime = "create_time"
  33. ConfigTypeCustom = "custom"
  34. TotalVersionCount = 1
  35. ProcessorTypeNPU = "npu.huawei.com/NPU"
  36. ProcessorTypeGPU = "nvidia.com/gpu"
  37. CommandPrepareScript = "pwd;cd /tmp;wget https://git.openi.org.cn/lewis/script_for_grampus/archive/master.zip;unzip master.zip;cd script_for_grampus;"
  38. ScriptSyncObsCodeAndDataset = "sync_obs_code_and_dataset.py"
  39. )
  40. var (
  41. poolInfos *models.PoolInfos
  42. FlavorInfos *models.FlavorInfos
  43. ImageInfos *models.ImageInfosModelArts
  44. )
  45. type GenerateTrainJobReq struct {
  46. JobName string
  47. Command string
  48. ResourceSpecId string
  49. ImageUrl string //与image_id二选一,都有的情况下优先image_url
  50. ImageId string
  51. DisplayJobName string
  52. Uuid string
  53. Description string
  54. CodeObsPath string
  55. BootFile string
  56. BootFileUrl string
  57. DataUrl string
  58. TrainUrl string
  59. WorkServerNumber int
  60. EngineID int64
  61. CommitID string
  62. IsLatestVersion string
  63. BranchName string
  64. PreVersionId int64
  65. PreVersionName string
  66. FlavorName string
  67. VersionCount int
  68. EngineName string
  69. TotalVersionCount int
  70. ComputeResource string
  71. DatasetName string
  72. Params string
  73. }
  74. func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (err error) {
  75. createTime := timeutil.TimeStampNow()
  76. jobResult, err := createJob(models.CreateGrampusJobRequest{
  77. Name: req.JobName,
  78. Tasks: []models.GrampusTasks{
  79. {
  80. Name: req.JobName,
  81. Command: req.Command,
  82. ResourceSpecId: req.ResourceSpecId,
  83. ImageId: req.ImageId,
  84. ImageUrl: req.ImageUrl,
  85. },
  86. },
  87. })
  88. if err != nil {
  89. log.Error("createJob failed: %v", err.Error())
  90. return err
  91. }
  92. jobID := jobResult.JobInfo.JobID
  93. err = models.CreateCloudbrain(&models.Cloudbrain{
  94. Status: TransTrainJobStatus(jobResult.JobInfo.Status),
  95. UserID: ctx.User.ID,
  96. RepoID: ctx.Repo.Repository.ID,
  97. JobID: jobID,
  98. JobName: req.JobName,
  99. DisplayJobName: req.DisplayJobName,
  100. JobType: string(models.JobTypeTrain),
  101. Type: models.TypeCloudBrainGrampus,
  102. Uuid: req.Uuid,
  103. DatasetName: req.DatasetName,
  104. CommitID: req.CommitID,
  105. IsLatestVersion: req.IsLatestVersion,
  106. ComputeResource: req.ComputeResource,
  107. ImageID: req.ImageId,
  108. TrainUrl: req.TrainUrl,
  109. BranchName: req.BranchName,
  110. Parameters: req.Params,
  111. BootFile: req.BootFile,
  112. DataUrl: req.DataUrl,
  113. FlavorCode: req.ResourceSpecId,
  114. Description: req.Description,
  115. WorkServerNumber: req.WorkServerNumber,
  116. FlavorName: req.FlavorName,
  117. EngineName: req.EngineName,
  118. VersionCount: req.VersionCount,
  119. TotalVersionCount: req.TotalVersionCount,
  120. CreatedUnix: createTime,
  121. UpdatedUnix: createTime,
  122. })
  123. if err != nil {
  124. log.Error("CreateCloudbrain(%s) failed:%v", req.DisplayJobName, err.Error())
  125. return err
  126. }
  127. var actionType models.ActionType
  128. if req.ComputeResource == models.NPUResource {
  129. actionType = models.ActionCreateTrainTask
  130. } else if req.ComputeResource == models.GPUResource {
  131. actionType = models.ActionCreateGPUTrainTask
  132. }
  133. notification.NotifyOtherTask(ctx.User, ctx.Repo.Repository, jobID, req.DisplayJobName, actionType)
  134. return nil
  135. }
  136. func TransTrainJobStatus(status string) string {
  137. if status == "pending" {
  138. status = "waiting"
  139. }
  140. return strings.ToUpper(status)
  141. }