|
|
|
@@ -2,11 +2,18 @@ package inference |
|
|
|
|
|
|
|
import ( |
|
|
|
"context" |
|
|
|
|
|
|
|
"errors" |
|
|
|
"github.com/go-resty/resty/v2" |
|
|
|
"github.com/zeromicro/go-zero/core/logx" |
|
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option" |
|
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/collector" |
|
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/strategy" |
|
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" |
|
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" |
|
|
|
|
|
|
|
"github.com/zeromicro/go-zero/core/logx" |
|
|
|
"math/rand" |
|
|
|
"mime/multipart" |
|
|
|
"net/http" |
|
|
|
"sync" |
|
|
|
) |
|
|
|
|
|
|
|
type ImageInferenceLogic struct { |
|
|
|
@@ -24,7 +31,250 @@ func NewImageInferenceLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Im |
|
|
|
} |
|
|
|
|
|
|
|
func (l *ImageInferenceLogic) ImageInference(req *types.ImageInferenceReq) (resp *types.ImageInferenceResp, err error) { |
|
|
|
// todo: add your logic here and delete this line |
|
|
|
return nil, nil |
|
|
|
} |
|
|
|
|
|
|
|
func (l *ImageInferenceLogic) ImageInfer(r *http.Request, req *types.ImageInferenceReq) (resp *types.ImageInferenceResp, err error) { |
|
|
|
resp = &types.ImageInferenceResp{} |
|
|
|
opt := &option.InferOption{ |
|
|
|
TaskName: req.TaskName, |
|
|
|
TaskDesc: req.TaskDesc, |
|
|
|
AdapterId: req.AdapterId, |
|
|
|
AiClusterIds: req.AiClusterIds, |
|
|
|
ModelName: req.ModelName, |
|
|
|
ModelType: req.ModelType, |
|
|
|
Strategy: req.Strategy, |
|
|
|
StaticWeightMap: req.StaticWeightMap, |
|
|
|
} |
|
|
|
|
|
|
|
var ts []struct { |
|
|
|
imageResult *types.ImageResult |
|
|
|
file multipart.File |
|
|
|
} |
|
|
|
|
|
|
|
uploadedFiles := r.MultipartForm.File |
|
|
|
|
|
|
|
if len(uploadedFiles) == 0 { |
|
|
|
return nil, errors.New("Images does not exist") |
|
|
|
} |
|
|
|
|
|
|
|
if len(uploadedFiles["images"]) == 0 { |
|
|
|
return nil, errors.New("Images does not exist") |
|
|
|
} |
|
|
|
|
|
|
|
for _, header := range uploadedFiles["images"] { |
|
|
|
file, err := header.Open() |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
defer file.Close() |
|
|
|
var ir types.ImageResult |
|
|
|
ir.ImageName = header.Filename |
|
|
|
t := struct { |
|
|
|
imageResult *types.ImageResult |
|
|
|
file multipart.File |
|
|
|
}{ |
|
|
|
imageResult: &ir, |
|
|
|
file: file, |
|
|
|
} |
|
|
|
ts = append(ts, t) |
|
|
|
} |
|
|
|
|
|
|
|
_, ok := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId] |
|
|
|
if !ok { |
|
|
|
return nil, errors.New("AdapterId does not exist") |
|
|
|
} |
|
|
|
|
|
|
|
var strat strategy.Strategy |
|
|
|
switch opt.Strategy { |
|
|
|
case strategy.STATIC_WEIGHT: |
|
|
|
//todo resources should match cluster StaticWeightMap |
|
|
|
strat = strategy.NewStaticWeightStrategy(opt.StaticWeightMap, int32(len(ts))) |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
default: |
|
|
|
return nil, errors.New("no strategy has been chosen") |
|
|
|
} |
|
|
|
clusters, err := strat.Schedule() |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
|
|
|
|
results, err := infer(opt, clusters, ts, l.svcCtx, l.ctx) |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
resp.InferResults = results |
|
|
|
|
|
|
|
return resp, nil |
|
|
|
} |
|
|
|
|
|
|
|
func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []struct { |
|
|
|
imageResult *types.ImageResult |
|
|
|
file multipart.File |
|
|
|
}, svcCtx *svc.ServiceContext, ctx context.Context) ([]*types.ImageResult, error) { |
|
|
|
|
|
|
|
if clusters == nil || len(clusters) == 0 { |
|
|
|
return nil, errors.New("clusters is nil") |
|
|
|
} |
|
|
|
|
|
|
|
for i := len(clusters) - 1; i >= 0; i-- { |
|
|
|
if clusters[i].Replicas == 0 { |
|
|
|
clusters = append(clusters[:i], clusters[i+1:]...) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
var wg sync.WaitGroup |
|
|
|
var cluster_ch = make(chan struct { |
|
|
|
urls []*collector.ImageInferUrl |
|
|
|
clusterName string |
|
|
|
imageNum int32 |
|
|
|
}, len(clusters)) |
|
|
|
|
|
|
|
var cs []struct { |
|
|
|
urls []*collector.ImageInferUrl |
|
|
|
clusterName string |
|
|
|
imageNum int32 |
|
|
|
} |
|
|
|
collectorMap := svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId] |
|
|
|
|
|
|
|
for _, cluster := range clusters { |
|
|
|
wg.Add(1) |
|
|
|
c := cluster |
|
|
|
go func() { |
|
|
|
imageUrls, err := collectorMap[c.ClusterId].GetImageInferUrl(ctx, opt) |
|
|
|
if err != nil { |
|
|
|
return |
|
|
|
} |
|
|
|
clusterName, _ := svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId) |
|
|
|
|
|
|
|
s := struct { |
|
|
|
urls []*collector.ImageInferUrl |
|
|
|
clusterName string |
|
|
|
imageNum int32 |
|
|
|
}{ |
|
|
|
urls: imageUrls, |
|
|
|
clusterName: clusterName, |
|
|
|
imageNum: c.Replicas, |
|
|
|
} |
|
|
|
|
|
|
|
cluster_ch <- s |
|
|
|
wg.Done() |
|
|
|
return |
|
|
|
}() |
|
|
|
} |
|
|
|
wg.Wait() |
|
|
|
close(cluster_ch) |
|
|
|
|
|
|
|
for s := range cluster_ch { |
|
|
|
cs = append(cs, s) |
|
|
|
} |
|
|
|
|
|
|
|
var result_ch = make(chan *types.ImageResult, len(ts)) |
|
|
|
var results []*types.ImageResult |
|
|
|
|
|
|
|
wg.Add(len(ts)) |
|
|
|
|
|
|
|
var imageNumIdx int32 = 0 |
|
|
|
var imageNumIdxEnd int32 = 0 |
|
|
|
for _, c := range cs { |
|
|
|
new_images := make([]struct { |
|
|
|
imageResult *types.ImageResult |
|
|
|
file multipart.File |
|
|
|
}, len(ts)) |
|
|
|
copy(new_images, ts) |
|
|
|
|
|
|
|
imageNumIdxEnd = imageNumIdxEnd + c.imageNum |
|
|
|
new_images = new_images[imageNumIdx:imageNumIdxEnd] |
|
|
|
imageNumIdx = imageNumIdx + c.imageNum |
|
|
|
|
|
|
|
go sendInferReq(new_images, c, &wg, result_ch) |
|
|
|
} |
|
|
|
wg.Wait() |
|
|
|
|
|
|
|
close(result_ch) |
|
|
|
|
|
|
|
for s := range result_ch { |
|
|
|
results = append(results, s) |
|
|
|
} |
|
|
|
|
|
|
|
return results, nil |
|
|
|
} |
|
|
|
|
|
|
|
func sendInferReq(images []struct { |
|
|
|
imageResult *types.ImageResult |
|
|
|
file multipart.File |
|
|
|
}, cluster struct { |
|
|
|
urls []*collector.ImageInferUrl |
|
|
|
clusterName string |
|
|
|
imageNum int32 |
|
|
|
}, wg *sync.WaitGroup, ch chan<- *types.ImageResult) { |
|
|
|
for _, image := range images { |
|
|
|
go func(t struct { |
|
|
|
imageResult *types.ImageResult |
|
|
|
file multipart.File |
|
|
|
}, c struct { |
|
|
|
urls []*collector.ImageInferUrl |
|
|
|
clusterName string |
|
|
|
imageNum int32 |
|
|
|
}) { |
|
|
|
if len(c.urls) == 1 { |
|
|
|
r, err := getInferResult(c.urls[0].Url, t.file, t.imageResult.ImageName) |
|
|
|
if err != nil { |
|
|
|
t.imageResult.ImageResult = err.Error() |
|
|
|
ch <- t.imageResult |
|
|
|
wg.Done() |
|
|
|
return |
|
|
|
} |
|
|
|
t.imageResult.ImageResult = r |
|
|
|
t.imageResult.ClusterName = c.clusterName |
|
|
|
t.imageResult.Card = c.urls[0].Card |
|
|
|
|
|
|
|
ch <- t.imageResult |
|
|
|
wg.Done() |
|
|
|
return |
|
|
|
} else { |
|
|
|
idx := rand.Intn(len(c.urls)) |
|
|
|
r, err := getInferResult(c.urls[idx].Url, t.file, t.imageResult.ImageName) |
|
|
|
if err != nil { |
|
|
|
t.imageResult.ImageResult = err.Error() |
|
|
|
ch <- t.imageResult |
|
|
|
wg.Done() |
|
|
|
return |
|
|
|
} |
|
|
|
t.imageResult.ImageResult = r |
|
|
|
t.imageResult.ClusterName = c.clusterName |
|
|
|
t.imageResult.Card = c.urls[idx].Card |
|
|
|
|
|
|
|
ch <- t.imageResult |
|
|
|
wg.Done() |
|
|
|
return |
|
|
|
} |
|
|
|
}(image, cluster) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
func getInferResult(url string, file multipart.File, fileName string) (string, error) { |
|
|
|
var res Res |
|
|
|
req := GetACHttpRequest() |
|
|
|
_, err := req. |
|
|
|
SetFileReader("file", fileName, file). |
|
|
|
SetResult(&res). |
|
|
|
Post(url) |
|
|
|
|
|
|
|
if err != nil { |
|
|
|
return "", err |
|
|
|
} |
|
|
|
return res.Result, nil |
|
|
|
} |
|
|
|
|
|
|
|
func GetACHttpRequest() *resty.Request { |
|
|
|
client := resty.New() |
|
|
|
request := client.R() |
|
|
|
return request |
|
|
|
} |
|
|
|
|
|
|
|
return |
|
|
|
type Res struct { |
|
|
|
Result string `json:"result"` |
|
|
|
} |