You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ai_model_manage.go 7.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. package repo
  2. import (
  3. "archive/zip"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "io/ioutil"
  8. "os"
  9. "path/filepath"
  10. "code.gitea.io/gitea/models"
  11. "code.gitea.io/gitea/modules/context"
  12. "code.gitea.io/gitea/modules/log"
  13. "code.gitea.io/gitea/modules/setting"
  14. "code.gitea.io/gitea/modules/storage"
  15. uuid "github.com/satori/go.uuid"
  16. )
  17. func SaveModelByParameters(trainTaskId string, name string, version string, label string, description string, userId int64) {
  18. aiTask, err := models.GetCloudbrainByJobID(trainTaskId)
  19. if err != nil {
  20. log.Info("query task error." + err.Error())
  21. //ctx.Error(500, fmt.Sprintf("query cloud brain train task error. %v", err))
  22. return
  23. }
  24. uuid := uuid.NewV4()
  25. id := uuid.String()
  26. modelPath := id
  27. parent := id
  28. var modelSize int64
  29. cloudType := models.TypeCloudBrainTwo
  30. log.Info("find task name:" + aiTask.JobName)
  31. aimodels := models.QueryModelByName(name, userId)
  32. if len(aimodels) > 0 {
  33. for _, model := range aimodels {
  34. if model.ID == model.Parent {
  35. parent = model.ID
  36. }
  37. }
  38. }
  39. cloudType = aiTask.Type
  40. //download model zip //train type
  41. if cloudType == models.TypeCloudBrainTwo {
  42. modelPath, modelSize, err = downloadModelFromCloudBrainTwo(id, aiTask.JobName, "")
  43. if err == nil {
  44. } else {
  45. log.Info("download model from CloudBrainTwo faild." + err.Error())
  46. //ctx.Error(500, fmt.Sprintf("%v", err))
  47. return
  48. }
  49. }
  50. model := &models.AiModelManage{
  51. ID: id,
  52. Version: version,
  53. Label: label,
  54. Name: name,
  55. Description: description,
  56. Parent: parent,
  57. Type: cloudType,
  58. Path: modelPath,
  59. Size: modelSize,
  60. AttachmentId: aiTask.Uuid,
  61. RepoId: aiTask.RepoID,
  62. UserId: userId,
  63. }
  64. models.SaveModelToDb(model)
  65. log.Info("save model end.")
  66. }
  67. func SaveModel(ctx *context.Context) {
  68. log.Info("save model start.")
  69. trainTaskId := ctx.QueryInt64("TrainTask")
  70. name := ctx.Query("Name")
  71. version := ctx.Query("Version")
  72. label := ctx.Query("Label")
  73. description := ctx.Query("Description")
  74. aiTasks, _, err := models.Cloudbrains(&models.CloudbrainsOptions{
  75. JobID: trainTaskId,
  76. })
  77. if err != nil {
  78. log.Info("query task error." + err.Error())
  79. ctx.Error(500, fmt.Sprintf("query cloud brain train task error. %v", err))
  80. return
  81. }
  82. uuid := uuid.NewV4()
  83. id := uuid.String()
  84. modelPath := id
  85. parent := id
  86. var modelSize int64
  87. cloudType := models.TypeCloudBrainTwo
  88. if len(aiTasks) != 1 {
  89. log.Info("query task error. len=" + fmt.Sprint(len(aiTasks)))
  90. ctx.Error(500, fmt.Sprintf("query cloud brain train task error. %v", err))
  91. return
  92. }
  93. aiTask := aiTasks[0]
  94. log.Info("find task name:" + aiTask.JobName)
  95. aimodels := models.QueryModelByName(name, ctx.User.ID)
  96. if len(aimodels) > 0 {
  97. for _, model := range aimodels {
  98. if model.ID == model.Parent {
  99. parent = model.ID
  100. }
  101. }
  102. }
  103. cloudType = aiTask.Cloudbrain.Type
  104. //download model zip //train type
  105. if cloudType == models.TypeCloudBrainTrainJob {
  106. modelPath, modelSize, err = downloadModelFromCloudBrainTwo(id, aiTask.JobName, "")
  107. if err == nil {
  108. } else {
  109. log.Info("download model from CloudBrainTwo faild." + err.Error())
  110. ctx.Error(500, fmt.Sprintf("%v", err))
  111. return
  112. }
  113. }
  114. model := &models.AiModelManage{
  115. ID: id,
  116. Version: version,
  117. Label: label,
  118. Name: name,
  119. Description: description,
  120. Parent: parent,
  121. Type: cloudType,
  122. Path: modelPath,
  123. Size: modelSize,
  124. AttachmentId: aiTask.Uuid,
  125. RepoId: aiTask.RepoID,
  126. UserId: ctx.User.ID,
  127. }
  128. models.SaveModelToDb(model)
  129. log.Info("save model end.")
  130. }
  131. func downloadModelFromCloudBrainTwo(modelUUID string, jobName string, parentDir string) (string, int64, error) {
  132. dataActualPath := setting.Bucket + "/" +
  133. "aimodels/" +
  134. models.AttachmentRelativePath(modelUUID) +
  135. "/"
  136. models, err := storage.GetObsListObject(jobName, parentDir)
  137. if err != nil {
  138. log.Info("get TrainJobListModel failed:", err)
  139. return "", 0, err
  140. }
  141. if len(models) == 0 {
  142. return "", 0, errors.New("cannot create model, as model is empty.")
  143. }
  144. for _, modelFile := range models {
  145. log.Info("copy file, bucket=%s, src keyname=%s,dest keyname=%s", setting.Bucket, modelFile.ParenDir+modelFile.FileName, dataActualPath+modelFile.FileName)
  146. // err := storage.ObsCopyFile(setting.Bucket, modelFile.ParenDir+modelFile.FileName, setting.Bucket, dataActualPath+modelFile.FileName)
  147. // if err != nil {
  148. // log.Info("copy failed.")
  149. // }
  150. }
  151. return dataActualPath, 0, nil
  152. }
  153. func DeleteModel(ctx *context.Context) {
  154. log.Info("delete model start.")
  155. id := ctx.Query("ID")
  156. err := DeleteModelByID(id)
  157. if err != nil {
  158. ctx.JSON(500, err.Error())
  159. } else {
  160. ctx.JSON(200, map[string]string{
  161. "result_code": "0",
  162. })
  163. }
  164. }
  165. func DeleteModelByID(id string) error {
  166. log.Info("delete model start. id=" + id)
  167. return models.DeleteModelById(id)
  168. }
  169. func DownloadModel(ctx *context.Context) {
  170. log.Info("download model start.")
  171. }
  172. func QueryModelByParameters(repoId int64, page int) ([]*models.AiModelManage, int64, error) {
  173. return models.QueryModel(&models.AiModelQueryOptions{
  174. ListOptions: models.ListOptions{
  175. Page: page,
  176. PageSize: setting.UI.IssuePagingNum,
  177. },
  178. RepoID: repoId,
  179. })
  180. }
  181. func ShowModelInfo(ctx *context.Context) {
  182. log.Info("ShowModelInfo start.")
  183. page := ctx.QueryInt("page")
  184. if page <= 0 {
  185. page = 1
  186. }
  187. repoId := ctx.QueryInt64("repoId")
  188. modelResult, count, err := models.QueryModel(&models.AiModelQueryOptions{
  189. ListOptions: models.ListOptions{
  190. Page: page,
  191. PageSize: setting.UI.IssuePagingNum,
  192. },
  193. RepoID: repoId,
  194. })
  195. if err != nil {
  196. ctx.ServerError("Cloudbrain", err)
  197. return
  198. }
  199. pager := context.NewPagination(int(count), setting.UI.IssuePagingNum, page, 5)
  200. pager.SetDefaultParams(ctx)
  201. ctx.Data["Page"] = pager
  202. ctx.Data["PageIsCloudBrain"] = true
  203. ctx.Data["Tasks"] = modelResult
  204. ctx.HTML(200, "")
  205. }
  206. func downloadModelFromCloudBrainOne(modelUUID string, jobName string, parentDir string) (string, int64, error) {
  207. modelActualPath := setting.Attachment.Minio.RealPath +
  208. setting.Attachment.Minio.Bucket + "/" +
  209. "aimodels/" +
  210. models.AttachmentRelativePath(modelUUID) +
  211. "/"
  212. os.MkdirAll(modelActualPath, 0755)
  213. zipFile := modelActualPath + "model.zip"
  214. modelDir := setting.JobPath + jobName + "/model/"
  215. dir, _ := ioutil.ReadDir(modelDir)
  216. if len(dir) == 0 {
  217. return "", 0, errors.New("cannot create model, as model is empty.")
  218. }
  219. err := zipDir(modelDir, zipFile)
  220. if err != nil {
  221. return "", 0, err
  222. }
  223. fi, err := os.Stat(zipFile)
  224. if err == nil {
  225. return modelActualPath, fi.Size(), nil
  226. } else {
  227. return "", 0, err
  228. }
  229. }
  230. func zipDir(dir, zipFile string) error {
  231. fz, err := os.Create(zipFile)
  232. if err != nil {
  233. log.Info("Create zip file failed: %s\n", err.Error())
  234. return err
  235. }
  236. defer fz.Close()
  237. w := zip.NewWriter(fz)
  238. defer w.Close()
  239. err = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
  240. if !info.IsDir() {
  241. fDest, err := w.Create(path[len(dir)+1:])
  242. if err != nil {
  243. log.Info("Create failed: %s\n", err.Error())
  244. return err
  245. }
  246. fSrc, err := os.Open(path)
  247. if err != nil {
  248. log.Info("Open failed: %s\n", err.Error())
  249. return err
  250. }
  251. defer fSrc.Close()
  252. _, err = io.Copy(fDest, fSrc)
  253. if err != nil {
  254. log.Info("Copy failed: %s\n", err.Error())
  255. return err
  256. }
  257. }
  258. return nil
  259. })
  260. if err != nil {
  261. return err
  262. }
  263. return nil
  264. }