You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ai_model_manage.go 7.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. package repo
  2. import (
  3. "archive/zip"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "io/ioutil"
  8. "os"
  9. "path"
  10. "path/filepath"
  11. "strings"
  12. "code.gitea.io/gitea/models"
  13. "code.gitea.io/gitea/modules/context"
  14. "code.gitea.io/gitea/modules/log"
  15. "code.gitea.io/gitea/modules/setting"
  16. "code.gitea.io/gitea/modules/storage"
  17. uuid "github.com/satori/go.uuid"
  18. )
  19. func SaveModelByParameters(trainTaskId string, name string, version string, label string, description string, userId int64) {
  20. aiTask, err := models.GetCloudbrainByJobID(trainTaskId)
  21. if err != nil {
  22. log.Info("query task error." + err.Error())
  23. //ctx.Error(500, fmt.Sprintf("query cloud brain train task error. %v", err))
  24. return
  25. }
  26. uuid := uuid.NewV4()
  27. id := uuid.String()
  28. modelPath := id
  29. parent := id
  30. var modelSize int64
  31. cloudType := models.TypeCloudBrainTwo
  32. log.Info("find task name:" + aiTask.JobName)
  33. aimodels := models.QueryModelByName(name, userId)
  34. if len(aimodels) > 0 {
  35. for _, model := range aimodels {
  36. if model.ID == model.Parent {
  37. parent = model.ID
  38. }
  39. }
  40. }
  41. cloudType = aiTask.Type
  42. //download model zip //train type
  43. if cloudType == models.TypeCloudBrainTwo {
  44. modelPath, modelSize, err = downloadModelFromCloudBrainTwo(id, aiTask.JobName, "")
  45. if err == nil {
  46. } else {
  47. log.Info("download model from CloudBrainTwo faild." + err.Error())
  48. //ctx.Error(500, fmt.Sprintf("%v", err))
  49. return
  50. }
  51. }
  52. model := &models.AiModelManage{
  53. ID: id,
  54. Version: version,
  55. Label: label,
  56. Name: name,
  57. Description: description,
  58. Parent: parent,
  59. Type: cloudType,
  60. Path: modelPath,
  61. Size: modelSize,
  62. AttachmentId: aiTask.Uuid,
  63. RepoId: aiTask.RepoID,
  64. UserId: userId,
  65. }
  66. models.SaveModelToDb(model)
  67. log.Info("save model end.")
  68. }
  69. func SaveModel(ctx *context.Context) {
  70. log.Info("save model start.")
  71. trainTaskId := ctx.QueryInt64("TrainTask")
  72. name := ctx.Query("Name")
  73. version := ctx.Query("Version")
  74. label := ctx.Query("Label")
  75. description := ctx.Query("Description")
  76. aiTasks, _, err := models.Cloudbrains(&models.CloudbrainsOptions{
  77. JobID: trainTaskId,
  78. })
  79. if err != nil {
  80. log.Info("query task error." + err.Error())
  81. ctx.Error(500, fmt.Sprintf("query cloud brain train task error. %v", err))
  82. return
  83. }
  84. uuid := uuid.NewV4()
  85. id := uuid.String()
  86. modelPath := id
  87. parent := id
  88. var modelSize int64
  89. cloudType := models.TypeCloudBrainTwo
  90. if len(aiTasks) != 1 {
  91. log.Info("query task error. len=" + fmt.Sprint(len(aiTasks)))
  92. ctx.Error(500, fmt.Sprintf("query cloud brain train task error. %v", err))
  93. return
  94. }
  95. aiTask := aiTasks[0]
  96. log.Info("find task name:" + aiTask.JobName)
  97. aimodels := models.QueryModelByName(name, ctx.User.ID)
  98. if len(aimodels) > 0 {
  99. for _, model := range aimodels {
  100. if model.ID == model.Parent {
  101. parent = model.ID
  102. }
  103. }
  104. }
  105. cloudType = aiTask.Cloudbrain.Type
  106. //download model zip //train type
  107. if cloudType == models.TypeCloudBrainTrainJob {
  108. modelPath, modelSize, err = downloadModelFromCloudBrainTwo(id, aiTask.JobName, "")
  109. if err == nil {
  110. } else {
  111. log.Info("download model from CloudBrainTwo faild." + err.Error())
  112. ctx.Error(500, fmt.Sprintf("%v", err))
  113. return
  114. }
  115. }
  116. model := &models.AiModelManage{
  117. ID: id,
  118. Version: version,
  119. Label: label,
  120. Name: name,
  121. Description: description,
  122. Parent: parent,
  123. Type: cloudType,
  124. Path: modelPath,
  125. Size: modelSize,
  126. AttachmentId: aiTask.Uuid,
  127. RepoId: aiTask.RepoID,
  128. UserId: ctx.User.ID,
  129. }
  130. models.SaveModelToDb(model)
  131. log.Info("save model end.")
  132. }
  133. func downloadModelFromCloudBrainTwo(modelUUID string, jobName string, parentDir string) (string, int64, error) {
  134. dataActualPath := setting.Bucket + "/" +
  135. "aimodels/" +
  136. models.AttachmentRelativePath(modelUUID) +
  137. "/"
  138. modelDbResult, err := storage.GetObsListObject(jobName, parentDir)
  139. if err != nil {
  140. log.Info("get TrainJobListModel failed:", err)
  141. return "", 0, err
  142. }
  143. if len(modelDbResult) == 0 {
  144. return "", 0, errors.New("cannot create model, as model is empty.")
  145. }
  146. prefix := strings.TrimPrefix(path.Join(setting.TrainJobModelPath, jobName, setting.OutPutPath, parentDir), "/")
  147. for _, modelFile := range modelDbResult {
  148. destKeyName := "/aimodels/" + models.AttachmentRelativePath(modelUUID) + "/"
  149. log.Info("copy file, bucket=%s, src keyname=%s, dest keyname=%s,", setting.Bucket, prefix+modelFile.FileName, destKeyName)
  150. // err := storage.ObsCopyFile(setting.Bucket, modelFile.ParenDir+modelFile.FileName, setting.Bucket, dataActualPath+modelFile.FileName)
  151. // if err != nil {
  152. // log.Info("copy failed.")
  153. // }
  154. }
  155. return dataActualPath, 0, nil
  156. }
  157. func DeleteModel(ctx *context.Context) {
  158. log.Info("delete model start.")
  159. id := ctx.Query("ID")
  160. err := DeleteModelByID(id)
  161. if err != nil {
  162. ctx.JSON(500, err.Error())
  163. } else {
  164. ctx.JSON(200, map[string]string{
  165. "result_code": "0",
  166. })
  167. }
  168. }
  169. func DeleteModelByID(id string) error {
  170. log.Info("delete model start. id=" + id)
  171. return models.DeleteModelById(id)
  172. }
  173. func DownloadModel(ctx *context.Context) {
  174. log.Info("download model start.")
  175. }
  176. func QueryModelByParameters(repoId int64, page int) ([]*models.AiModelManage, int64, error) {
  177. return models.QueryModel(&models.AiModelQueryOptions{
  178. ListOptions: models.ListOptions{
  179. Page: page,
  180. PageSize: setting.UI.IssuePagingNum,
  181. },
  182. RepoID: repoId,
  183. Type: -1,
  184. })
  185. }
  186. func ShowModelInfo(ctx *context.Context) {
  187. log.Info("ShowModelInfo start.")
  188. page := ctx.QueryInt("page")
  189. if page <= 0 {
  190. page = 1
  191. }
  192. repoId := ctx.QueryInt64("repoId")
  193. Type := -1
  194. modelResult, count, err := models.QueryModel(&models.AiModelQueryOptions{
  195. ListOptions: models.ListOptions{
  196. Page: page,
  197. PageSize: setting.UI.IssuePagingNum,
  198. },
  199. RepoID: repoId,
  200. Type: Type,
  201. })
  202. if err != nil {
  203. ctx.ServerError("Cloudbrain", err)
  204. return
  205. }
  206. pager := context.NewPagination(int(count), setting.UI.IssuePagingNum, page, 5)
  207. pager.SetDefaultParams(ctx)
  208. ctx.Data["Page"] = pager
  209. ctx.Data["PageIsCloudBrain"] = true
  210. ctx.Data["Tasks"] = modelResult
  211. ctx.HTML(200, "")
  212. }
  213. func downloadModelFromCloudBrainOne(modelUUID string, jobName string, parentDir string) (string, int64, error) {
  214. modelActualPath := setting.Attachment.Minio.RealPath +
  215. setting.Attachment.Minio.Bucket + "/" +
  216. "aimodels/" +
  217. models.AttachmentRelativePath(modelUUID) +
  218. "/"
  219. os.MkdirAll(modelActualPath, 0755)
  220. zipFile := modelActualPath + "model.zip"
  221. modelDir := setting.JobPath + jobName + "/model/"
  222. dir, _ := ioutil.ReadDir(modelDir)
  223. if len(dir) == 0 {
  224. return "", 0, errors.New("cannot create model, as model is empty.")
  225. }
  226. err := zipDir(modelDir, zipFile)
  227. if err != nil {
  228. return "", 0, err
  229. }
  230. fi, err := os.Stat(zipFile)
  231. if err == nil {
  232. return modelActualPath, fi.Size(), nil
  233. } else {
  234. return "", 0, err
  235. }
  236. }
  237. func zipDir(dir, zipFile string) error {
  238. fz, err := os.Create(zipFile)
  239. if err != nil {
  240. log.Info("Create zip file failed: %s\n", err.Error())
  241. return err
  242. }
  243. defer fz.Close()
  244. w := zip.NewWriter(fz)
  245. defer w.Close()
  246. err = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
  247. if !info.IsDir() {
  248. fDest, err := w.Create(path[len(dir)+1:])
  249. if err != nil {
  250. log.Info("Create failed: %s\n", err.Error())
  251. return err
  252. }
  253. fSrc, err := os.Open(path)
  254. if err != nil {
  255. log.Info("Open failed: %s\n", err.Error())
  256. return err
  257. }
  258. defer fSrc.Close()
  259. _, err = io.Copy(fDest, fSrc)
  260. if err != nil {
  261. log.Info("Copy failed: %s\n", err.Error())
  262. return err
  263. }
  264. }
  265. return nil
  266. })
  267. if err != nil {
  268. return err
  269. }
  270. return nil
  271. }