Compare commits

...

1 Commits

Author SHA1 Message Date
  ychao_1983 65f43fa820 提交代码 3 years ago
4 changed files with 59 additions and 30 deletions
Split View
  1. +19
    -18
      models/cloudbrain.go
  2. +25
    -12
      modules/cloudbrain/cloudbrain.go
  3. +1
    -0
      modules/modelarts/modelarts.go
  4. +14
    -0
      routers/repo/cloudbrain.go

+ 19
- 18
models/cloudbrain.go View File

@@ -168,24 +168,25 @@ type Cloudbrain struct {
ImageID string //grampus image_id
AiCenter string //grampus ai center: center_id+center_name

TrainUrl string //输出模型的obs路径
BranchName string //分支名称
Parameters string //传给modelarts的param参数
BootFile string //启动文件
DataUrl string //数据集的obs路径
LogUrl string //日志输出的obs路径
PreVersionId int64 //父版本的版本id
FlavorCode string //modelarts上的规格id
Description string `xorm:"varchar(256)"` //描述
WorkServerNumber int //节点数
FlavorName string //规格名称
EngineName string //引擎名称
TotalVersionCount int //任务的所有版本数量,包括删除的
LabelName string //标签名称
ModelName string //模型名称
ModelVersion string //模型版本
CkptName string //权重文件名称
ResultUrl string //推理结果的obs路径
TrainUrl string //输出模型的obs路径
BranchName string //分支名称
Parameters string //传给modelarts的param参数
BootFile string //启动文件
DataUrl string //数据集的obs路径
LogUrl string //日志输出的obs路径
PreVersionId int64 //父版本的版本id
FlavorCode string //modelarts上的规格id
Description string `xorm:"varchar(256)"` //描述
WorkServerNumber int //节点数
FlavorName string //规格名称
EngineName string //引擎名称
TotalVersionCount int //任务的所有版本数量,包括删除的
LabelName string //标签名称
ModelName string //模型名称
ModelVersion string //模型版本
CkptName string //权重文件名称
PreTrainingModelUrl string //预训练模型地址
ResultUrl string //推理结果的obs路径

User *User `xorm:"-"`
Repo *Repository `xorm:"-"`


+ 25
- 12
modules/cloudbrain/cloudbrain.go View File

@@ -20,18 +20,19 @@ import (
const (
//Command = `pip3 install jupyterlab==2.2.5 -i https://pypi.tuna.tsinghua.edu.cn/simple;service ssh stop;jupyter lab --no-browser --ip=0.0.0.0 --allow-root --notebook-dir="/code" --port=80 --LabApp.token="" --LabApp.allow_origin="self https://cloudbrain.pcl.ac.cn"`
//CommandBenchmark = `echo "start benchmark";python /code/test.py;echo "end benchmark"`
CommandBenchmark = `echo "start benchmark";cd /benchmark && bash run_bk.sh;echo "end benchmark"`
CodeMountPath = "/code"
DataSetMountPath = "/dataset"
ModelMountPath = "/model"
LogFile = "log.txt"
BenchMarkMountPath = "/benchmark"
BenchMarkResourceID = 1
Snn4imagenetMountPath = "/snn4imagenet"
BrainScoreMountPath = "/brainscore"
TaskInfoName = "/taskInfo"
Snn4imagenetCommand = `/opt/conda/bin/python /snn4imagenet/testSNN_script.py --modelname '%s' --modelpath '/dataset' --modeldescription '%s'`
BrainScoreCommand = `bash /brainscore/brainscore_test_par4shSrcipt.sh -b '%s' -n '%s' -p '/dataset' -d '%s'`
CommandBenchmark = `echo "start benchmark";cd /benchmark && bash run_bk.sh;echo "end benchmark"`
CodeMountPath = "/code"
DataSetMountPath = "/dataset"
ModelMountPath = "/model"
PretrainModelMountPath = "/pretrainmodel"
LogFile = "log.txt"
BenchMarkMountPath = "/benchmark"
BenchMarkResourceID = 1
Snn4imagenetMountPath = "/snn4imagenet"
BrainScoreMountPath = "/brainscore"
TaskInfoName = "/taskInfo"
Snn4imagenetCommand = `/opt/conda/bin/python /snn4imagenet/testSNN_script.py --modelname '%s' --modelpath '/dataset' --modeldescription '%s'`
BrainScoreCommand = `bash /brainscore/brainscore_test_par4shSrcipt.sh -b '%s' -n '%s' -p '/dataset' -d '%s'`

SubTaskName = "task1"

@@ -77,6 +78,8 @@ type GenerateCloudBrainTaskReq struct {
ModelVersion string
CkptName string
LabelName string
PreTrainModelPath string
PreTrainingModelUrl string
Spec *models.Specification
}

@@ -276,6 +279,16 @@ func GenerateTask(req GenerateCloudBrainTaskReq) error {
},
}

if req.PreTrainingModelUrl != "" { //预训练
volumes = append(volumes, models.Volume{
HostPath: models.StHostPath{
Path: req.PreTrainModelPath,
MountPath: PretrainModelMountPath,
ReadOnly: true,
},
})
}

if len(req.DatasetInfos) == 1 {
volumes = append(volumes, models.Volume{
HostPath: models.StHostPath{


+ 1
- 0
modules/modelarts/modelarts.go View File

@@ -54,6 +54,7 @@ const (
MultiDataUrl = "multi_data_url"
ResultUrl = "result_url"
CkptUrl = "ckpt_url"
PreTrainModelUrl = "pretrain_model_url"
DeviceTarget = "device_target"
Ascend = "Ascend"
PerPage = 10


+ 14
- 0
routers/repo/cloudbrain.go View File

@@ -327,6 +327,17 @@ func CloudBrainCreate(ctx *context.Context, form auth.CreateCloudBrainForm) {
ResultPath: storage.GetMinioPath(jobName, cloudbrain.ResultPath+"/"),
Spec: spec,
}
if form.ModelName != "" { //使用预训练模型训练
req.ModelName = form.ModelName
req.LabelName = form.LabelName
req.CkptName = form.CkptName
req.ModelVersion = form.ModelVersion
req.PreTrainModelPath = setting.Attachment.Minio.RealPath + form.PreTrainingModelUrl
req.PreTrainingModelUrl = form.PreTrainingModelUrl

}


err = cloudbrain.GenerateTask(req)
if err != nil {
@@ -2626,6 +2637,9 @@ func getTrainJobCommand(form auth.CreateCloudBrainForm) (string, error) {
param += " --" + parameter.Label + "=" + parameter.Value
}
}
if form.CkptName != "" {
param += " --pretrainmodelname" + "=" + form.CkptName
}

command += "python /code/" + bootFile + param + " | tee " + cloudbrain.ModelMountPath + "/" + form.DisplayJobName + "-" + cloudbrain.LogFile



Loading…
Cancel
Save