You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

imageinferencelogic.go 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. package inference
  2. import (
  3. "context"
  4. "errors"
  5. "github.com/go-resty/resty/v2"
  6. "github.com/zeromicro/go-zero/core/logx"
  7. "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option"
  8. "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/collector"
  9. "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/strategy"
  10. "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc"
  11. "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types"
  12. "math/rand"
  13. "mime/multipart"
  14. "net/http"
  15. "sync"
  16. )
  17. type ImageInferenceLogic struct {
  18. logx.Logger
  19. ctx context.Context
  20. svcCtx *svc.ServiceContext
  21. }
  22. func NewImageInferenceLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ImageInferenceLogic {
  23. return &ImageInferenceLogic{
  24. Logger: logx.WithContext(ctx),
  25. ctx: ctx,
  26. svcCtx: svcCtx,
  27. }
  28. }
  29. func (l *ImageInferenceLogic) ImageInference(req *types.ImageInferenceReq) (resp *types.ImageInferenceResp, err error) {
  30. return nil, nil
  31. }
  32. func (l *ImageInferenceLogic) ImageInfer(r *http.Request, req *types.ImageInferenceReq) (resp *types.ImageInferenceResp, err error) {
  33. resp = &types.ImageInferenceResp{}
  34. opt := &option.InferOption{
  35. TaskName: req.TaskName,
  36. TaskDesc: req.TaskDesc,
  37. AdapterId: req.AdapterId,
  38. AiClusterIds: req.AiClusterIds,
  39. ModelName: req.ModelName,
  40. ModelType: req.ModelType,
  41. Strategy: req.Strategy,
  42. StaticWeightMap: req.StaticWeightMap,
  43. }
  44. var ts []struct {
  45. imageResult *types.ImageResult
  46. file multipart.File
  47. }
  48. uploadedFiles := r.MultipartForm.File
  49. if len(uploadedFiles) == 0 {
  50. return nil, errors.New("Images does not exist")
  51. }
  52. if len(uploadedFiles["images"]) == 0 {
  53. return nil, errors.New("Images does not exist")
  54. }
  55. for _, header := range uploadedFiles["images"] {
  56. file, err := header.Open()
  57. if err != nil {
  58. return nil, err
  59. }
  60. defer file.Close()
  61. var ir types.ImageResult
  62. ir.ImageName = header.Filename
  63. t := struct {
  64. imageResult *types.ImageResult
  65. file multipart.File
  66. }{
  67. imageResult: &ir,
  68. file: file,
  69. }
  70. ts = append(ts, t)
  71. }
  72. _, ok := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId]
  73. if !ok {
  74. return nil, errors.New("AdapterId does not exist")
  75. }
  76. var strat strategy.Strategy
  77. switch opt.Strategy {
  78. case strategy.STATIC_WEIGHT:
  79. //todo resources should match cluster StaticWeightMap
  80. strat = strategy.NewStaticWeightStrategy(opt.StaticWeightMap, int32(len(ts)))
  81. if err != nil {
  82. return nil, err
  83. }
  84. default:
  85. return nil, errors.New("no strategy has been chosen")
  86. }
  87. clusters, err := strat.Schedule()
  88. if err != nil {
  89. return nil, err
  90. }
  91. results, err := infer(opt, clusters, ts, l.svcCtx, l.ctx)
  92. if err != nil {
  93. return nil, err
  94. }
  95. resp.InferResults = results
  96. return resp, nil
  97. }
  98. func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []struct {
  99. imageResult *types.ImageResult
  100. file multipart.File
  101. }, svcCtx *svc.ServiceContext, ctx context.Context) ([]*types.ImageResult, error) {
  102. if clusters == nil || len(clusters) == 0 {
  103. return nil, errors.New("clusters is nil")
  104. }
  105. for i := len(clusters) - 1; i >= 0; i-- {
  106. if clusters[i].Replicas == 0 {
  107. clusters = append(clusters[:i], clusters[i+1:]...)
  108. }
  109. }
  110. var wg sync.WaitGroup
  111. var cluster_ch = make(chan struct {
  112. urls []*collector.ImageInferUrl
  113. clusterName string
  114. imageNum int32
  115. }, len(clusters))
  116. var cs []struct {
  117. urls []*collector.ImageInferUrl
  118. clusterName string
  119. imageNum int32
  120. }
  121. collectorMap := svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId]
  122. for _, cluster := range clusters {
  123. wg.Add(1)
  124. c := cluster
  125. go func() {
  126. imageUrls, err := collectorMap[c.ClusterId].GetImageInferUrl(ctx, opt)
  127. if err != nil {
  128. return
  129. }
  130. clusterName, _ := svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId)
  131. s := struct {
  132. urls []*collector.ImageInferUrl
  133. clusterName string
  134. imageNum int32
  135. }{
  136. urls: imageUrls,
  137. clusterName: clusterName,
  138. imageNum: c.Replicas,
  139. }
  140. cluster_ch <- s
  141. wg.Done()
  142. return
  143. }()
  144. }
  145. wg.Wait()
  146. close(cluster_ch)
  147. for s := range cluster_ch {
  148. cs = append(cs, s)
  149. }
  150. var result_ch = make(chan *types.ImageResult, len(ts))
  151. var results []*types.ImageResult
  152. wg.Add(len(ts))
  153. var imageNumIdx int32 = 0
  154. var imageNumIdxEnd int32 = 0
  155. for _, c := range cs {
  156. new_images := make([]struct {
  157. imageResult *types.ImageResult
  158. file multipart.File
  159. }, len(ts))
  160. copy(new_images, ts)
  161. imageNumIdxEnd = imageNumIdxEnd + c.imageNum
  162. new_images = new_images[imageNumIdx:imageNumIdxEnd]
  163. imageNumIdx = imageNumIdx + c.imageNum
  164. go sendInferReq(new_images, c, &wg, result_ch)
  165. }
  166. wg.Wait()
  167. close(result_ch)
  168. for s := range result_ch {
  169. results = append(results, s)
  170. }
  171. return results, nil
  172. }
  173. func sendInferReq(images []struct {
  174. imageResult *types.ImageResult
  175. file multipart.File
  176. }, cluster struct {
  177. urls []*collector.ImageInferUrl
  178. clusterName string
  179. imageNum int32
  180. }, wg *sync.WaitGroup, ch chan<- *types.ImageResult) {
  181. for _, image := range images {
  182. go func(t struct {
  183. imageResult *types.ImageResult
  184. file multipart.File
  185. }, c struct {
  186. urls []*collector.ImageInferUrl
  187. clusterName string
  188. imageNum int32
  189. }) {
  190. if len(c.urls) == 1 {
  191. r, err := getInferResult(c.urls[0].Url, t.file, t.imageResult.ImageName)
  192. if err != nil {
  193. t.imageResult.ImageResult = err.Error()
  194. ch <- t.imageResult
  195. wg.Done()
  196. return
  197. }
  198. t.imageResult.ImageResult = r
  199. t.imageResult.ClusterName = c.clusterName
  200. t.imageResult.Card = c.urls[0].Card
  201. ch <- t.imageResult
  202. wg.Done()
  203. return
  204. } else {
  205. idx := rand.Intn(len(c.urls))
  206. r, err := getInferResult(c.urls[idx].Url, t.file, t.imageResult.ImageName)
  207. if err != nil {
  208. t.imageResult.ImageResult = err.Error()
  209. ch <- t.imageResult
  210. wg.Done()
  211. return
  212. }
  213. t.imageResult.ImageResult = r
  214. t.imageResult.ClusterName = c.clusterName
  215. t.imageResult.Card = c.urls[idx].Card
  216. ch <- t.imageResult
  217. wg.Done()
  218. return
  219. }
  220. }(image, cluster)
  221. }
  222. }
  223. func getInferResult(url string, file multipart.File, fileName string) (string, error) {
  224. var res Res
  225. req := GetACHttpRequest()
  226. _, err := req.
  227. SetFileReader("file", fileName, file).
  228. SetResult(&res).
  229. Post(url)
  230. if err != nil {
  231. return "", err
  232. }
  233. return res.Result, nil
  234. }
  235. func GetACHttpRequest() *resty.Request {
  236. client := resty.New()
  237. request := client.R()
  238. return request
  239. }
  240. type Res struct {
  241. Result string `json:"result"`
  242. }

PCM is positioned as Software stack over Cloud, aiming to build the standards and ecology of heterogeneous cloud collaboration for JCC in a non intrusive and autonomous peer-to-peer manner.