diff --git a/internal/logic/inference/createinferencetasklogic.go b/internal/logic/inference/createinferencetasklogic.go index 12ea55fd..2f9475ae 100644 --- a/internal/logic/inference/createinferencetasklogic.go +++ b/internal/logic/inference/createinferencetasklogic.go @@ -57,7 +57,7 @@ func (l *CreateInferenceTaskLogic) CreateInferenceTask(req *types.CreateInferenc return nil, err } - assignedClusters := task.CopyParams(clusters, req.JobResources.Clusters) + assignedClusters := task.CopyParams(clusters, req.JobResources.Clusters, "inference") opt := &option.InferOption{ TaskName: taskName, @@ -158,7 +158,7 @@ func updateInferOption(cluster *strategy.AssignedCluster, opt *option.InferOptio opt.ImageId = cluster.ImageId opt.AlgorithmId = cluster.CodeId - opt.ModelId = cluster.ModelId + opt.ModelID = cluster.ModelId opt.ResourcesRequired = cluster.ResourcesRequired diff --git a/internal/logic/schedule/schedulecreatetasklogic.go b/internal/logic/schedule/schedulecreatetasklogic.go index 64b3383b..36d65d32 100644 --- a/internal/logic/schedule/schedulecreatetasklogic.go +++ b/internal/logic/schedule/schedulecreatetasklogic.go @@ -162,7 +162,7 @@ func (l *ScheduleCreateTaskLogic) ScheduleCreateTask(req *types.CreateTaskReq) ( assignedClusters := task.CopyParams([]*strategy.AssignedCluster{{ ClusterId: req.JobResources.Clusters[0].ClusterID, Replicas: 1, - }}, req.JobResources.Clusters) + }}, req.JobResources.Clusters, "") // filter data distribution clustersWithDataDistributes := generateFilteredDataDistributes(assignedClusters, req.DataDistributes) @@ -244,14 +244,14 @@ func (l *ScheduleCreateTaskLogic) getAssignedClustersByStrategy(resources *types if err != nil { return nil, err } - assignedClusters = task.CopyParams(clusters, resources.Clusters) + assignedClusters = task.CopyParams(clusters, resources.Clusters, "") case strategy.DATA_LOCALITY: strtg := strategy.NewDataLocality(TRAINNING_TASK_REPLICA, dataDistribute) clusters, err := strtg.Schedule() if err != nil { return nil, err } - assignedClusters = task.CopyParams(clusters, resources.Clusters) + assignedClusters = task.CopyParams(clusters, resources.Clusters, "") default: return nil, errors.New("no strategy has been chosen") } diff --git a/internal/scheduler/schedulers/option/inferOption.go b/internal/scheduler/schedulers/option/inferOption.go index 489a57c4..0e8ad69c 100644 --- a/internal/scheduler/schedulers/option/inferOption.go +++ b/internal/scheduler/schedulers/option/inferOption.go @@ -19,6 +19,7 @@ type InferOption struct { ResourceId string AlgorithmId string ImageId string + ModelID string Output string diff --git a/internal/scheduler/service/utils/task/taskParamChecker.go b/internal/scheduler/service/utils/task/taskParamChecker.go index 2fdc57fb..084e0118 100644 --- a/internal/scheduler/service/utils/task/taskParamChecker.go +++ b/internal/scheduler/service/utils/task/taskParamChecker.go @@ -34,7 +34,7 @@ func ValidateJobResources(resources types.JobResources, taskType string) error { return nil } -func CopyParams(clusters []*strategy.AssignedCluster, clusterInfos []*types.JobClusterInfo) []*strategy.AssignedCluster { +func CopyParams(clusters []*strategy.AssignedCluster, clusterInfos []*types.JobClusterInfo, taskType string) []*strategy.AssignedCluster { var result []*strategy.AssignedCluster for _, c := range clusters { @@ -69,5 +69,19 @@ func CopyParams(clusters []*strategy.AssignedCluster, clusterInfos []*types.JobC } } } + + if taskType == "inference" { + for _, c := range clusters { + for _, r := range result { + if c.ClusterId == r.ClusterId { + r.ModelId = c.ModelId + r.ModelName = c.ModelName + r.ImageId = c.ImageId + r.CodeId = c.CodeId + } + } + } + } + return result }