diff --git a/services/cloudbrain/cloudbrainTask/inference.go b/services/cloudbrain/cloudbrainTask/inference.go index ba5ba7cf1..5481a8b4c 100644 --- a/services/cloudbrain/cloudbrainTask/inference.go +++ b/services/cloudbrain/cloudbrainTask/inference.go @@ -96,7 +96,7 @@ func CloudBrainInferenceJobCreate(ctx *context.Context, option api.CreateTrainJo return } - count, err := models.GetCloudbrainCountByUserID(ctx.User.ID, jobType) + count, err := GetNotFinalStatusTaskCount(ctx.User.ID, models.TypeCloudBrainOne, jobType) if err != nil { log.Error("GetCloudbrainCountByUserID failed:%v", err, ctx.Data["MsgID"]) ctx.JSON(http.StatusOK, models.BaseErrorMessage("system error")) @@ -226,7 +226,7 @@ func ModelArtsInferenceJobCreate(ctx *context.Context, option api.CreateTrainJob } defer lock.UnLock() - count, err := models.GetCloudbrainInferenceJobCountByUserID(ctx.User.ID) + count, err := GetNotFinalStatusTaskCount(ctx.User.ID, models.TypeCloudBrainTwo, string(models.JobTypeInference)) if err != nil { log.Error("GetCloudbrainInferenceJobCountByUserID failed:%v", err, ctx.Data["MsgID"]) diff --git a/services/cloudbrain/cloudbrainTask/train.go b/services/cloudbrain/cloudbrainTask/train.go index 28d5038bf..5af9a1181 100644 --- a/services/cloudbrain/cloudbrainTask/train.go +++ b/services/cloudbrain/cloudbrainTask/train.go @@ -174,8 +174,12 @@ func checkParameters(ctx *context.Context, option api.CreateTrainJobOption, lock return nil, nil, "", fmt.Errorf(ctx.Tr("repo.cloudbrain_bootfile_err")) } + computeResource := models.GPUResource + if option.Type == 3 { + computeResource = models.NPUResource + } //check count limit - count, err := models.GetGrampusCountByUserID(ctx.User.ID, string(models.JobTypeTrain), models.GPUResource) + count, err := GetNotFinalStatusTaskCount(ctx.User.ID, models.TypeC2Net, string(models.JobTypeTrain), computeResource) if err != nil { log.Error("GetGrampusCountByUserID failed:%v", err, ctx.Data["MsgID"]) return nil, nil, "", fmt.Errorf("system error") @@ -207,13 +211,13 @@ func checkParameters(ctx *context.Context, option api.CreateTrainJobOption, lock } //check specification - computeResource := models.GPU + computeType := models.GPU if option.Type == 3 { - computeResource = models.NPU + computeType = models.NPU } spec, err := resource.GetAndCheckSpec(ctx.User.ID, option.SpecId, models.FindSpecsOptions{ JobType: models.JobTypeTrain, - ComputeResource: computeResource, + ComputeResource: computeType, Cluster: models.C2NetCluster, }) if err != nil || spec == nil { @@ -226,7 +230,7 @@ func checkParameters(ctx *context.Context, option api.CreateTrainJobOption, lock } //check dataset - datasetInfos, datasetNames, err := models.GetDatasetInfo(option.Attachment, computeResource) + datasetInfos, datasetNames, err := models.GetDatasetInfo(option.Attachment, computeType) if err != nil { log.Error("GetDatasetInfo failed: %v", err, ctx.Data["MsgID"]) return nil, nil, "", fmt.Errorf(ctx.Tr("cloudbrain.error.dataset_select"))