You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.go 23 kB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682
  1. package cloudbrainTask
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "io/ioutil"
  8. "net/http"
  9. "os"
  10. "path"
  11. "regexp"
  12. "strings"
  13. "code.gitea.io/gitea/modules/obs"
  14. "code.gitea.io/gitea/modules/git"
  15. "code.gitea.io/gitea/modules/storage"
  16. "github.com/unknwon/com"
  17. "code.gitea.io/gitea/models"
  18. "code.gitea.io/gitea/modules/cloudbrain"
  19. "code.gitea.io/gitea/modules/context"
  20. "code.gitea.io/gitea/modules/grampus"
  21. "code.gitea.io/gitea/modules/log"
  22. "code.gitea.io/gitea/modules/modelarts"
  23. "code.gitea.io/gitea/modules/redis/redis_key"
  24. "code.gitea.io/gitea/modules/redis/redis_lock"
  25. "code.gitea.io/gitea/modules/setting"
  26. api "code.gitea.io/gitea/modules/structs"
  27. "code.gitea.io/gitea/modules/util"
  28. "code.gitea.io/gitea/services/cloudbrain/resource"
  29. "code.gitea.io/gitea/services/reward/point/account"
  30. )
  31. var jobNamePattern = regexp.MustCompile(`^[a-z0-9][a-z0-9-_]{1,34}[a-z0-9-]$`)
  32. func GrampusTrainJobGpuCreate(ctx *context.Context, option api.CreateTrainJobOption) {
  33. displayJobName := option.DisplayJobName
  34. jobName := util.ConvertDisplayJobNameToJobName(displayJobName)
  35. uuid := option.Attachment
  36. description := option.Description
  37. bootFile := strings.TrimSpace(option.BootFile)
  38. params := option.Params
  39. repo := ctx.Repo.Repository
  40. codeLocalPath := setting.JobPath + jobName + cloudbrain.CodeMountPath + "/"
  41. codeMinioPath := setting.CBCodePathPrefix + jobName + cloudbrain.CodeMountPath + "/"
  42. branchName := option.BranchName
  43. image := strings.TrimSpace(option.Image)
  44. lock := redis_lock.NewDistributeLock(redis_key.CloudbrainBindingJobNameKey(fmt.Sprint(repo.ID), string(models.JobTypeTrain), displayJobName))
  45. defer lock.UnLock()
  46. spec, datasetInfos, datasetNames, err := checkParameters(ctx, option, lock, repo)
  47. if err != nil {
  48. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error()))
  49. return
  50. }
  51. //prepare code and out path
  52. _, err = ioutil.ReadDir(codeLocalPath)
  53. if err == nil {
  54. os.RemoveAll(codeLocalPath)
  55. }
  56. if err := downloadZipCode(ctx, codeLocalPath, branchName); err != nil {
  57. log.Error("downloadZipCode failed, server timed out: %s (%v)", repo.FullName(), err, ctx.Data["MsgID"])
  58. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("cloudbrain.load_code_failed")))
  59. }
  60. //todo: upload code (send to file_server todo this work?)
  61. //upload code
  62. if err := uploadCodeToMinio(codeLocalPath+"/", jobName, cloudbrain.CodeMountPath+"/"); err != nil {
  63. log.Error("Failed to uploadCodeToMinio: %s (%v)", repo.FullName(), err, ctx.Data["MsgID"])
  64. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("cloudbrain.load_code_failed")))
  65. return
  66. }
  67. modelPath := setting.JobPath + jobName + cloudbrain.ModelMountPath + "/"
  68. if err := mkModelPath(modelPath); err != nil {
  69. log.Error("Failed to mkModelPath: %s (%v)", repo.FullName(), err, ctx.Data["MsgID"])
  70. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("cloudbrain.load_code_failed")))
  71. return
  72. }
  73. //init model readme
  74. if err := uploadCodeToMinio(modelPath, jobName, cloudbrain.ModelMountPath+"/"); err != nil {
  75. log.Error("Failed to uploadCodeToMinio: %s (%v)", repo.FullName(), err, ctx.Data["MsgID"])
  76. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("cloudbrain.load_code_failed")))
  77. return
  78. }
  79. var datasetRemotePath, allFileName string
  80. for _, datasetInfo := range datasetInfos {
  81. if datasetRemotePath == "" {
  82. datasetRemotePath = datasetInfo.DataLocalPath
  83. allFileName = datasetInfo.FullName
  84. } else {
  85. datasetRemotePath = datasetRemotePath + ";" + datasetInfo.DataLocalPath
  86. allFileName = allFileName + ";" + datasetInfo.FullName
  87. }
  88. }
  89. //prepare command
  90. preTrainModelPath := getPreTrainModelPath(option.PreTrainModelUrl, option.CkptName)
  91. command, err := generateCommand(repo.Name, grampus.ProcessorTypeGPU, codeMinioPath+cloudbrain.DefaultBranchName+".zip", datasetRemotePath, bootFile, params, setting.CBCodePathPrefix+jobName+cloudbrain.ModelMountPath+"/", allFileName, preTrainModelPath, option.CkptName)
  92. if err != nil {
  93. log.Error("Failed to generateCommand: %s (%v)", displayJobName, err, ctx.Data["MsgID"])
  94. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("Create task failed, internal error"))
  95. return
  96. }
  97. commitID, _ := ctx.Repo.GitRepo.GetBranchCommitID(branchName)
  98. req := &grampus.GenerateTrainJobReq{
  99. JobName: jobName,
  100. DisplayJobName: displayJobName,
  101. ComputeResource: models.GPUResource,
  102. ProcessType: grampus.ProcessorTypeGPU,
  103. Command: command,
  104. ImageUrl: image,
  105. Description: description,
  106. BootFile: bootFile,
  107. Uuid: uuid,
  108. CommitID: commitID,
  109. BranchName: branchName,
  110. Params: option.Params,
  111. EngineName: image,
  112. DatasetNames: datasetNames,
  113. DatasetInfos: datasetInfos,
  114. IsLatestVersion: modelarts.IsLatestVersion,
  115. VersionCount: modelarts.VersionCountOne,
  116. WorkServerNumber: 1,
  117. Spec: spec,
  118. }
  119. if option.ModelName != "" { //使用预训练模型训练
  120. req.ModelName = option.ModelName
  121. req.LabelName = option.LabelName
  122. req.CkptName = option.CkptName
  123. req.ModelVersion = option.ModelVersion
  124. req.PreTrainModelUrl = option.PreTrainModelUrl
  125. }
  126. jobId, err := grampus.GenerateTrainJob(ctx, req)
  127. if err != nil {
  128. log.Error("GenerateTrainJob failed:%v", err.Error(), ctx.Data["MsgID"])
  129. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error()))
  130. return
  131. }
  132. ctx.JSON(http.StatusOK, models.BaseMessageApi{Code: 0, Message: jobId})
  133. }
  134. func checkParameters(ctx *context.Context, option api.CreateTrainJobOption, lock *redis_lock.DistributeLock, repo *models.Repository) (*models.Specification, map[string]models.DatasetInfo, string, error) {
  135. isOk, err := lock.Lock(models.CloudbrainKeyDuration)
  136. if !isOk {
  137. log.Error("lock processed failed:%v", err, ctx.Data["MsgID"])
  138. return nil, nil, "", fmt.Errorf(ctx.Tr("repo.cloudbrain_samejob_err"))
  139. }
  140. if !jobNamePattern.MatchString(option.DisplayJobName) {
  141. return nil, nil, "", fmt.Errorf(ctx.Tr("repo.cloudbrain_jobname_err"))
  142. }
  143. bootFileExist, err := ctx.Repo.FileExists(option.BootFile, option.BranchName)
  144. if err != nil || !bootFileExist {
  145. log.Error("Get bootfile error:", err, ctx.Data["MsgID"])
  146. return nil, nil, "", fmt.Errorf(ctx.Tr("repo.cloudbrain_bootfile_err"))
  147. }
  148. computeResource := models.GPUResource
  149. if option.Type == 3 {
  150. computeResource = models.NPUResource
  151. }
  152. //check count limit
  153. count, err := GetNotFinalStatusTaskCount(ctx.User.ID, models.TypeC2Net, string(models.JobTypeTrain), computeResource)
  154. if err != nil {
  155. log.Error("GetGrampusCountByUserID failed:%v", err, ctx.Data["MsgID"])
  156. return nil, nil, "", fmt.Errorf("system error")
  157. } else {
  158. if count >= 1 {
  159. log.Error("the user already has running or waiting task", ctx.Data["MsgID"])
  160. return nil, nil, "", fmt.Errorf("you have already a running or waiting task, can not create more.")
  161. }
  162. }
  163. //check param
  164. if err := grampusParamCheckCreateTrainJob(option.BootFile, option.BranchName); err != nil {
  165. log.Error("paramCheckCreateTrainJob failed:(%v)", err, ctx.Data["MsgID"])
  166. return nil, nil, "", err
  167. }
  168. //check whether the task name in the project is duplicated
  169. tasks, err := models.GetCloudbrainsByDisplayJobName(repo.ID, string(models.JobTypeTrain), option.DisplayJobName)
  170. if err == nil {
  171. if len(tasks) != 0 {
  172. log.Error("the job name did already exist", ctx.Data["MsgID"])
  173. return nil, nil, "", fmt.Errorf("The job name did already exist.")
  174. }
  175. } else {
  176. if !models.IsErrJobNotExist(err) {
  177. log.Error("system error, %v", err, ctx.Data["MsgID"])
  178. return nil, nil, "", fmt.Errorf("system error")
  179. }
  180. }
  181. //check specification
  182. computeType := models.GPU
  183. if option.Type == 3 {
  184. computeType = models.NPU
  185. }
  186. spec, err := resource.GetAndCheckSpec(ctx.User.ID, option.SpecId, models.FindSpecsOptions{
  187. JobType: models.JobTypeTrain,
  188. ComputeResource: computeType,
  189. Cluster: models.C2NetCluster,
  190. })
  191. if err != nil || spec == nil {
  192. return nil, nil, "", fmt.Errorf("Resource specification is not available.")
  193. }
  194. if !account.IsPointBalanceEnough(ctx.User.ID, spec.UnitPrice) {
  195. log.Error("point balance is not enough,userId=%d specId=%d", ctx.User.ID, spec.ID)
  196. return nil, nil, "", fmt.Errorf(ctx.Tr("points.insufficient_points_balance"))
  197. }
  198. //check dataset
  199. datasetInfos, datasetNames, err := models.GetDatasetInfo(option.Attachment, computeType)
  200. if err != nil {
  201. log.Error("GetDatasetInfo failed: %v", err, ctx.Data["MsgID"])
  202. return nil, nil, "", fmt.Errorf(ctx.Tr("cloudbrain.error.dataset_select"))
  203. }
  204. return spec, datasetInfos, datasetNames, err
  205. }
  206. func GrampusTrainJobNpuCreate(ctx *context.Context, option api.CreateTrainJobOption) {
  207. displayJobName := option.DisplayJobName
  208. jobName := util.ConvertDisplayJobNameToJobName(displayJobName)
  209. uuid := option.Attachment
  210. description := option.Description
  211. bootFile := strings.TrimSpace(option.BootFile)
  212. params := option.Params
  213. repo := ctx.Repo.Repository
  214. codeLocalPath := setting.JobPath + jobName + modelarts.CodePath
  215. codeObsPath := grampus.JobPath + jobName + modelarts.CodePath
  216. branchName := option.BranchName
  217. isLatestVersion := modelarts.IsLatestVersion
  218. versionCount := modelarts.VersionCountOne
  219. engineName := option.Image
  220. lock := redis_lock.NewDistributeLock(redis_key.CloudbrainBindingJobNameKey(fmt.Sprint(repo.ID), string(models.JobTypeTrain), displayJobName))
  221. defer lock.UnLock()
  222. spec, datasetInfos, datasetNames, err := checkParameters(ctx, option, lock, repo)
  223. if err != nil {
  224. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error()))
  225. return
  226. }
  227. //prepare code and out path
  228. _, err = ioutil.ReadDir(codeLocalPath)
  229. if err == nil {
  230. os.RemoveAll(codeLocalPath)
  231. }
  232. if err := downloadZipCode(ctx, codeLocalPath, branchName); err != nil {
  233. log.Error("downloadZipCode failed, server timed out: %s (%v)", repo.FullName(), err)
  234. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("cloudbrain.load_code_failed")))
  235. return
  236. }
  237. //todo: upload code (send to file_server todo this work?)
  238. if err := obsMkdir(setting.CodePathPrefix + jobName + modelarts.OutputPath); err != nil {
  239. log.Error("Failed to obsMkdir_output: %s (%v)", repo.FullName(), err)
  240. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("cloudbrain.load_code_failed")))
  241. return
  242. }
  243. if err := uploadCodeToObs(codeLocalPath, jobName, ""); err != nil {
  244. log.Error("Failed to uploadCodeToObs: %s (%v)", repo.FullName(), err)
  245. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("cloudbrain.load_code_failed")))
  246. return
  247. }
  248. var datasetRemotePath, allFileName string
  249. for _, datasetInfo := range datasetInfos {
  250. if datasetRemotePath == "" {
  251. datasetRemotePath = datasetInfo.DataLocalPath + "'" + datasetInfo.FullName + "'"
  252. allFileName = datasetInfo.FullName
  253. } else {
  254. datasetRemotePath = datasetRemotePath + ";" + datasetInfo.DataLocalPath + "'" + datasetInfo.FullName + "'"
  255. allFileName = allFileName + ";" + datasetInfo.FullName
  256. }
  257. }
  258. //prepare command
  259. preTrainModelPath := getPreTrainModelPath(option.PreTrainModelUrl, option.CkptName)
  260. command, err := generateCommand(repo.Name, grampus.ProcessorTypeNPU, codeObsPath+cloudbrain.DefaultBranchName+".zip", datasetRemotePath, bootFile, params, setting.CodePathPrefix+jobName+modelarts.OutputPath, allFileName, preTrainModelPath, option.CkptName)
  261. if err != nil {
  262. log.Error("Failed to generateCommand: %s (%v)", displayJobName, err, ctx.Data["MsgID"])
  263. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("Create task failed, internal error"))
  264. return
  265. }
  266. commitID, _ := ctx.Repo.GitRepo.GetBranchCommitID(branchName)
  267. req := &grampus.GenerateTrainJobReq{
  268. JobName: jobName,
  269. DisplayJobName: displayJobName,
  270. ComputeResource: models.NPUResource,
  271. ProcessType: grampus.ProcessorTypeNPU,
  272. Command: command,
  273. ImageId: option.ImageID,
  274. Description: description,
  275. CodeObsPath: codeObsPath,
  276. BootFileUrl: codeObsPath + bootFile,
  277. BootFile: bootFile,
  278. WorkServerNumber: option.WorkServerNumber,
  279. Uuid: uuid,
  280. CommitID: commitID,
  281. IsLatestVersion: isLatestVersion,
  282. BranchName: branchName,
  283. Params: option.Params,
  284. EngineName: engineName,
  285. VersionCount: versionCount,
  286. TotalVersionCount: modelarts.TotalVersionCount,
  287. DatasetNames: datasetNames,
  288. DatasetInfos: datasetInfos,
  289. Spec: spec,
  290. CodeName: strings.ToLower(repo.Name),
  291. }
  292. if option.ModelName != "" { //使用预训练模型训练
  293. req.ModelName = option.ModelName
  294. req.LabelName = option.LabelName
  295. req.CkptName = option.CkptName
  296. req.ModelVersion = option.ModelVersion
  297. req.PreTrainModelUrl = option.PreTrainModelUrl
  298. req.PreTrainModelPath = preTrainModelPath
  299. }
  300. jobId, err := grampus.GenerateTrainJob(ctx, req)
  301. if err != nil {
  302. log.Error("GenerateTrainJob failed:%v", err.Error())
  303. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error()))
  304. return
  305. }
  306. ctx.JSON(http.StatusOK, models.BaseMessageApi{Code: 0, Message: jobId})
  307. }
  308. func obsMkdir(dir string) error {
  309. input := &obs.PutObjectInput{}
  310. input.Bucket = setting.Bucket
  311. input.Key = dir
  312. _, err := storage.ObsCli.PutObject(input)
  313. if err != nil {
  314. log.Error("PutObject(%s) failed: %s", input.Key, err.Error())
  315. return err
  316. }
  317. return nil
  318. }
  319. func uploadCodeToObs(codePath, jobName, parentDir string) error {
  320. files, err := readDir(codePath)
  321. if err != nil {
  322. log.Error("readDir(%s) failed: %s", codePath, err.Error())
  323. return err
  324. }
  325. for _, file := range files {
  326. if file.IsDir() {
  327. input := &obs.PutObjectInput{}
  328. input.Bucket = setting.Bucket
  329. input.Key = parentDir + file.Name() + "/"
  330. _, err = storage.ObsCli.PutObject(input)
  331. if err != nil {
  332. log.Error("PutObject(%s) failed: %s", input.Key, err.Error())
  333. return err
  334. }
  335. if err = uploadCodeToObs(codePath+file.Name()+"/", jobName, parentDir+file.Name()+"/"); err != nil {
  336. log.Error("uploadCodeToObs(%s) failed: %s", file.Name(), err.Error())
  337. return err
  338. }
  339. } else {
  340. input := &obs.PutFileInput{}
  341. input.Bucket = setting.Bucket
  342. input.Key = setting.CodePathPrefix + jobName + "/code/" + parentDir + file.Name()
  343. input.SourceFile = codePath + file.Name()
  344. _, err = storage.ObsCli.PutFile(input)
  345. if err != nil {
  346. log.Error("PutFile(%s) failed: %s", input.SourceFile, err.Error())
  347. return err
  348. }
  349. }
  350. }
  351. return nil
  352. }
  353. func grampusParamCheckCreateTrainJob(bootFile string, branchName string) error {
  354. if !strings.HasSuffix(strings.TrimSpace(bootFile), ".py") {
  355. log.Error("the boot file(%s) must be a python file", bootFile)
  356. return errors.New("启动文件必须是python文件")
  357. }
  358. if branchName == "" {
  359. log.Error("the branch must not be null!", branchName)
  360. return errors.New("代码分支不能为空!")
  361. }
  362. return nil
  363. }
  364. func downloadZipCode(ctx *context.Context, codePath, branchName string) error {
  365. archiveType := git.ZIP
  366. archivePath := codePath
  367. if !com.IsDir(archivePath) {
  368. if err := os.MkdirAll(archivePath, os.ModePerm); err != nil {
  369. log.Error("MkdirAll failed:" + err.Error())
  370. return err
  371. }
  372. }
  373. // Get corresponding commit.
  374. var (
  375. commit *git.Commit
  376. err error
  377. )
  378. gitRepo := ctx.Repo.GitRepo
  379. if err != nil {
  380. log.Error("OpenRepository failed:" + err.Error())
  381. return err
  382. }
  383. if gitRepo.IsBranchExist(branchName) {
  384. commit, err = gitRepo.GetBranchCommit(branchName)
  385. if err != nil {
  386. log.Error("GetBranchCommit failed:" + err.Error())
  387. return err
  388. }
  389. } else {
  390. log.Error("the branch is not exist: " + branchName)
  391. return fmt.Errorf("The branch does not exist.")
  392. }
  393. archivePath = path.Join(archivePath, grampus.CodeArchiveName)
  394. if !com.IsFile(archivePath) {
  395. if err := commit.CreateArchive(archivePath, git.CreateArchiveOpts{
  396. Format: archiveType,
  397. Prefix: setting.Repository.PrefixArchiveFiles,
  398. }); err != nil {
  399. log.Error("CreateArchive failed:" + err.Error())
  400. return err
  401. }
  402. }
  403. return nil
  404. }
  405. func uploadCodeToMinio(codePath, jobName, parentDir string) error {
  406. files, err := readDir(codePath)
  407. if err != nil {
  408. log.Error("readDir(%s) failed: %s", codePath, err.Error())
  409. return err
  410. }
  411. for _, file := range files {
  412. if file.IsDir() {
  413. if err = uploadCodeToMinio(codePath+file.Name()+"/", jobName, parentDir+file.Name()+"/"); err != nil {
  414. log.Error("uploadCodeToMinio(%s) failed: %s", file.Name(), err.Error())
  415. return err
  416. }
  417. } else {
  418. destObject := setting.CBCodePathPrefix + jobName + parentDir + file.Name()
  419. sourceFile := codePath + file.Name()
  420. err = storage.Attachments.UploadObject(destObject, sourceFile)
  421. if err != nil {
  422. log.Error("UploadObject(%s) failed: %s", file.Name(), err.Error())
  423. return err
  424. }
  425. }
  426. }
  427. return nil
  428. }
  429. func readDir(dirname string) ([]os.FileInfo, error) {
  430. f, err := os.Open(dirname)
  431. if err != nil {
  432. return nil, err
  433. }
  434. list, err := f.Readdir(0)
  435. f.Close()
  436. if err != nil {
  437. //todo: can not upload empty folder
  438. if err == io.EOF {
  439. return nil, nil
  440. }
  441. return nil, err
  442. }
  443. //sort.Slice(list, func(i, j int) bool { return list[i].Name() < list[j].Name() })
  444. return list, nil
  445. }
  446. func mkModelPath(modelPath string) error {
  447. return mkPathAndReadMeFile(modelPath, "You can put the files into this directory and download the files by the web page.")
  448. }
  449. func mkPathAndReadMeFile(path string, text string) error {
  450. err := os.MkdirAll(path, os.ModePerm)
  451. if err != nil {
  452. log.Error("MkdirAll(%s) failed:%v", path, err)
  453. return err
  454. }
  455. fileName := path + "README"
  456. f, err := os.OpenFile(fileName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm)
  457. if err != nil {
  458. log.Error("OpenFile failed", err.Error())
  459. return err
  460. }
  461. defer f.Close()
  462. _, err = f.WriteString(text)
  463. if err != nil {
  464. log.Error("WriteString failed", err.Error())
  465. return err
  466. }
  467. return nil
  468. }
  469. func getPreTrainModelPath(pretrainModelDir string, fileName string) string {
  470. index := strings.Index(pretrainModelDir, "/")
  471. if index > 0 {
  472. filterBucket := pretrainModelDir[index+1:]
  473. return filterBucket + fileName
  474. } else {
  475. return ""
  476. }
  477. }
  478. func generateCommand(repoName, processorType, codeRemotePath, dataRemotePath, bootFile, paramSrc, outputRemotePath, datasetName, pretrainModelPath, pretrainModelFileName string) (string, error) {
  479. var command string
  480. workDir := grampus.NpuWorkDir
  481. if processorType == grampus.ProcessorTypeGPU {
  482. workDir = grampus.GpuWorkDir
  483. }
  484. command += "pwd;cd " + workDir + fmt.Sprintf(grampus.CommandPrepareScript, setting.Grampus.SyncScriptProject, setting.Grampus.SyncScriptProject)
  485. //download code & dataset
  486. if processorType == grampus.ProcessorTypeNPU {
  487. //no need to download code & dataset by internet
  488. } else if processorType == grampus.ProcessorTypeGPU {
  489. commandDownload := "./downloader_for_minio " + setting.Grampus.Env + " " + codeRemotePath + " " + grampus.CodeArchiveName + " '" + dataRemotePath + "' '" + datasetName + "'"
  490. commandDownload = processPretrainModelParameter(pretrainModelPath, pretrainModelFileName, commandDownload)
  491. command += commandDownload
  492. }
  493. //unzip code & dataset
  494. if processorType == grampus.ProcessorTypeNPU {
  495. //no need to process
  496. } else if processorType == grampus.ProcessorTypeGPU {
  497. unZipDatasetCommand := generateDatasetUnzipCommand(datasetName)
  498. commandUnzip := "cd " + workDir + "code;unzip -q master.zip;echo \"start to unzip dataset\";cd " + workDir + "dataset;" + unZipDatasetCommand
  499. command += commandUnzip
  500. }
  501. command += "echo \"unzip finished;start to exec code;\";"
  502. // set export
  503. var commandExport string
  504. if processorType == grampus.ProcessorTypeNPU {
  505. commandExport = "export bucket=" + setting.Bucket + " && export remote_path=" + outputRemotePath + ";"
  506. } else if processorType == grampus.ProcessorTypeGPU {
  507. commandExport = "export env=" + setting.Grampus.Env + " && export remote_path=" + outputRemotePath + ";"
  508. }
  509. command += commandExport
  510. //exec code
  511. var parameters models.Parameters
  512. var paramCode string
  513. if len(paramSrc) != 0 {
  514. err := json.Unmarshal([]byte(paramSrc), &parameters)
  515. if err != nil {
  516. log.Error("Failed to Unmarshal params: %s (%v)", paramSrc, err)
  517. return command, err
  518. }
  519. for _, parameter := range parameters.Parameter {
  520. paramCode += " --" + parameter.Label + "=" + parameter.Value
  521. }
  522. }
  523. var commandCode string
  524. if processorType == grampus.ProcessorTypeNPU {
  525. commandCode = "/bin/bash /home/work/run_train_for_openi.sh /home/work/openi.py /tmp/log/train.log" + paramCode + ";"
  526. } else if processorType == grampus.ProcessorTypeGPU {
  527. if pretrainModelFileName != "" {
  528. paramCode += " --ckpt_url" + "=" + workDir + "pretrainmodel/" + pretrainModelFileName
  529. }
  530. commandCode = "cd " + workDir + "code/" + strings.ToLower(repoName) + ";python " + bootFile + paramCode + ";"
  531. }
  532. command += commandCode
  533. //get exec result
  534. commandGetRes := "result=$?;"
  535. command += commandGetRes
  536. //upload models
  537. if processorType == grampus.ProcessorTypeNPU {
  538. commandUpload := "cd " + workDir + setting.Grampus.SyncScriptProject + "/;./uploader_for_npu " + setting.Bucket + " " + outputRemotePath + " " + workDir + "output/;"
  539. command += commandUpload
  540. } else if processorType == grampus.ProcessorTypeGPU {
  541. commandUpload := "cd " + workDir + setting.Grampus.SyncScriptProject + "/;./uploader_for_gpu " + setting.Grampus.Env + " " + outputRemotePath + " " + workDir + "output/;"
  542. command += commandUpload
  543. }
  544. //check exec result
  545. commandCheckRes := "bash -c \"[[ $result -eq 0 ]] && exit 0 || exit -1\""
  546. command += commandCheckRes
  547. return command, nil
  548. }
  549. func processPretrainModelParameter(pretrainModelPath string, pretrainModelFileName string, commandDownload string) string {
  550. commandDownloadTemp := commandDownload
  551. if pretrainModelPath != "" {
  552. commandDownloadTemp += " '" + pretrainModelPath + "' '" + pretrainModelFileName + "'"
  553. }
  554. commandDownloadTemp += ";"
  555. return commandDownloadTemp
  556. }
  557. func generateDatasetUnzipCommand(datasetName string) string {
  558. var unZipDatasetCommand string
  559. datasetNameArray := strings.Split(datasetName, ";")
  560. if len(datasetNameArray) == 1 { //单数据集
  561. unZipDatasetCommand = "unzip -q '" + datasetName + "';"
  562. if strings.HasSuffix(datasetNameArray[0], ".tar.gz") {
  563. unZipDatasetCommand = "tar --strip-components=1 -zxvf '" + datasetName + "';"
  564. }
  565. } else { //多数据集
  566. for _, datasetNameTemp := range datasetNameArray {
  567. if strings.HasSuffix(datasetNameTemp, ".tar.gz") {
  568. unZipDatasetCommand = unZipDatasetCommand + "tar -zxvf '" + datasetNameTemp + "';"
  569. } else {
  570. unZipDatasetCommand = unZipDatasetCommand + "unzip -q '" + datasetNameTemp + "' -d './" + strings.TrimSuffix(datasetNameTemp, ".zip") + "';"
  571. }
  572. }
  573. }
  574. return unZipDatasetCommand
  575. }
  576. func getPoolId() string {
  577. var resourcePools modelarts.ResourcePool
  578. json.Unmarshal([]byte(setting.ResourcePools), &resourcePools)
  579. return resourcePools.Info[0].ID
  580. }