Browse Source

提交后端功能代码。

Signed-off-by: zouap <zouap@pcl.ac.cn>
tags/v1.22.10.1^2
zouap 3 years ago
parent
commit
0f4a71ea9a
4 changed files with 163 additions and 16 deletions
  1. +1
    -1
      modules/aisafety/resty.go
  2. +16
    -0
      modules/setting/setting.go
  3. +140
    -11
      routers/repo/aisafety.go
  4. +6
    -4
      routers/routes/routes.go

+ 1
- 1
modules/aisafety/resty.go View File

@@ -227,7 +227,7 @@ func GetTaskStatus(jobID string) (map[string]interface{}, error) {

if err != nil {
log.Info("error =" + err.Error())
return nil, fmt.Errorf("resty GetJob: %v", err)
return nil, fmt.Errorf("Get task status error: %v", err)
} else {
reMap := make(map[string]interface{})
err = json.Unmarshal(res.Body(), &reMap)


+ 16
- 0
modules/setting/setting.go View File

@@ -707,6 +707,13 @@ var (
NPU_MINDSPORE_IMAGE_ID int
NPU_TENSORFLOW_IMAGE_ID int
}{}

ModelSafetyTest = struct {
BaseDataSetName string
BaseDataSetUUID string
CombatDataSetName string
CombatDataSetUUID string
}{}
)

// DateLang transforms standard language locale name to corresponding value in datetime plugin.
@@ -1527,6 +1534,15 @@ func NewContext() {
getGrampusConfig()
getModelartsCDConfig()
getModelConvertConfig()
getModelSafetyConfig()
}

func getModelSafetyConfig() {
sec := Cfg.Section("model_safety_test")
ModelSafetyTest.BaseDataSetName = sec.Key("BaseDataSetName").MustString("ImageNet1000_100基础数据集;CIFAR10_1000基础数据集")
ModelSafetyTest.BaseDataSetUUID = sec.Key("BaseDataSetUUID").MustString("0fa81800-e95e-42f4-ab40-2c3ca83f2344;6eaab665-1c68-45fc-ad05-c070f2db092e")
ModelSafetyTest.CombatDataSetName = sec.Key("CombatDataSetName").MustString("ImageNet1000_100_FGSM;CIFAR10_1000_FGSM.zip")
ModelSafetyTest.CombatDataSetUUID = sec.Key("CombatDataSetUUID").MustString("9ba30d3f-83e1-4f9f-849d-6f93217e2ca3;23825796-e4f3-4cf8-b697-9963048cef42")
}

func getModelConvertConfig() {


+ 140
- 11
routers/repo/aisafety.go View File

@@ -1,8 +1,10 @@
package repo

import (
"bufio"
"encoding/json"
"errors"
"io"
"io/ioutil"
"net/http"
"os"
@@ -22,6 +24,10 @@ import (
uuid "github.com/satori/go.uuid"
)

const (
tplModelSafetyTestCreate = "repo/modelsafety/new"
)

func CloudBrainAiSafetyCreateTest(ctx *context.Context) {
log.Info("start to create CloudBrainAiSafetyCreate")
uuid := uuid.NewV4()
@@ -57,34 +63,148 @@ func CloudBrainAiSafetyCreateTest(ctx *context.Context) {

func GetAiSafetyTask(ctx *context.Context) {
var ID = ctx.Params(":jobid")
task, err := models.GetCloudbrainByJobIDWithDeleted(ID)
getAiSafetyTaskStatusFromCloudbrain(ID)
}

func getAiSafetyTaskStatusFromCloudbrain(ID string) {
job, err := models.GetCloudbrainByJobIDWithDeleted(ID)
if err != nil {
log.Error("GetCloudbrainByJobID failed:" + err.Error())
ctx.NotFound(ctx.Req.URL.RequestURI(), nil)
return
}
if task.Type == models.TypeCloudBrainTwo {
if job.Type == models.TypeCloudBrainTwo {

} else if job.Type == models.TypeCloudBrainOne {
if isTaskNotFinished(job.Status) {
log.Info("The task not finished,name=" + job.DisplayJobName)
jobResult, err := cloudbrain.GetJob(job.JobID)

result, err := models.ConvertToJobResultPayload(jobResult.Payload)
if err != nil {
log.Error("ConvertToJobResultPayload failed:", err)
return
}
job.Status = result.JobStatus.State
if result.JobStatus.State != string(models.JobWaiting) && result.JobStatus.State != string(models.JobFailed) {
taskRoles := result.TaskRoles
taskRes, _ := models.ConvertToTaskPod(taskRoles[cloudbrain.SubTaskName].(map[string]interface{}))
job.Status = taskRes.TaskStatuses[0].State
}

if result.JobStatus.State != string(models.JobSucceeded) {
err = models.UpdateJob(job)
if err != nil {
log.Error("UpdateJob failed:", err)
}
} else {
//
job.Status = string(models.ModelSafetyTesting)
err = models.UpdateJob(job)
if err != nil {
log.Error("UpdateJob failed:", err)
}
//send msg to beihang
sendGpuInferenceResultToTest(job)
}

} else if task.Type == models.TypeCloudBrainOne {
} else {
if job.Status == string(models.ModelSafetyTesting) {
//
result, err := aisafety.GetTaskStatus(job.PreVersionName)
if err == nil {
if result["code"] != nil {

}
}
}
}

}
}

func isTaskFinished(status string) bool {
func sendGpuInferenceResultToTest(job *models.Cloudbrain) {
datasetname := job.DatasetName
datasetnames := strings.Split(datasetname, ";")
indicator := job.LabelName

req := aisafety.TaskReq{
UnionId: job.JobID,
EvalName: job.DisplayJobName,
EvalContent: job.Description,
TLPath: "test",
Indicators: strings.Split(indicator, ";"),
CDName: datasetnames[1],
BDName: datasetnames[0],
}

resultDir := "/model"
prefix := "/" + setting.CBCodePathPrefix + job.JobName + resultDir
files, err := storage.GetOneLevelAllObjectUnderDirMinio(setting.Attachment.Minio.Bucket, prefix, "")
if err != nil {
log.Error("query cloudbrain one model failed: %v", err)
return
}
jsonContent := ""
for _, file := range files {
if strings.HasSuffix(file.FileName, "result.json") {
path := storage.GetMinioPath(job.JobName+resultDir+"/", file.FileName)
log.Info("path=" + path)
reader, err := os.Open(path)
defer reader.Close()
if err == nil {
r := bufio.NewReader(reader)
for {
line, error := r.ReadString('\n')
if error == io.EOF {
log.Info("read file completed.")
break
}
if error != nil {
log.Info("read file error." + error.Error())
break
}
jsonContent += line
}
}
break
}
}
if jsonContent != "" {
serialNo, err := aisafety.CreateSafetyTask(req, jsonContent)
if err == nil {
//update serial no to db
job.PreVersionName = serialNo
err = models.UpdateJob(job)
if err != nil {
log.Error("UpdateJob failed:", err)
}
}
} else {
log.Info("The json is null. so set it failed.")
//update task failed.
job.Status = string(models.JobFailed)
err = models.UpdateJob(job)
if err != nil {
log.Error("UpdateJob failed:", err)
}
}
}

func isTaskNotFinished(status string) bool {
if status == string(models.ModelArtsTrainJobRunning) || status == string(models.ModelArtsTrainJobWaiting) {
return false
return true
}
if status == string(models.JobWaiting) || status == string(models.JobRunning) {
return false
return true
}

if status == string(models.ModelArtsTrainJobUnknown) || status == string(models.ModelArtsTrainJobInit) {
return false
return true
}
if status == string(models.ModelArtsTrainJobImageCreating) || status == string(models.ModelArtsTrainJobSubmitTrying) {
return false
return true
}
return true
return false
}

func StopAiSafetyTask(ctx *context.Context) {
@@ -95,7 +215,16 @@ func DelAiSafetyTask(ctx *context.Context) {

}

func CloudBrainAiSafetyCreate(ctx *context.Context) {
func AiSafetyCreateForGet(ctx *context.Context) {
ctx.Data["PageIsCloudBrain"] = true
ctx.Data["BaseDataSetName"] = setting.ModelSafetyTest.BaseDataSetName
ctx.Data["BaseDataSetUUID"] = setting.ModelSafetyTest.BaseDataSetUUID
ctx.Data["CombatDataSetName"] = setting.ModelSafetyTest.CombatDataSetName
ctx.Data["CombatDataSetUUID"] = setting.ModelSafetyTest.CombatDataSetUUID
ctx.HTML(200, tplModelSafetyTestCreate)
}

func AiSafetyCreateForPost(ctx *context.Context) {
ctx.Data["PageIsCloudBrain"] = true
displayJobName := ctx.Query("DisplayJobName")
jobName := util.ConvertDisplayJobNameToJobName(displayJobName)


+ 6
- 4
routers/routes/routes.go View File

@@ -6,15 +6,16 @@ package routes

import (
"bytes"
"code.gitea.io/gitea/routers/reward/point"
"code.gitea.io/gitea/routers/task"
"code.gitea.io/gitea/services/reward"
"encoding/gob"
"net/http"
"path"
"text/template"
"time"

"code.gitea.io/gitea/routers/reward/point"
"code.gitea.io/gitea/routers/task"
"code.gitea.io/gitea/services/reward"

"code.gitea.io/gitea/modules/slideimage"

"code.gitea.io/gitea/routers/image"
@@ -1231,7 +1232,8 @@ func RegisterRoutes(m *macaron.Macaron) {
m.Get("", reqRepoCloudBrainWriter, repo.GetAiSafetyTask)
m.Post("/stop", cloudbrain.AdminOrOwnerOrJobCreaterRight, repo.StopAiSafetyTask)
m.Post("/del", cloudbrain.AdminOrOwnerOrJobCreaterRight, repo.DelAiSafetyTask)
m.Post("/create", reqWechatBind, reqRepoCloudBrainWriter, repo.CloudBrainAiSafetyCreate)
m.Get("/create", reqWechatBind, reqRepoCloudBrainWriter, repo.AiSafetyCreateForGet)
m.Post("/create", reqWechatBind, reqRepoCloudBrainWriter, repo.AiSafetyCreateForPost)
})
}, context.RepoRef())



Loading…
Cancel
Save