diff --git a/internal/logic/schedule/schedulecreatetasklogic.go b/internal/logic/schedule/schedulecreatetasklogic.go index 2b52eac3..3d73f4d7 100644 --- a/internal/logic/schedule/schedulecreatetasklogic.go +++ b/internal/logic/schedule/schedulecreatetasklogic.go @@ -25,6 +25,11 @@ const ( QUERY_RESOURCE_RETRY = 3 ) +type ClustersWithDataDistributes struct { + Clusters []*strategy.AssignedCluster + DataDistributes *types.DataDistribute +} + type ScheduleCreateTaskLogic struct { logx.Logger ctx context.Context @@ -41,6 +46,98 @@ func NewScheduleCreateTaskLogic(ctx context.Context, svcCtx *svc.ServiceContext) } } +func generateFilteredDataDistributes(clusters []*strategy.AssignedCluster, distribute types.DataDistribute) *ClustersWithDataDistributes { + + var clusterIds []string + for _, c := range clusters { + clusterIds = append(clusterIds, c.ClusterId) + } + + clustersWithDataDistributes := &ClustersWithDataDistributes{ + Clusters: clusters, + DataDistributes: &types.DataDistribute{ + Dataset: make([]*types.DatasetDistribute, 0), + Image: make([]*types.ImageDistribute, 0), + Model: make([]*types.ModelDistribute, 0), + Code: make([]*types.CodeDistribute, 0), + }, + } + + for _, datasetDistribute := range distribute.Dataset { + dataset := &types.DatasetDistribute{} + dataset.DataName = datasetDistribute.DataName + dataset.PackageID = datasetDistribute.PackageID + clusterScheduledList := make([]*types.ClusterScheduled, 0) + + if len(datasetDistribute.Clusters) != 0 { + for _, cluster := range datasetDistribute.Clusters { + if slices.Contains(clusterIds, cluster.ClusterID) { + clusterScheduledList = append(clusterScheduledList, cluster) + } + } + } + + dataset.Clusters = clusterScheduledList + clustersWithDataDistributes.DataDistributes.Dataset = append(clustersWithDataDistributes.DataDistributes.Dataset, dataset) + } + + for _, imageDistribute := range distribute.Image { + image := &types.ImageDistribute{} + image.DataName = imageDistribute.DataName + image.PackageID = imageDistribute.PackageID + clusterScheduledList := make([]*types.ClusterScheduled, 0) + + if len(imageDistribute.Clusters) != 0 { + for _, cluster := range imageDistribute.Clusters { + if slices.Contains(clusterIds, cluster.ClusterID) { + clusterScheduledList = append(clusterScheduledList, cluster) + } + } + } + + image.Clusters = clusterScheduledList + clustersWithDataDistributes.DataDistributes.Image = append(clustersWithDataDistributes.DataDistributes.Image, image) + } + + for _, codeDistribute := range distribute.Code { + code := &types.CodeDistribute{} + code.DataName = codeDistribute.DataName + code.PackageID = codeDistribute.PackageID + clusterScheduledList := make([]*types.ClusterScheduled, 0) + + if len(codeDistribute.Clusters) != 0 { + for _, cluster := range codeDistribute.Clusters { + if slices.Contains(clusterIds, cluster.ClusterID) { + clusterScheduledList = append(clusterScheduledList, cluster) + } + } + } + + code.Clusters = clusterScheduledList + clustersWithDataDistributes.DataDistributes.Code = append(clustersWithDataDistributes.DataDistributes.Code, code) + } + + for _, modelDistribute := range distribute.Model { + model := &types.ModelDistribute{} + model.DataName = modelDistribute.DataName + model.PackageID = modelDistribute.PackageID + clusterScheduledList := make([]*types.ClusterScheduled, 0) + + if len(modelDistribute.Clusters) != 0 { + for _, cluster := range modelDistribute.Clusters { + if slices.Contains(clusterIds, cluster.ClusterID) { + clusterScheduledList = append(clusterScheduledList, cluster) + } + } + } + + model.Clusters = clusterScheduledList + clustersWithDataDistributes.DataDistributes.Model = append(clustersWithDataDistributes.DataDistributes.Model, model) + } + + return clustersWithDataDistributes +} + func (l *ScheduleCreateTaskLogic) ScheduleCreateTask(req *types.CreateTaskReq) (resp *types.CreateTaskResp, err error) { resp = &types.CreateTaskResp{} @@ -66,7 +163,10 @@ func (l *ScheduleCreateTaskLogic) ScheduleCreateTask(req *types.CreateTaskReq) ( ClusterId: req.JobResources.Clusters[0].ClusterID, }}, req.JobResources.Clusters) - taskId, err := l.createTask(taskName, req.Description, req.JobResources.ScheduleStrategy, assignedClusters, req.Token) + // filter data distribution + clustersWithDataDistributes := generateFilteredDataDistributes(assignedClusters, req.DataDistributes) + + taskId, err := l.createTask(taskName, req.Description, req.JobResources.ScheduleStrategy, clustersWithDataDistributes, req.Token) if err != nil { return nil, err } @@ -92,7 +192,11 @@ func (l *ScheduleCreateTaskLogic) ScheduleCreateTask(req *types.CreateTaskReq) ( if err != nil { return nil, err } - taskId, err := l.createTask(taskName, req.Description, req.JobResources.ScheduleStrategy, assignedClusters, req.Token) + + // filter data distribution + clustersWithDataDistributes := generateFilteredDataDistributes(assignedClusters, req.DataDistributes) + + taskId, err := l.createTask(taskName, req.Description, req.JobResources.ScheduleStrategy, clustersWithDataDistributes, req.Token) if err != nil { return nil, err } @@ -228,13 +332,13 @@ func copyParams(clusters []*strategy.AssignedCluster, clusterInfos []*types.JobC return result } -func (l *ScheduleCreateTaskLogic) createTask(taskName string, desc string, strategyName string, clusters []*strategy.AssignedCluster, token string) (int64, error) { +func (l *ScheduleCreateTaskLogic) createTask(taskName string, desc string, strategyName string, clustersWithDataDistributes *ClustersWithDataDistributes, token string) (int64, error) { var synergyStatus int64 - if len(clusters) > 1 { + if len(clustersWithDataDistributes.Clusters) > 1 { synergyStatus = 1 } - y, err := yaml.Marshal(clusters) + y, err := yaml.Marshal(clustersWithDataDistributes) if err != nil { fmt.Printf("Error while Marshaling. %v", err) } diff --git a/internal/logic/schedule/scheduleruntasklogic.go b/internal/logic/schedule/scheduleruntasklogic.go index d04becc2..01fd7c3e 100644 --- a/internal/logic/schedule/scheduleruntasklogic.go +++ b/internal/logic/schedule/scheduleruntasklogic.go @@ -47,8 +47,8 @@ func (l *ScheduleRunTaskLogic) ScheduleRunTask(req *types.RunTaskReq) (resp *typ return nil, errors.New("task has been cancelled ") } - var clusters []*strategy.AssignedCluster - err = yaml.Unmarshal([]byte(task.YamlString), &clusters) + var clustersWithDataDistributes ClustersWithDataDistributes + err = yaml.Unmarshal([]byte(task.YamlString), &clustersWithDataDistributes) if err != nil { return nil, err } @@ -58,8 +58,9 @@ func (l *ScheduleRunTaskLogic) ScheduleRunTask(req *types.RunTaskReq) (resp *typ TaskName: task.Name, StrategyName: "", } + // update assignedClusters - err = updateClustersByScheduledDatas(task.Id, &clusters, req.ScheduledDatas) + assignedClusters, err := updateClustersByScheduledDatas(task.Id, &clustersWithDataDistributes, req.ScheduledDatas) if err != nil { return nil, err } @@ -69,7 +70,7 @@ func (l *ScheduleRunTaskLogic) ScheduleRunTask(req *types.RunTaskReq) (resp *typ return nil, err } - results, err := l.svcCtx.Scheduler.AssignAndSchedule(aiSchdl, executor.SUBMIT_MODE_STORAGE_SCHEDULE, clusters) + results, err := l.svcCtx.Scheduler.AssignAndSchedule(aiSchdl, executor.SUBMIT_MODE_STORAGE_SCHEDULE, assignedClusters) if err != nil { return nil, err } @@ -111,8 +112,10 @@ func (l *ScheduleRunTaskLogic) SaveResult(task *models.Task, results []*schedule } -func updateClustersByScheduledDatas(taskId int64, assignedClusters *[]*strategy.AssignedCluster, scheduledDatas []*types.DataScheduleResults) error { - for _, cluster := range *assignedClusters { +func updateClustersByScheduledDatas(taskId int64, clustersWithDataDistributes *ClustersWithDataDistributes, scheduledDatas []*types.DataScheduleResults) ([]*strategy.AssignedCluster, error) { + assignedClusters := make([]*strategy.AssignedCluster, 0) + // handle pass-in scheduledDatas + for _, cluster := range clustersWithDataDistributes.Clusters { for _, data := range scheduledDatas { switch data.DataType { case "dataset": @@ -131,7 +134,7 @@ func updateClustersByScheduledDatas(taskId int64, assignedClusters *[]*strategy. }{} err := json.Unmarshal([]byte(c.JsonData), &jsonData) if err != nil { - return fmt.Errorf("jsonData convert failed, task %d, cluster %s, datatype %s", taskId, cluster.ClusterId, "dataset") + return nil, fmt.Errorf("pass-in jsonData convert failed, task %d, cluster %s, datatype %s", taskId, cluster.ClusterId, "dataset") } cluster.DatasetId = jsonData.Id } @@ -153,7 +156,7 @@ func updateClustersByScheduledDatas(taskId int64, assignedClusters *[]*strategy. }{} err := json.Unmarshal([]byte(c.JsonData), &jsonData) if err != nil { - return fmt.Errorf("jsonData convert failed, task %d, cluster %s, datatype %s", taskId, cluster.ClusterId, "image") + return nil, fmt.Errorf("pass-in jsonData convert failed, task %d, cluster %s, datatype %s", taskId, cluster.ClusterId, "image") } cluster.ImageId = jsonData.Id } @@ -175,7 +178,7 @@ func updateClustersByScheduledDatas(taskId int64, assignedClusters *[]*strategy. }{} err := json.Unmarshal([]byte(c.JsonData), &jsonData) if err != nil { - return fmt.Errorf("jsonData convert failed, task %d, cluster %s, datatype %s", taskId, cluster.ClusterId, "code") + return nil, fmt.Errorf("pass-in jsonData convert failed, task %d, cluster %s, datatype %s", taskId, cluster.ClusterId, "code") } cluster.CodeId = jsonData.Id } @@ -197,7 +200,7 @@ func updateClustersByScheduledDatas(taskId int64, assignedClusters *[]*strategy. }{} err := json.Unmarshal([]byte(c.JsonData), &jsonData) if err != nil { - return fmt.Errorf("jsonData convert failed, task %d, cluster %s, datatype %s", taskId, cluster.ClusterId, "model") + return nil, fmt.Errorf("pass-in jsonData convert failed, task %d, cluster %s, datatype %s", taskId, cluster.ClusterId, "model") } cluster.ModelId = jsonData.Id } @@ -205,21 +208,110 @@ func updateClustersByScheduledDatas(taskId int64, assignedClusters *[]*strategy. } } } + assignedClusters = append(assignedClusters, cluster) + } + + // handle db yaml clustersWithDataDistributes + for _, cluster := range assignedClusters { + if cluster.DatasetId == "" { + for _, distribute := range clustersWithDataDistributes.DataDistributes.Dataset { + for _, c := range distribute.Clusters { + if cluster.ClusterId == c.ClusterID { + if c.JsonData == "" { + continue + } + jsonData := struct { + Name string `json:"name"` + Id string `json:"id"` + }{} + err := json.Unmarshal([]byte(c.JsonData), &jsonData) + if err != nil { + return nil, fmt.Errorf("db yaml jsonData convert failed, task %d, cluster %s, datatype %s", taskId, cluster.ClusterId, "dataset") + } + cluster.DatasetId = jsonData.Id + } + } + } + } + + if cluster.ImageId == "" { + for _, distribute := range clustersWithDataDistributes.DataDistributes.Image { + for _, c := range distribute.Clusters { + if cluster.ClusterId == c.ClusterID { + if c.JsonData == "" { + continue + } + jsonData := struct { + Name string `json:"name"` + Id string `json:"id"` + }{} + err := json.Unmarshal([]byte(c.JsonData), &jsonData) + if err != nil { + return nil, fmt.Errorf("db yaml jsonData convert failed, task %d, cluster %s, datatype %s", taskId, cluster.ClusterId, "image") + } + cluster.ImageId = jsonData.Id + } + } + } + } + + if cluster.CodeId == "" { + for _, distribute := range clustersWithDataDistributes.DataDistributes.Code { + for _, c := range distribute.Clusters { + if cluster.ClusterId == c.ClusterID { + if c.JsonData == "" { + continue + } + jsonData := struct { + Name string `json:"name"` + Id string `json:"id"` + }{} + err := json.Unmarshal([]byte(c.JsonData), &jsonData) + if err != nil { + return nil, fmt.Errorf("db yaml jsonData convert failed, task %d, cluster %s, datatype %s", taskId, cluster.ClusterId, "code") + } + cluster.CodeId = jsonData.Id + } + } + } + } + + if cluster.ModelId == "" { + for _, distribute := range clustersWithDataDistributes.DataDistributes.Model { + for _, c := range distribute.Clusters { + if cluster.ClusterId == c.ClusterID { + if c.JsonData == "" { + continue + } + jsonData := struct { + Name string `json:"name"` + Id string `json:"id"` + }{} + err := json.Unmarshal([]byte(c.JsonData), &jsonData) + if err != nil { + return nil, fmt.Errorf("jsonData convert failed, task %d, cluster %s, datatype %s", taskId, cluster.ClusterId, "model") + } + cluster.ModelId = jsonData.Id + } + } + } + } } - for _, cluster := range *assignedClusters { + // check empty data + for _, cluster := range assignedClusters { if cluster.DatasetId == "" { - return fmt.Errorf("failed to run task %d, cluster %s cannot find %s", taskId, cluster.ClusterId, "DatasetId") + return nil, fmt.Errorf("failed to run task %d, cluster %s cannot find %s", taskId, cluster.ClusterId, "DatasetId") } if cluster.ImageId == "" { - return fmt.Errorf("failed to run task %d, cluster %s cannot find %s", taskId, cluster.ClusterId, "ImageId") + return nil, fmt.Errorf("failed to run task %d, cluster %s cannot find %s", taskId, cluster.ClusterId, "ImageId") } if cluster.CodeId == "" { - return fmt.Errorf("failed to run task %d, cluster %s cannot find %s", taskId, cluster.ClusterId, "CodeId") + return nil, fmt.Errorf("failed to run task %d, cluster %s cannot find %s", taskId, cluster.ClusterId, "CodeId") } } - return nil + return assignedClusters, nil } diff --git a/internal/storeLink/openi.go b/internal/storeLink/openi.go index d5c5c011..f2805606 100644 --- a/internal/storeLink/openi.go +++ b/internal/storeLink/openi.go @@ -460,6 +460,8 @@ func (o OpenI) GetTrainingTask(ctx context.Context, taskId string) (*collector.T resp.Status = constants.Stopped case "PENDING": resp.Status = constants.Pending + case "WAITING": + resp.Status = constants.Waiting default: resp.Status = "undefined" } diff --git a/pkg/constants/task.go b/pkg/constants/task.go index d21f5b9d..e9fb5841 100644 --- a/pkg/constants/task.go +++ b/pkg/constants/task.go @@ -30,4 +30,5 @@ const ( Stopped = "Stopped" Deploying = "Deploying" Cancelled = "Cancelled" + Waiting = "Waiting" )