You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

imageinferencelogic.go 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560
  1. package inference
  2. import "C"
  3. import (
  4. "bytes"
  5. "context"
  6. "crypto/tls"
  7. "errors"
  8. "fmt"
  9. "github.com/JCCE-nudt/apigw-go-sdk/core"
  10. "github.com/go-resty/resty/v2"
  11. "github.com/zeromicro/go-zero/core/logx"
  12. "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option"
  13. "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/collector"
  14. "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/strategy"
  15. "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc"
  16. "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types"
  17. "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants"
  18. "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
  19. "io"
  20. "k8s.io/apimachinery/pkg/util/json"
  21. "log"
  22. "math/rand"
  23. "mime/multipart"
  24. "net/http"
  25. "sort"
  26. "strconv"
  27. "sync"
  28. "time"
  29. )
  30. type ImageInferenceLogic struct {
  31. logx.Logger
  32. ctx context.Context
  33. svcCtx *svc.ServiceContext
  34. }
  35. func NewImageInferenceLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ImageInferenceLogic {
  36. return &ImageInferenceLogic{
  37. Logger: logx.WithContext(ctx),
  38. ctx: ctx,
  39. svcCtx: svcCtx,
  40. }
  41. }
  42. func (l *ImageInferenceLogic) ImageInference(req *types.ImageInferenceReq) (resp *types.ImageInferenceResp, err error) {
  43. return nil, nil
  44. }
  45. func (l *ImageInferenceLogic) ImageInfer(r *http.Request, req *types.ImageInferenceReq) (resp *types.ImageInferenceResp, err error) {
  46. resp = &types.ImageInferenceResp{}
  47. opt := &option.InferOption{
  48. TaskName: req.TaskName,
  49. TaskDesc: req.TaskDesc,
  50. AdapterId: req.AdapterId,
  51. AiClusterIds: req.AiClusterIds,
  52. ModelName: req.ModelName,
  53. ModelType: req.ModelType,
  54. Strategy: req.Strategy,
  55. StaticWeightMap: req.StaticWeightMap,
  56. }
  57. var ts []struct {
  58. imageResult *types.ImageResult
  59. file multipart.File
  60. }
  61. uploadedFiles := r.MultipartForm.File
  62. if len(uploadedFiles) == 0 {
  63. return nil, errors.New("Images does not exist")
  64. }
  65. if len(uploadedFiles["images"]) == 0 {
  66. return nil, errors.New("Images does not exist")
  67. }
  68. for _, header := range uploadedFiles["images"] {
  69. file, err := header.Open()
  70. if err != nil {
  71. return nil, err
  72. }
  73. defer file.Close()
  74. var ir types.ImageResult
  75. ir.ImageName = header.Filename
  76. t := struct {
  77. imageResult *types.ImageResult
  78. file multipart.File
  79. }{
  80. imageResult: &ir,
  81. file: file,
  82. }
  83. ts = append(ts, t)
  84. }
  85. _, ok := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId]
  86. if !ok {
  87. return nil, errors.New("AdapterId does not exist")
  88. }
  89. var strat strategy.Strategy
  90. switch opt.Strategy {
  91. case strategy.STATIC_WEIGHT:
  92. strat = strategy.NewStaticWeightStrategy(opt.StaticWeightMap, int32(len(ts)))
  93. if err != nil {
  94. return nil, err
  95. }
  96. default:
  97. return nil, errors.New("no strategy has been chosen")
  98. }
  99. clusters, err := strat.Schedule()
  100. if err != nil {
  101. return nil, err
  102. }
  103. results, err := infer(opt, clusters, ts, l.svcCtx, l.ctx)
  104. if err != nil {
  105. return nil, err
  106. }
  107. resp.InferResults = results
  108. return resp, nil
  109. }
  110. func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []struct {
  111. imageResult *types.ImageResult
  112. file multipart.File
  113. }, svcCtx *svc.ServiceContext, ctx context.Context) ([]*types.ImageResult, error) {
  114. if clusters == nil || len(clusters) == 0 {
  115. return nil, errors.New("clusters is nil")
  116. }
  117. for i := len(clusters) - 1; i >= 0; i-- {
  118. if clusters[i].Replicas == 0 {
  119. clusters = append(clusters[:i], clusters[i+1:]...)
  120. }
  121. }
  122. var wg sync.WaitGroup
  123. var cluster_ch = make(chan struct {
  124. urls []*collector.ImageInferUrl
  125. clusterId string
  126. clusterName string
  127. imageNum int32
  128. }, len(clusters))
  129. var cs []struct {
  130. urls []*collector.ImageInferUrl
  131. clusterId string
  132. clusterName string
  133. imageNum int32
  134. }
  135. collectorMap := svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId]
  136. //save task
  137. var synergystatus int64
  138. if len(clusters) > 1 {
  139. synergystatus = 1
  140. }
  141. strategyCode, err := svcCtx.Scheduler.AiStorages.GetStrategyCode(opt.Strategy)
  142. if err != nil {
  143. return nil, err
  144. }
  145. adapterName, err := svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId)
  146. if err != nil {
  147. return nil, err
  148. }
  149. id, err := svcCtx.Scheduler.AiStorages.SaveTask(opt.TaskName, strategyCode, synergystatus, "11")
  150. if err != nil {
  151. return nil, err
  152. }
  153. svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "create", "任务创建中")
  154. //save taskai
  155. for _, c := range clusters {
  156. clusterName, _ := svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId)
  157. opt.Replica = c.Replicas
  158. err := svcCtx.Scheduler.AiStorages.SaveAiTask(id, opt, adapterName, c.ClusterId, clusterName, "", constants.Saved, "")
  159. if err != nil {
  160. return nil, err
  161. }
  162. }
  163. for _, cluster := range clusters {
  164. wg.Add(1)
  165. c := cluster
  166. go func() {
  167. imageUrls, err := collectorMap[c.ClusterId].GetImageInferUrl(ctx, opt)
  168. if err != nil {
  169. wg.Done()
  170. return
  171. }
  172. clusterName, _ := svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId)
  173. s := struct {
  174. urls []*collector.ImageInferUrl
  175. clusterId string
  176. clusterName string
  177. imageNum int32
  178. }{
  179. urls: imageUrls,
  180. clusterId: c.ClusterId,
  181. clusterName: clusterName,
  182. imageNum: c.Replicas,
  183. }
  184. cluster_ch <- s
  185. wg.Done()
  186. return
  187. }()
  188. }
  189. wg.Wait()
  190. close(cluster_ch)
  191. for s := range cluster_ch {
  192. cs = append(cs, s)
  193. }
  194. var aiTaskList []*models.TaskAi
  195. tx := svcCtx.DbEngin.Raw("select * from task_ai where `task_id` = ? ", id).Scan(&aiTaskList)
  196. if tx.Error != nil {
  197. return nil, tx.Error
  198. }
  199. //no cluster available
  200. if len(cs) == 0 {
  201. for _, t := range aiTaskList {
  202. t.Status = constants.Failed
  203. err := svcCtx.Scheduler.AiStorages.UpdateAiTask(t)
  204. if err != nil {
  205. logx.Errorf(tx.Error.Error())
  206. }
  207. }
  208. svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败")
  209. return nil, errors.New("image infer task failed")
  210. }
  211. //change cluster status
  212. if len(clusters) != len(cs) {
  213. var acs []*strategy.AssignedCluster
  214. for _, cluster := range clusters {
  215. if contains(cs, cluster.ClusterId) {
  216. continue
  217. } else {
  218. var ac *strategy.AssignedCluster
  219. ac = cluster
  220. acs = append(acs, ac)
  221. }
  222. }
  223. // update failed cluster status
  224. for _, ac := range acs {
  225. for _, t := range aiTaskList {
  226. if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) {
  227. t.Status = constants.Failed
  228. err := svcCtx.Scheduler.AiStorages.UpdateAiTask(t)
  229. if err != nil {
  230. logx.Errorf(tx.Error.Error())
  231. }
  232. }
  233. }
  234. }
  235. }
  236. var result_ch = make(chan *types.ImageResult, len(ts))
  237. var results []*types.ImageResult
  238. var imageNumIdx int32 = 0
  239. var imageNumIdxEnd int32 = 0
  240. for _, c := range cs {
  241. new_images := make([]struct {
  242. imageResult *types.ImageResult
  243. file multipart.File
  244. }, len(ts))
  245. copy(new_images, ts)
  246. imageNumIdxEnd = imageNumIdxEnd + c.imageNum
  247. new_images = new_images[imageNumIdx:imageNumIdxEnd]
  248. imageNumIdx = imageNumIdx + c.imageNum
  249. wg.Add(len(new_images))
  250. go sendInferReq(new_images, c, &wg, result_ch)
  251. }
  252. wg.Wait()
  253. close(result_ch)
  254. for s := range result_ch {
  255. results = append(results, s)
  256. }
  257. //save ai sub tasks
  258. for _, r := range results {
  259. for _, task := range aiTaskList {
  260. if r.ClusterId == strconv.Itoa(int(task.ClusterId)) {
  261. taskAiSub := &models.TaskAiSub{
  262. Id: task.Id,
  263. ImageName: r.ImageName,
  264. Result: r.ImageResult,
  265. Card: r.Card,
  266. ClusterId: task.ClusterId,
  267. ClusterName: r.ClusterName,
  268. }
  269. tx := svcCtx.DbEngin.Save(&taskAiSub)
  270. if tx.Error != nil {
  271. logx.Errorf(err.Error())
  272. }
  273. }
  274. }
  275. }
  276. sort.Slice(results, func(p, q int) bool {
  277. return results[p].ClusterName < results[q].ClusterName
  278. })
  279. // update succeeded cluster status
  280. var successStatusCount int
  281. for _, c := range cs {
  282. for _, t := range aiTaskList {
  283. if c.clusterId == strconv.Itoa(int(t.ClusterId)) {
  284. t.Status = constants.Completed
  285. err := svcCtx.Scheduler.AiStorages.UpdateAiTask(t)
  286. if err != nil {
  287. logx.Errorf(tx.Error.Error())
  288. }
  289. successStatusCount++
  290. } else {
  291. continue
  292. }
  293. }
  294. }
  295. if len(cs) == successStatusCount {
  296. svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败")
  297. } else {
  298. svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "completed", "任务完成")
  299. }
  300. //save ai sub tasks
  301. for _, r := range results {
  302. for _, task := range aiTaskList {
  303. if r.ClusterId == strconv.Itoa(int(task.ClusterId)) {
  304. taskAiSub := models.TaskAiSub{
  305. TaskId: id,
  306. TaskName: task.Name,
  307. TaskAiId: task.TaskId,
  308. TaskAiName: task.Name,
  309. ImageName: r.ImageName,
  310. Result: r.ImageResult,
  311. Card: r.Card,
  312. ClusterId: task.ClusterId,
  313. ClusterName: r.ClusterName,
  314. }
  315. tx := svcCtx.DbEngin.Table("task_ai_sub").Create(&taskAiSub)
  316. if tx.Error != nil {
  317. logx.Errorf(err.Error())
  318. }
  319. }
  320. }
  321. }
  322. return results, nil
  323. }
  324. func sendInferReq(images []struct {
  325. imageResult *types.ImageResult
  326. file multipart.File
  327. }, cluster struct {
  328. urls []*collector.ImageInferUrl
  329. clusterId string
  330. clusterName string
  331. imageNum int32
  332. }, wg *sync.WaitGroup, ch chan<- *types.ImageResult) {
  333. for _, image := range images {
  334. go func(t struct {
  335. imageResult *types.ImageResult
  336. file multipart.File
  337. }, c struct {
  338. urls []*collector.ImageInferUrl
  339. clusterId string
  340. clusterName string
  341. imageNum int32
  342. }) {
  343. if len(c.urls) == 1 {
  344. r, err := getInferResult(c.urls[0].Url, t.file, t.imageResult.ImageName, c.clusterName)
  345. if err != nil {
  346. t.imageResult.ImageResult = err.Error()
  347. t.imageResult.ClusterId = c.clusterId
  348. t.imageResult.ClusterName = c.clusterName
  349. t.imageResult.Card = c.urls[0].Card
  350. ch <- t.imageResult
  351. wg.Done()
  352. return
  353. }
  354. t.imageResult.ImageResult = r
  355. t.imageResult.ClusterId = c.clusterId
  356. t.imageResult.ClusterName = c.clusterName
  357. t.imageResult.Card = c.urls[0].Card
  358. ch <- t.imageResult
  359. wg.Done()
  360. return
  361. } else {
  362. idx := rand.Intn(len(c.urls))
  363. r, err := getInferResult(c.urls[idx].Url, t.file, t.imageResult.ImageName, c.clusterName)
  364. if err != nil {
  365. t.imageResult.ImageResult = err.Error()
  366. t.imageResult.ClusterId = c.clusterId
  367. t.imageResult.ClusterName = c.clusterName
  368. t.imageResult.Card = c.urls[idx].Card
  369. ch <- t.imageResult
  370. wg.Done()
  371. return
  372. }
  373. t.imageResult.ImageResult = r
  374. t.imageResult.ClusterId = c.clusterId
  375. t.imageResult.ClusterName = c.clusterName
  376. t.imageResult.Card = c.urls[idx].Card
  377. ch <- t.imageResult
  378. wg.Done()
  379. return
  380. }
  381. }(image, cluster)
  382. }
  383. }
  384. func getInferResult(url string, file multipart.File, fileName string, clusterName string) (string, error) {
  385. if clusterName == "鹏城云脑II-modelarts" {
  386. r, err := getInferResultModelarts(url, file, fileName)
  387. if err != nil {
  388. return "", err
  389. }
  390. return r, nil
  391. }
  392. var res Res
  393. req := GetRestyRequest(20)
  394. _, err := req.
  395. SetFileReader("file", fileName, file).
  396. SetResult(&res).
  397. Post(url)
  398. if err != nil {
  399. return "", err
  400. }
  401. return res.Result, nil
  402. }
  403. func getInferResultModelarts(url string, file multipart.File, fileName string) (string, error) {
  404. var res Res
  405. body, err := SendRequest("POST", url, file, fileName)
  406. if err != nil {
  407. return "", err
  408. }
  409. errjson := json.Unmarshal([]byte(body), &res)
  410. if errjson != nil {
  411. log.Fatalf("Error parsing JSON: %s", errjson)
  412. }
  413. return res.Result, nil
  414. }
  415. // SignClient AK/SK签名认证
  416. func SignClient(r *http.Request, writer *multipart.Writer) (*http.Client, error) {
  417. r.Header.Add("content-type", "application/json;charset=UTF-8")
  418. r.Header.Add("X-Project-Id", "d18190e28e3f45a281ef0b0696ec9d52")
  419. r.Header.Add("x-stage", "RELEASE")
  420. r.Header.Add("x-sdk-content-sha256", "UNSIGNED-PAYLOAD")
  421. r.Header.Set("Content-Type", writer.FormDataContentType())
  422. s := core.Signer{
  423. Key: "UNEHPHO4Z7YSNPKRXFE4",
  424. Secret: "JWXCE9qcYbc7RjpSRIWt4WgG3ZKF6Q4lPzkJReX9",
  425. }
  426. err := s.Sign(r)
  427. if err != nil {
  428. return nil, err
  429. }
  430. //设置client信任所有证书
  431. tr := &http.Transport{
  432. TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
  433. }
  434. client := &http.Client{
  435. Transport: tr,
  436. }
  437. return client, nil
  438. }
  439. func SendRequest(method, url string, file multipart.File, fileName string) (string, error) {
  440. /*body := &bytes.Buffer{}
  441. writer := multipart.NewWriter(body)*/
  442. // 创建一个新的缓冲区以写入multipart表单
  443. var body bytes.Buffer
  444. // 创建一个新的multipart writer
  445. writer := multipart.NewWriter(&body)
  446. // 创建一个用于写入文件的表单字段
  447. part, err := writer.CreateFormFile("file", fileName) // "file"是表单的字段名,第二个参数是文件名
  448. if err != nil {
  449. fmt.Println("Error creating form file:", err)
  450. }
  451. // 将文件的内容拷贝到multipart writer中
  452. _, err = io.Copy(part, file)
  453. if err != nil {
  454. fmt.Println("Error copying file data:", err)
  455. }
  456. err = writer.Close()
  457. if err != nil {
  458. fmt.Println("Error closing multipart writer:", err)
  459. }
  460. request, err := http.NewRequest(method, url, &body)
  461. if err != nil {
  462. fmt.Println("Error creating new request:", err)
  463. //return nil, err
  464. }
  465. signedR, err := SignClient(request, writer)
  466. if err != nil {
  467. fmt.Println("Error signing request:", err)
  468. //return nil, err
  469. }
  470. res, err := signedR.Do(request)
  471. if err != nil {
  472. fmt.Println("Error sending request:", err)
  473. //return nil, err
  474. }
  475. defer res.Body.Close()
  476. Resbody, err := io.ReadAll(res.Body)
  477. if err != nil {
  478. fmt.Println("Error reading response body:", err)
  479. //return nil, err
  480. }
  481. return string(Resbody), nil
  482. }
  483. func GetRestyRequest(timeoutSeconds int64) *resty.Request {
  484. client := resty.New().SetTimeout(time.Duration(timeoutSeconds) * time.Second)
  485. request := client.R()
  486. return request
  487. }
  488. type Res struct {
  489. Result string `json:"result"`
  490. }
  491. func contains(cs []struct {
  492. urls []*collector.ImageInferUrl
  493. clusterId string
  494. clusterName string
  495. imageNum int32
  496. }, e string) bool {
  497. for _, c := range cs {
  498. if c.clusterId == e {
  499. return true
  500. }
  501. }
  502. return false
  503. }

PCM is positioned as Software stack over Cloud, aiming to build the standards and ecology of heterogeneous cloud collaboration for JCC in a non intrusive and autonomous peer-to-peer manner.