You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ai_model_manage.go 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. package repo
  2. import (
  3. "errors"
  4. "fmt"
  5. "path"
  6. "strings"
  7. "code.gitea.io/gitea/models"
  8. "code.gitea.io/gitea/modules/context"
  9. "code.gitea.io/gitea/modules/log"
  10. "code.gitea.io/gitea/modules/setting"
  11. "code.gitea.io/gitea/modules/storage"
  12. uuid "github.com/satori/go.uuid"
  13. )
  14. const (
  15. Model_prefix = "aimodels/"
  16. )
  17. func SaveModelByParameters(jobId string, name string, version string, label string, description string, userId int64) error {
  18. aiTask, err := models.GetCloudbrainByJobID(jobId)
  19. if err != nil {
  20. log.Info("query task error." + err.Error())
  21. return err
  22. }
  23. uuid := uuid.NewV4()
  24. id := uuid.String()
  25. modelPath := id
  26. parent := id
  27. var modelSize int64
  28. cloudType := models.TypeCloudBrainTwo
  29. log.Info("find task name:" + aiTask.JobName)
  30. aimodels := models.QueryModelByName(name, userId)
  31. if len(aimodels) > 0 {
  32. for _, model := range aimodels {
  33. if model.ID == model.Parent {
  34. parent = model.ID
  35. }
  36. }
  37. }
  38. cloudType = aiTask.Type
  39. //download model zip //train type
  40. if cloudType == models.TypeCloudBrainTwo {
  41. modelPath, modelSize, err = downloadModelFromCloudBrainTwo(id, aiTask.JobName, "")
  42. if err != nil {
  43. log.Info("download model from CloudBrainTwo faild." + err.Error())
  44. return err
  45. }
  46. }
  47. model := &models.AiModelManage{
  48. ID: id,
  49. Version: version,
  50. Label: label,
  51. Name: name,
  52. Description: description,
  53. Parent: parent,
  54. Type: cloudType,
  55. Path: modelPath,
  56. Size: modelSize,
  57. AttachmentId: aiTask.Uuid,
  58. RepoId: aiTask.RepoID,
  59. UserId: userId,
  60. }
  61. models.SaveModelToDb(model)
  62. log.Info("save model end.")
  63. return nil
  64. }
  65. func SaveModel(ctx *context.Context) {
  66. log.Info("save model start.")
  67. JobId := ctx.Query("JobId")
  68. name := ctx.Query("Name")
  69. version := ctx.Query("Version")
  70. label := ctx.Query("Label")
  71. description := ctx.Query("Description")
  72. err := SaveModelByParameters(JobId, name, version, label, description, ctx.User.ID)
  73. if err != nil {
  74. log.Info("save model error." + err.Error())
  75. ctx.Error(500, fmt.Sprintf("save model error. %v", err))
  76. return
  77. }
  78. log.Info("save model end.")
  79. }
  80. func downloadModelFromCloudBrainTwo(modelUUID string, jobName string, parentDir string) (string, int64, error) {
  81. dataActualPath := setting.Bucket + "/" + Model_prefix +
  82. models.AttachmentRelativePath(modelUUID) +
  83. "/"
  84. modelDbResult, err := storage.GetObsListObject(jobName, parentDir)
  85. if err != nil {
  86. log.Info("get TrainJobListModel failed:", err)
  87. return "", 0, err
  88. }
  89. if len(modelDbResult) == 0 {
  90. return "", 0, errors.New("cannot create model, as model is empty.")
  91. }
  92. var size int64
  93. prefix := strings.TrimPrefix(path.Join(setting.TrainJobModelPath, jobName, setting.OutPutPath, parentDir), "/") + "/"
  94. for _, modelFile := range modelDbResult {
  95. destKeyNamePrefix := Model_prefix + models.AttachmentRelativePath(modelUUID) + "/"
  96. log.Info("copy file, bucket=" + setting.Bucket + ", src keyname=" + prefix + modelFile.FileName)
  97. log.Info("Dest key name=" + destKeyNamePrefix + modelFile.FileName)
  98. err := storage.ObsCopyFile(setting.Bucket, prefix+modelFile.FileName, setting.Bucket, destKeyNamePrefix+modelFile.FileName)
  99. if err != nil {
  100. log.Info("copy failed.")
  101. }
  102. size += modelFile.Size
  103. }
  104. return dataActualPath, size, nil
  105. }
  106. func DeleteModel(ctx *context.Context) {
  107. log.Info("delete model start.")
  108. id := ctx.Query("ID")
  109. err := DeleteModelByID(id)
  110. if err != nil {
  111. ctx.JSON(500, err.Error())
  112. } else {
  113. ctx.JSON(200, map[string]string{
  114. "result_code": "0",
  115. })
  116. }
  117. }
  118. func DeleteModelByID(id string) error {
  119. log.Info("delete model start. id=" + id)
  120. model, err := models.QueryModelById(id)
  121. if err == nil {
  122. log.Info("bucket=" + setting.Bucket + " path=" + model.Path)
  123. if strings.HasPrefix(model.Path, setting.Bucket+"/"+Model_prefix) {
  124. err := storage.ObsRemoveObject(setting.Bucket, model.Path[len(setting.Bucket)+1:])
  125. if err != nil {
  126. log.Info("Failed to delete model. id=" + id)
  127. return err
  128. }
  129. }
  130. return models.DeleteModelById(id)
  131. }
  132. return err
  133. }
  134. func DownloadModel(ctx *context.Context) {
  135. log.Info("download model start.")
  136. }
  137. func QueryModelByParameters(repoId int64, page int) ([]*models.AiModelManage, int64, error) {
  138. return models.QueryModel(&models.AiModelQueryOptions{
  139. ListOptions: models.ListOptions{
  140. Page: page,
  141. PageSize: setting.UI.IssuePagingNum,
  142. },
  143. RepoID: repoId,
  144. Type: -1,
  145. })
  146. }
  147. func ShowModelInfo(ctx *context.Context) {
  148. log.Info("ShowModelInfo start.")
  149. page := ctx.QueryInt("page")
  150. if page <= 0 {
  151. page = 1
  152. }
  153. repoId := ctx.QueryInt64("repoId")
  154. Type := -1
  155. modelResult, count, err := models.QueryModel(&models.AiModelQueryOptions{
  156. ListOptions: models.ListOptions{
  157. Page: page,
  158. PageSize: setting.UI.IssuePagingNum,
  159. },
  160. RepoID: repoId,
  161. Type: Type,
  162. })
  163. if err != nil {
  164. ctx.ServerError("Cloudbrain", err)
  165. return
  166. }
  167. pager := context.NewPagination(int(count), setting.UI.IssuePagingNum, page, 5)
  168. pager.SetDefaultParams(ctx)
  169. ctx.Data["Page"] = pager
  170. ctx.Data["PageIsCloudBrain"] = true
  171. ctx.Data["Tasks"] = modelResult
  172. ctx.HTML(200, "")
  173. }
  174. func ModifyModel(id string, description string) error {
  175. err := models.ModifyModelDescription(id, description)
  176. if err == nil {
  177. log.Info("modify success.")
  178. } else {
  179. log.Info("Failed to modify.id=" + id + " desc=" + description)
  180. }
  181. return err
  182. }
  183. func ModifyModelInfo(ctx *context.Context) {
  184. log.Info("delete model start.")
  185. id := ctx.Query("ID")
  186. description := ctx.Query("Description")
  187. err := ModifyModel(id, description)
  188. if err == nil {
  189. ctx.HTML(200, "success")
  190. } else {
  191. ctx.HTML(500, "Failed.")
  192. }
  193. }