Browse Source

整理代码

tags/v1.22.1.2
liuzx 3 years ago
parent
commit
fe08d55c4f
3 changed files with 15 additions and 33 deletions
  1. +4
    -7
      models/cloudbrain.go
  2. +7
    -7
      modules/modelarts/resty.go
  3. +4
    -19
      routers/repo/modelarts.go

+ 4
- 7
models/cloudbrain.go View File

@@ -665,13 +665,10 @@ type InfConfig struct {
Parameter []Parameter `json:"parameter"`
DataUrl string `json:"data_url"` //训练作业需要的数据集OBS路径URL
EngineID int64 `json:"engine_id"`
// TrainUrl string `json:"train_url"` //训练作业的输出文件OBS路径URL
LogUrl string `json:"log_url"`
//UserImageUrl string `json:"user_image_url"`
//UserCommand string `json:"user_command"`
CreateVersion bool `json:"create_version"`
Flavor Flavor `json:"flavor"`
PoolID string `json:"pool_id"`
LogUrl string `json:"log_url"`
CreateVersion bool `json:"create_version"`
Flavor Flavor `json:"flavor"`
PoolID string `json:"pool_id"`
}

type CreateTrainJobVersionParams struct {


+ 7
- 7
modules/modelarts/resty.go View File

@@ -891,7 +891,7 @@ sendjob:
Post(HOST + "/v1/" + setting.ProjectID + urlTrainJob)

if err != nil {
return nil, fmt.Errorf("resty create train-job: %s", err)
return nil, fmt.Errorf("resty create inference-job: %s", err)
}

req, _ := json.Marshal(createJobParams)
@@ -909,23 +909,23 @@ sendjob:
log.Error("json.Unmarshal failed(%s): %v", res.String(), err.Error())
return &result, fmt.Errorf("json.Unmarshal failed(%s): %v", res.String(), err.Error())
}
log.Error("createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg)
log.Error("createInferenceJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg)
BootFileErrorMsg := "Invalid OBS path '" + createJobParams.InfConfig.BootFileUrl + "'."
DataSetErrorMsg := "Invalid OBS path '" + createJobParams.InfConfig.DataUrl + "'."
if temp.ErrorMsg == BootFileErrorMsg {
log.Error("启动文件错误!createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg)
log.Error("启动文件错误!createInferenceJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg)
return &result, fmt.Errorf("启动文件错误!")
}
if temp.ErrorMsg == DataSetErrorMsg {
log.Error("数据集错误!createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg)
log.Error("数据集错误!createInferenceJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg)
return &result, fmt.Errorf("数据集错误!")
}
return &result, fmt.Errorf("createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg)
return &result, fmt.Errorf("createInferenceJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg)
}

if !result.IsSuccess {
log.Error("createTrainJob failed(%s): %s", result.ErrorCode, result.ErrorMsg)
return &result, fmt.Errorf("createTrainJob failed(%s): %s", result.ErrorCode, result.ErrorMsg)
log.Error("createInferenceJob failed(%s): %s", result.ErrorCode, result.ErrorMsg)
return &result, fmt.Errorf("createInferenceJob failed(%s): %s", result.ErrorCode, result.ErrorMsg)
}

return &result, nil


+ 4
- 19
routers/repo/modelarts.go View File

@@ -797,24 +797,11 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm)
return
}

// attach, err := models.GetAttachmentByUUID(uuid)
// if err != nil {
// log.Error("GetAttachmentByUUID(%s) failed:%v", uuid, err.Error())
// return
// }

//todo: del the codeLocalPath
_, err = ioutil.ReadDir(codeLocalPath)
if err == nil {
os.RemoveAll(codeLocalPath)
}
// } else {
// log.Error("创建任务失败,原代码还未删除,请重试!: %s (%v)", repo.FullName(), err)
// versionErrorDataPrepare(ctx, form)
// ctx.RenderWithErr("创建任务失败,原代码还未删除,请重试!", tplModelArtsTrainJobVersionNew, &form)
// return
// }
// os.RemoveAll(codeLocalPath)

gitRepo, _ := git.OpenRepository(repo.RepoPath())
commitID, _ := gitRepo.GetBranchCommitID(branch_name)
@@ -1020,12 +1007,6 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ
return
}

// attach, err := models.GetAttachmentByUUID(uuid)
// if err != nil {
// log.Error("GetAttachmentByUUID(%s) failed:%v", uuid, err.Error())
// return
// }

//todo: del the codeLocalPath
_, err = ioutil.ReadDir(codeLocalPath)
if err == nil {
@@ -1297,6 +1278,10 @@ func paramCheckCreateInferenceJob(form auth.CreateModelArtsInferenceJobForm) err
log.Error("the CkptName(%d) must not be nil", form.CkptName)
return errors.New("权重文件不能为空")
}
if form.BranchName == "" {
log.Error("the Branch(%d) must not be nil", form.BranchName)
return errors.New("分支名不能为空")
}

return nil
}


Loading…
Cancel
Save