From 450993bd46257dc2855c4be37bf97668515051b8 Mon Sep 17 00:00:00 2001 From: zouap Date: Thu, 4 Nov 2021 16:05:17 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8F=90=E4=BA=A4=E4=BB=A3=E7=A0=81=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: zouap --- routers/private/tool.go | 12 +++++++ routers/repo/ai_model_manage.go | 64 +++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+) diff --git a/routers/private/tool.go b/routers/private/tool.go index b93f17090..68de67aa7 100755 --- a/routers/private/tool.go +++ b/routers/private/tool.go @@ -44,3 +44,15 @@ func RepoStatisticManually(ctx *macaron.Context) { repo.SummaryStatisticDaily(date) repo.TimingCountDataByDate(date) } + +func CreateModel(ctx *macaron.Context) { + trainTaskId := ctx.QueryInt64("TrainTask") + name := ctx.Query("Name") + version := ctx.Query("Version") + label := ctx.Query("Label") + description := ctx.Query("Description") + userId := ctx.QueryInt64("userId") + + repo.SaveModelByParameters(trainTaskId, name, version, label, description, userId) + +} diff --git a/routers/repo/ai_model_manage.go b/routers/repo/ai_model_manage.go index 1dd1e82b2..ec99a8048 100644 --- a/routers/repo/ai_model_manage.go +++ b/routers/repo/ai_model_manage.go @@ -17,6 +17,70 @@ import ( uuid "github.com/satori/go.uuid" ) +func SaveModelByParameters(trainTaskId int64, name string, version string, label string, description string, userId int64) { + aiTasks, _, err := models.Cloudbrains(&models.CloudbrainsOptions{ + JobID: trainTaskId, + }) + if err != nil { + log.Info("query task error." + err.Error()) + //ctx.Error(500, fmt.Sprintf("query cloud brain train task error. %v", err)) + return + } + uuid := uuid.NewV4() + id := uuid.String() + modelPath := id + parent := id + var modelSize int64 + cloudType := models.TypeCloudBrainTwo + + if len(aiTasks) != 1 { + log.Info("query task error. len=" + fmt.Sprint(len(aiTasks))) + //ctx.Error(500, fmt.Sprintf("query cloud brain train task error. %v", err)) + return + } + aiTask := aiTasks[0] + log.Info("find task name:" + aiTask.JobName) + aimodels := models.QueryModelByName(name, userId) + if len(aimodels) > 0 { + for _, model := range aimodels { + if model.ID == model.Parent { + parent = model.ID + } + } + } + cloudType = aiTask.Cloudbrain.Type + //download model zip //train type + if cloudType == models.TypeCloudBrainTrainJob { + modelPath, modelSize, err = downloadModelFromCloudBrainTwo(id, aiTask.JobName, "") + if err == nil { + + } else { + log.Info("download model from CloudBrainTwo faild." + err.Error()) + //ctx.Error(500, fmt.Sprintf("%v", err)) + return + } + } + + model := &models.AiModelManage{ + ID: id, + Version: version, + Label: label, + Name: name, + Description: description, + Parent: parent, + Type: cloudType, + Path: modelPath, + Size: modelSize, + AttachmentId: aiTask.Uuid, + RepoId: aiTask.RepoID, + UserId: userId, + } + + models.SaveModelToDb(model) + + log.Info("save model end.") +} + func SaveModel(ctx *context.Context) { log.Info("save model start.") trainTaskId := ctx.QueryInt64("TrainTask")