Former-commit-id: 86418e14ba
pull/295/head
| @@ -20,7 +20,7 @@ type ( | |||
| TaskDesc string `form:"taskDesc"` | |||
| ModelName string `form:"modelName"` | |||
| ModelType string `form:"modelType"` | |||
| AdapterId string `form:"adapterId"` | |||
| AdapterIds []string `form:"adapterIds"` | |||
| AiClusterIds []string `form:"aiClusterIds,optional"` | |||
| ResourceType string `form:"resourceType,optional"` | |||
| ComputeCard string `form:"card,optional"` | |||
| @@ -76,6 +76,18 @@ type ( | |||
| } | |||
| /******************TextToImage inference*************************/ | |||
| TextToImageInferenceReq{ | |||
| TaskName string `form:"taskName"` | |||
| TaskDesc string `form:"taskDesc"` | |||
| ModelName string `form:"modelName"` | |||
| ModelType string `form:"modelType"` | |||
| AiClusterIds []string `form:"aiClusterIds"` | |||
| } | |||
| TextToImageInferenceResp{ | |||
| Result []byte | |||
| } | |||
| /******************Deploy instance*************************/ | |||
| DeployInstanceListReq{ | |||
| PageInfo | |||
| @@ -146,6 +158,7 @@ type ( | |||
| } | |||
| GetRunningInstanceReq { | |||
| AdapterIds []string `form:"adapterIds"` | |||
| ModelType string `path:"modelType"` | |||
| ModelName string `path:"modelName"` | |||
| } | |||
| @@ -1,28 +1,25 @@ | |||
| package inference | |||
| import ( | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/repository/result" | |||
| "net/http" | |||
| "github.com/zeromicro/go-zero/rest/httpx" | |||
| "gitlink.org.cn/tzwang/pcm-coordinator/internal/logic/inference" | |||
| "gitlink.org.cn/tzwang/pcm-coordinator/internal/svc" | |||
| "gitlink.org.cn/tzwang/pcm-coordinator/internal/types" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/logic/inference" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" | |||
| ) | |||
| func GetRunningInstanceByModelHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { | |||
| return func(w http.ResponseWriter, r *http.Request) { | |||
| var req types.GetRunningInstanceReq | |||
| if err := httpx.Parse(r, &req); err != nil { | |||
| httpx.ErrorCtx(r.Context(), w, err) | |||
| result.ParamErrorResult(r, w, err) | |||
| return | |||
| } | |||
| l := inference.NewGetRunningInstanceByModelLogic(r.Context(), svcCtx) | |||
| resp, err := l.GetRunningInstanceByModel(&req) | |||
| if err != nil { | |||
| httpx.ErrorCtx(r.Context(), w, err) | |||
| } else { | |||
| httpx.OkJsonCtx(r.Context(), w, resp) | |||
| } | |||
| result.HttpResult(r, w, resp, err) | |||
| } | |||
| } | |||
| @@ -4,6 +4,7 @@ import ( | |||
| "context" | |||
| "errors" | |||
| "github.com/zeromicro/go-zero/core/logx" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/common" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/updater" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" | |||
| @@ -30,14 +31,34 @@ func (l *DeployInstanceListLogic) DeployInstanceList(req *types.DeployInstanceLi | |||
| offset := req.PageSize * (req.PageNum - 1) | |||
| resp = &types.DeployInstanceListResp{} | |||
| var list []*models.AiInferDeployInstance | |||
| tx := l.svcCtx.DbEngin.Raw("select * from ai_infer_deploy_instance").Scan(&list) | |||
| var tasklist []*models.AiDeployInstanceTask | |||
| tx := l.svcCtx.DbEngin.Raw("select * from ai_deploy_instance_task").Scan(&tasklist) | |||
| if tx.Error != nil { | |||
| logx.Errorf(tx.Error.Error()) | |||
| return nil, tx.Error | |||
| } | |||
| //count total | |||
| var total int64 | |||
| err = tx.Count(&total).Error | |||
| tx.Limit(limit).Offset(offset) | |||
| if err != nil { | |||
| return resp, err | |||
| } | |||
| err = tx.Order("create_time desc").Find(&tasklist).Error | |||
| if err != nil { | |||
| return nil, errors.New(err.Error()) | |||
| } | |||
| deployTasks := l.GenerateDeployTasks(tasklist) | |||
| slices := make([][]*models.AiInferDeployInstance, len(deployTasks)) | |||
| for i := 0; i < len(deployTasks); i++ { | |||
| slices[i] = deployTasks[i].Instances | |||
| } | |||
| list := common.ConcatMultipleSlices(slices) | |||
| if len(list) == 0 { | |||
| return | |||
| } | |||
| @@ -55,23 +76,35 @@ func (l *DeployInstanceListLogic) DeployInstanceList(req *types.DeployInstanceLi | |||
| go updater.UpdateDeployInstanceStatus(l.svcCtx, ins, true) | |||
| go updater.UpdateDeployTaskStatus(l.svcCtx) | |||
| //count total | |||
| var total int64 | |||
| err = tx.Count(&total).Error | |||
| tx.Limit(limit).Offset(offset) | |||
| if err != nil { | |||
| return resp, err | |||
| } | |||
| err = tx.Order("create_time desc").Find(&list).Error | |||
| if err != nil { | |||
| return nil, errors.New(err.Error()) | |||
| } | |||
| resp.List = &list | |||
| resp.List = &deployTasks | |||
| resp.PageSize = req.PageSize | |||
| resp.PageNum = req.PageNum | |||
| resp.Total = total | |||
| return | |||
| } | |||
| func (l *DeployInstanceListLogic) GenerateDeployTasks(tasklist []*models.AiDeployInstanceTask) []*DeployTask { | |||
| var tasks []*DeployTask | |||
| for _, t := range tasklist { | |||
| list, err := l.svcCtx.Scheduler.AiStorages.GetInstanceListByDeployTaskId(t.Id) | |||
| if err != nil { | |||
| logx.Errorf("db GetInstanceListByDeployTaskId error") | |||
| continue | |||
| } | |||
| deployTask := &DeployTask{ | |||
| Id: t.Id, | |||
| Name: t.Name, | |||
| Instances: list, | |||
| } | |||
| tasks = append(tasks, deployTask) | |||
| } | |||
| return tasks | |||
| } | |||
| type DeployTask struct { | |||
| Id int64 `json:"id,string"` | |||
| Name string `json:"name,string"` | |||
| Instances []*models.AiInferDeployInstance `json:"instances,string"` | |||
| } | |||
| @@ -3,8 +3,8 @@ package inference | |||
| import ( | |||
| "context" | |||
| "gitlink.org.cn/tzwang/pcm-coordinator/internal/svc" | |||
| "gitlink.org.cn/tzwang/pcm-coordinator/internal/types" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc" | |||
| "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" | |||
| "github.com/zeromicro/go-zero/core/logx" | |||
| ) | |||
| @@ -24,7 +24,7 @@ func NewGetRunningInstanceByModelLogic(ctx context.Context, svcCtx *svc.ServiceC | |||
| } | |||
| func (l *GetRunningInstanceByModelLogic) GetRunningInstanceByModel(req *types.GetRunningInstanceReq) (resp *types.GetRunningInstanceResp, err error) { | |||
| // todo: add your logic here and delete this line | |||
| resp = &types.GetRunningInstanceResp{} | |||
| return | |||
| } | |||
| @@ -97,3 +97,21 @@ func Contains(s []string, e string) bool { | |||
| } | |||
| return false | |||
| } | |||
| func ConcatMultipleSlices[T any](slices [][]T) []T { | |||
| var totalLen int | |||
| for _, s := range slices { | |||
| totalLen += len(s) | |||
| } | |||
| result := make([]T, totalLen) | |||
| var i int | |||
| for _, s := range slices { | |||
| i += copy(result[i:], s) | |||
| } | |||
| return result | |||
| } | |||
| @@ -485,6 +485,16 @@ func (s *AiStorage) GetInferDeployInstanceList() ([]*models.AiInferDeployInstanc | |||
| return list, nil | |||
| } | |||
| func (s *AiStorage) GetDeployTaskList() ([]*models.AiDeployInstanceTask, error) { | |||
| var list []*models.AiDeployInstanceTask | |||
| tx := s.DbEngin.Raw("select * from ai_deploy_instance_task").Scan(&list) | |||
| if tx.Error != nil { | |||
| logx.Errorf(tx.Error.Error()) | |||
| return nil, tx.Error | |||
| } | |||
| return list, nil | |||
| } | |||
| func (s *AiStorage) GetInferDeployInstanceTotalNum() (int32, error) { | |||
| var total int32 | |||
| tx := s.DbEngin.Raw("select count(*) from ai_infer_deploy_instance").Scan(&total) | |||
| @@ -563,3 +573,13 @@ func (s *AiStorage) SaveInferDeployTask(taskName string, modelName string, model | |||
| } | |||
| return taskModel.Id, nil | |||
| } | |||
| func (s *AiStorage) GetRunningDeployInstanceByModelNameAndAdapterId(modelType string, modelName string, adapterId string) ([]*models.AiInferDeployInstance, error) { | |||
| var list []*models.AiInferDeployInstance | |||
| tx := s.DbEngin.Raw("select * from ai_infer_deploy_instance where `model_type` = ? and `model_name` = ? and `adapter_id` = ? and `status` = 'Running'", modelType, modelName, adapterId).Scan(&list) | |||
| if tx.Error != nil { | |||
| logx.Errorf(tx.Error.Error()) | |||
| return nil, tx.Error | |||
| } | |||
| return list, nil | |||
| } | |||
| @@ -82,6 +82,7 @@ var ( | |||
| "image_classification": {"imagenet_resnet50"}, | |||
| "text_to_text": {"chatGLM_6B"}, | |||
| "image_to_text": {"blip-image-captioning-base"}, | |||
| "text_to_image": {"stable-diffusion-xl-base-1.0"}, | |||
| } | |||
| AITYPE = map[string]string{ | |||
| "1": OCTOPUS, | |||