diff --git a/models/cloudbrain.go b/models/cloudbrain.go index 810e68d30..9b30c4200 100755 --- a/models/cloudbrain.go +++ b/models/cloudbrain.go @@ -531,11 +531,12 @@ type ResourceSpecs struct { } type ResourceSpec struct { - Id int `json:"id"` - CpuNum int `json:"cpu"` - GpuNum int `json:"gpu"` - MemMiB int `json:"memMiB"` - ShareMemMiB int `json:"shareMemMiB"` + Id int `json:"id"` + CpuNum int `json:"cpu"` + GpuNum int `json:"gpu"` + MemMiB int `json:"memMiB"` + ShareMemMiB int `json:"shareMemMiB"` + UnitPrice int64 `json:"unitPrice"` } type FlavorInfos struct { @@ -543,9 +544,10 @@ type FlavorInfos struct { } type FlavorInfo struct { - Id int `json:"id"` - Value string `json:"value"` - Desc string `json:"desc"` + Id int `json:"id"` + Value string `json:"value"` + Desc string `json:"desc"` + UnitPrice int64 `json:"unitPrice"` } type ImageInfosModelArts struct { @@ -1692,3 +1694,12 @@ func CloudbrainAll(opts *CloudbrainsOptions) ([]*CloudbrainInfo, int64, error) { return cloudbrains, count, nil } + +func GetStartedCloudbrainTaskByUpdatedUnix(startTime, endTime time.Time) ([]Cloudbrain, error) { + r := make([]Cloudbrain, 0) + err := x.Where("updated_unix >= ? and updated_unix <= ? and start_time > 0", startTime.Unix(), endTime.Unix()).Find(&r) + if err != nil { + return nil, err + } + return r, nil +} diff --git a/models/models.go b/models/models.go index 59e7a3a48..c6c0d6610 100755 --- a/models/models.go +++ b/models/models.go @@ -148,7 +148,7 @@ func init() { new(TaskAccomplishLog), new(RewardOperateRecord), new(LimitConfig), - new(PeriodicTask), + new(RewardPeriodicTask), new(PointAccountLog), new(PointAccount), ) diff --git a/models/point_periodic_task.go b/models/point_periodic_task.go deleted file mode 100644 index 0d4297f2f..000000000 --- a/models/point_periodic_task.go +++ /dev/null @@ -1,28 +0,0 @@ -package models - -import "code.gitea.io/gitea/modules/timeutil" - -type PeriodicTaskStatus int - -// Possible PeriodicTaskStatus types. -const ( - PeriodicTaskStatusRunning PointAccountStatus = iota + 1 // 1 - PeriodicTaskStatusSuccess // 2 - PeriodicTaskStatusFailed // 3 -) - -type PeriodicTask struct { - ID int64 `xorm:"pk autoincr"` - Type string `xorm:"NOT NULL"` - OperateRecordId int64 `xorm:"INDEX NOT NULL"` - IntervalSecond int64 `xorm:"NOT NULL"` - PointsAmount int64 `xorm:"NOT NULL"` - NextExecuteTime timeutil.TimeStamp - SuccessCount int `xorm:"NOT NULL default 0"` - FailedCount int `xorm:"NOT NULL default 0"` - Status string `xorm:"NOT NULL"` - ExitCode string - CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"` - FinishedUnix timeutil.TimeStamp `xorm:"INDEX"` - UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"` -} diff --git a/models/reward_operate_record.go b/models/reward_operate_record.go index 4c31df03c..d3b2e0a10 100644 --- a/models/reward_operate_record.go +++ b/models/reward_operate_record.go @@ -6,12 +6,27 @@ import ( "xorm.io/builder" ) +type SourceType string + const ( - SourceTypeAccomplishTask string = "ACCOMPLISH_TASK" - SourceTypeAdminOperate = "ADMIN_OPERATE" - SourceTypeRunCloudbrainTask = "RUN_CLOUBRAIN_TASK" + SourceTypeAccomplishTask SourceType = "ACCOMPLISH_TASK" + SourceTypeAdminOperate SourceType = "ADMIN_OPERATE" + SourceTypeRunCloudbrainTask SourceType = "RUN_CLOUDBRAIN_TASK" ) +func (r SourceType) Name() string { + switch r { + case SourceTypeAccomplishTask: + return "ACCOMPLISH_TASK" + case SourceTypeAdminOperate: + return "ADMIN_OPERATE" + case SourceTypeRunCloudbrainTask: + return "RUN_CLOUDBRAIN_TASK" + default: + return "" + } +} + type RewardType string const ( @@ -66,6 +81,17 @@ func (r RewardOperateType) Show() string { } } +func GetRewardOperateTypeInstance(s string) RewardOperateType { + switch s { + case OperateTypeIncrease.Name(): + return OperateTypeIncrease + case OperateTypeDecrease.Name(): + return OperateTypeDecrease + default: + return "" + } +} + const ( OperateTypeIncrease RewardOperateType = "INCREASE" OperateTypeDecrease RewardOperateType = "DECREASE" @@ -80,20 +106,19 @@ const ( const Semicolon = ";" type RewardOperateRecord struct { - ID int64 `xorm:"pk autoincr"` - RecordId string `xorm:"INDEX NOT NULL"` - UserId int64 `xorm:"INDEX NOT NULL"` - Amount int64 `xorm:"NOT NULL"` - RewardType string `xorm:"NOT NULL"` - SourceType string `xorm:"NOT NULL"` - SourceId string `xorm:"INDEX NOT NULL"` - RequestId string `xorm:"INDEX NOT NULL"` - OperateType string `xorm:"NOT NULL"` - CycleIntervalSeconds int64 `xorm:"NOT NULL default 0"` - Status string `xorm:"NOT NULL"` - Remark string - CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"` - UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"` + ID int64 `xorm:"pk autoincr"` + RecordId string `xorm:"INDEX NOT NULL"` + UserId int64 `xorm:"INDEX NOT NULL"` + Amount int64 `xorm:"NOT NULL"` + RewardType string `xorm:"NOT NULL"` + SourceType string `xorm:"NOT NULL"` + SourceId string `xorm:"INDEX NOT NULL"` + RequestId string `xorm:"INDEX NOT NULL"` + OperateType string `xorm:"NOT NULL"` + Status string `xorm:"NOT NULL"` + Remark string + CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"` + UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"` } func getPointOperateRecord(tl *RewardOperateRecord) (*RewardOperateRecord, error) { @@ -106,10 +131,18 @@ func getPointOperateRecord(tl *RewardOperateRecord) (*RewardOperateRecord, error return tl, nil } -func GetPointOperateRecordBySourceTypeAndRequestId(sourceType, requestId string) (*RewardOperateRecord, error) { +func GetPointOperateRecordBySourceTypeAndRequestId(sourceType, requestId, operateType string) (*RewardOperateRecord, error) { + t := &RewardOperateRecord{ + SourceType: sourceType, + RequestId: requestId, + OperateType: operateType, + } + return getPointOperateRecord(t) +} + +func GetPointOperateRecordByRecordId(recordId string) (*RewardOperateRecord, error) { t := &RewardOperateRecord{ - SourceType: sourceType, - RequestId: requestId, + RecordId: recordId, } return getPointOperateRecord(t) } @@ -140,14 +173,13 @@ func SumRewardAmountInTaskPeriod(rewardType string, sourceType string, userId in } type RewardOperateContext struct { - SourceType string - SourceId string - Remark string - Reward Reward - TargetUserId int64 - RequestId string - OperateType RewardOperateType - CycleIntervalSeconds int64 + SourceType SourceType + SourceId string + Remark string + Reward Reward + TargetUserId int64 + RequestId string + OperateType RewardOperateType } type Reward struct { diff --git a/models/reward_periodic_task.go b/models/reward_periodic_task.go new file mode 100644 index 000000000..e6ebd17c2 --- /dev/null +++ b/models/reward_periodic_task.go @@ -0,0 +1,114 @@ +package models + +import ( + "code.gitea.io/gitea/modules/timeutil" + "time" +) + +type PeriodicTaskStatus int + +const ( + PeriodicTaskStatusRunning = iota + 1 // 1 + PeriodicTaskStatusFinished // 2 +) + +type PeriodType string + +const ( + PeriodType30MinutesFree1HourCost PeriodType = "30MF1HC" +) + +func (r PeriodType) Name() string { + switch r { + case PeriodType30MinutesFree1HourCost: + return "30MF1HC" + default: + return "" + } +} + +type RewardPeriodicTask struct { + ID int64 `xorm:"pk autoincr"` + OperateRecordId string `xorm:"INDEX NOT NULL"` + DelaySeconds int64 + IntervalSeconds int64 + Amount int64 `xorm:"NOT NULL"` + NextExecuteTime timeutil.TimeStamp `xorm:"INDEX NOT NULL"` + SuccessCount int `xorm:"NOT NULL default 0"` + Status int `xorm:"NOT NULL"` + CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"` + FinishedUnix timeutil.TimeStamp `xorm:"INDEX"` + UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"` +} + +type StartPeriodicTaskOpts struct { + SourceType SourceType + SourceId string + Remark string + TargetUserId int64 + RequestId string + OperateType RewardOperateType + Delay time.Duration + Interval time.Duration + UnitAmount int64 + RewardType RewardType + StartTime time.Time +} + +func InsertPeriodicTask(tl *RewardPeriodicTask) (int64, error) { + return x.Insert(tl) +} + +func GetRunningRewardTask(now time.Time) ([]RewardPeriodicTask, error) { + r := make([]RewardPeriodicTask, 0) + err := x.Where("next_execute_time <= ? and status = ?", now.Unix(), PeriodicTaskStatusRunning).Find(&r) + if err != nil { + return nil, err + } + return r, err +} + +func IncrRewardTaskSuccessCount(t RewardPeriodicTask, count int64, nextTime timeutil.TimeStamp) error { + sess := x.NewSession() + defer sess.Close() + _, err := sess.Exec("update reward_periodic_task set success_count = success_count + ? , next_execute_time = ?, updated_unix = ? where id = ?", count, nextTime, timeutil.TimeStampNow(), t.ID) + if err != nil { + sess.Rollback() + return err + } + _, err = sess.Exec("update reward_operate_record set amount = amount + ? ,updated_unix = ? where record_id = ?", count*t.Amount, timeutil.TimeStampNow(), t.OperateRecordId) + if err != nil { + sess.Rollback() + return err + } + sess.Commit() + return nil +} + +func GetPeriodicTaskBySourceIdAndType(sourceType SourceType, sourceId string, operateType RewardOperateType) (*RewardPeriodicTask, error) { + r := RewardPeriodicTask{} + _, err := x.SQL("select rpt.* from reward_periodic_task rpt "+ + "inner join reward_operate_record ror on rpt.operate_record_id = ror.record_id"+ + " where ror.source_type = ? and source_id = ? and operate_type = ? ", sourceType.Name(), sourceId, operateType.Name()).Get(&r) + if err != nil { + return nil, err + } + return &r, nil +} + +func StopPeriodicTask(taskId int64, operateRecordId string, stopTime time.Time) error { + sess := x.NewSession() + defer sess.Close() + _, err := sess.Where("id = ? and status = ?", taskId, PeriodicTaskStatusRunning).Update(&RewardPeriodicTask{Status: PeriodicTaskStatusFinished, FinishedUnix: timeutil.TimeStamp(stopTime.Unix())}) + if err != nil { + sess.Rollback() + return err + } + _, err = sess.Where("record_id = ? and status = ?", operateRecordId, OperateStatusOperating).Update(&RewardOperateRecord{Status: OperateStatusSucceeded}) + if err != nil { + sess.Rollback() + return err + } + sess.Commit() + return nil +} diff --git a/modules/auth/modelarts.go b/modules/auth/modelarts.go index ce41f5d1e..0cbed45a6 100755 --- a/modules/auth/modelarts.go +++ b/modules/auth/modelarts.go @@ -22,6 +22,7 @@ type CreateModelArtsNotebookForm struct { Description string `form:"description"` Flavor string `form:"flavor" binding:"Required"` ImageId string `form:"image_id" binding:"Required"` + ResourceSpecId int `form:"resource_spec_id"` } func (f *CreateModelArtsNotebookForm) Validate(ctx *macaron.Context, errs binding.Errors) binding.Errors { @@ -46,6 +47,7 @@ type CreateModelArtsTrainJobForm struct { VersionName string `form:"version_name" binding:"Required"` FlavorName string `form:"flaver_names" binding:"Required"` EngineName string `form:"engine_names" binding:"Required"` + ResourceSpecId int `form:"resource_spec_id"` } type CreateModelArtsInferenceJobForm struct { @@ -71,6 +73,7 @@ type CreateModelArtsInferenceJobForm struct { ModelName string `form:"model_name" binding:"Required"` ModelVersion string `form:"model_version" binding:"Required"` CkptName string `form:"ckpt_name" binding:"Required"` + ResourceSpecId int `form:"resource_spec_id"` } func (f *CreateModelArtsTrainJobForm) Validate(ctx *macaron.Context, errs binding.Errors) binding.Errors { diff --git a/modules/context/point.go b/modules/context/point.go new file mode 100644 index 000000000..9fbff61be --- /dev/null +++ b/modules/context/point.go @@ -0,0 +1,19 @@ +package context + +import ( + "code.gitea.io/gitea/services/reward/point/account" + "gitea.com/macaron/macaron" +) + +// PointAccount returns a macaron to get request user's point account +func PointAccount() macaron.Handler { + return func(ctx *Context) { + a, err := account.GetAccount(ctx.User.ID) + if err != nil { + ctx.ServerError("GetPointAccount", err) + return + } + ctx.Data["PointAccount"] = a + ctx.Next() + } +} diff --git a/modules/cron/tasks_basic.go b/modules/cron/tasks_basic.go index b3a6c02a1..39100594d 100755 --- a/modules/cron/tasks_basic.go +++ b/modules/cron/tasks_basic.go @@ -5,6 +5,7 @@ package cron import ( + "code.gitea.io/gitea/services/reward" "context" "time" @@ -207,6 +208,28 @@ func registerSyncCloudbrainStatus() { }) } +func registerRewardPeriodTask() { + RegisterTaskFatal("reward_period_task", &BaseConfig{ + Enabled: true, + RunAtStart: true, + Schedule: "@every 5m", + }, func(ctx context.Context, _ *models.User, _ Config) error { + reward.StartRewardTask() + return nil + }) +} + +func registerCloudbrainPointDeductTask() { + RegisterTaskFatal("cloudbrain_point_deduct_task", &BaseConfig{ + Enabled: true, + RunAtStart: true, + Schedule: "@every 1m", + }, func(ctx context.Context, _ *models.User, _ Config) error { + reward.StartCloudbrainPointDeductTask() + return nil + }) +} + func initBasicTasks() { registerUpdateMirrorTask() registerRepoHealthCheck() @@ -227,4 +250,7 @@ func initBasicTasks() { registerSyncCloudbrainStatus() registerHandleOrgStatistic() + + registerRewardPeriodTask() + registerCloudbrainPointDeductTask() } diff --git a/modules/modelarts/modelarts.go b/modules/modelarts/modelarts.go index 78b40fd56..de5c392cd 100755 --- a/modules/modelarts/modelarts.go +++ b/modules/modelarts/modelarts.go @@ -96,6 +96,7 @@ type GenerateTrainJobReq struct { VersionCount int EngineName string TotalVersionCount int + ResourceSpecId int } type GenerateInferenceJobReq struct { @@ -127,6 +128,7 @@ type GenerateInferenceJobReq struct { ModelVersion string CkptName string ResultUrl string + ResourceSpecId int } type VersionInfo struct { diff --git a/modules/redis/redis_key/reward_redis_key.go b/modules/redis/redis_key/reward_redis_key.go index add304db4..f6c9480a9 100644 --- a/modules/redis/redis_key/reward_redis_key.go +++ b/modules/redis/redis_key/reward_redis_key.go @@ -1,11 +1,16 @@ package redis_key +import "fmt" + const REWARD_REDIS_PREFIX = "reward" -func RewardSendLock(requestId string, sourceType string) string { - return KeyJoin(REWARD_REDIS_PREFIX, requestId, sourceType, "send") +func RewardOperateLock(requestId string, sourceType string, operateType string) string { + return KeyJoin(REWARD_REDIS_PREFIX, requestId, sourceType, operateType, "send") } func RewardOperateNotification() string { return KeyJoin(REWARD_REDIS_PREFIX, "operate", "notification") } +func RewardTaskRunningLock(taskId int64) string { + return KeyJoin(REWARD_REDIS_PREFIX, "periodic_task", fmt.Sprint(taskId), "lock") +} diff --git a/modules/setting/setting.go b/modules/setting/setting.go index 595c51286..b5ffe6eab 100755 --- a/modules/setting/setting.go +++ b/modules/setting/setting.go @@ -548,6 +548,9 @@ var ( WechatQRCodeExpireSeconds int WechatAuthSwitch bool + //point config + CloudBrainTaskPointPaySwitch bool + //nginx proxy PROXYURL string RadarMap = struct { @@ -1374,7 +1377,10 @@ func NewContext() { WechatAppId = sec.Key("APP_ID").MustString("wxba77b915a305a57d") WechatAppSecret = sec.Key("APP_SECRET").MustString("e48e13f315adc32749ddc7057585f198") WechatQRCodeExpireSeconds = sec.Key("QR_CODE_EXPIRE_SECONDS").MustInt(120) - WechatAuthSwitch = sec.Key("AUTH_SWITCH").MustBool(true) + WechatAuthSwitch = sec.Key("AUTH_SWITCH").MustBool(false) + + sec = Cfg.Section("point") + CloudBrainTaskPointPaySwitch = sec.Key("CLOUDBRAIN_PAY_SWITCH").MustBool(false) SetRadarMapConfig() diff --git a/routers/repo/cloudbrain.go b/routers/repo/cloudbrain.go index 7ed6fa6ef..b4d532ab0 100755 --- a/routers/repo/cloudbrain.go +++ b/routers/repo/cloudbrain.go @@ -2,6 +2,7 @@ package repo import ( "bufio" + "code.gitea.io/gitea/services/reward" "encoding/json" "errors" "fmt" @@ -229,6 +230,13 @@ func CloudBrainCreate(ctx *context.Context, form auth.CreateCloudBrainForm) { command = commandTrain } + if !reward.IsPointBalanceEnough(ctx.User.ID, jobType, resourceSpecId) { + log.Error("point balance is not enough,userId=%d jobType=%s resourceSpecId=%d", ctx.User.ID, jobType, resourceSpecId) + cloudBrainNewDataPrepare(ctx) + ctx.RenderWithErr("point balance not enough", tpl, &form) + return + } + tasks, err := models.GetCloudbrainsByDisplayJobName(repo.ID, jobType, displayJobName) if err == nil { if len(tasks) != 0 { @@ -308,6 +316,13 @@ func CloudBrainRestart(ctx *context.Context) { var status = string(models.JobWaiting) task := ctx.Cloudbrain for { + if !reward.IsPointBalanceEnough(ctx.User.ID, task.JobType, task.ResourceSpecId) { + log.Error("point balance is not enough,userId=%d jobType=%s resourceSpecId=%d", ctx.User.ID, task.JobType, task.ResourceSpecId) + resultCode = "-1" + errorMsg = "insufficient points balance" + break + } + if task.Status != string(models.JobStopped) && task.Status != string(models.JobSucceeded) && task.Status != string(models.JobFailed) { log.Error("the job(%s) is not stopped", task.JobName, ctx.Data["MsgID"]) resultCode = "-1" @@ -842,7 +857,6 @@ func CloudBrainStop(ctx *context.Context) { errorMsg = "system error" break } - status = task.Status break } @@ -1845,6 +1859,13 @@ func BenchMarkAlgorithmCreate(ctx *context.Context, form auth.CreateCloudBrainFo repo := ctx.Repo.Repository + if !reward.IsPointBalanceEnough(ctx.User.ID, string(models.JobTypeBenchmark), resourceSpecId) { + log.Error("point balance is not enough,userId=%d jobType=%s resourceSpecId=%d", ctx.User.ID, string(models.JobTypeBenchmark), resourceSpecId) + cloudBrainNewDataPrepare(ctx) + ctx.RenderWithErr("point balance not enough", tplCloudBrainBenchmarkNew, &form) + return + } + tasks, err := models.GetCloudbrainsByDisplayJobName(repo.ID, string(models.JobTypeBenchmark), displayJobName) if err == nil { if len(tasks) != 0 { @@ -2000,6 +2021,13 @@ func ModelBenchmarkCreate(ctx *context.Context, form auth.CreateCloudBrainForm) tpl := tplCloudBrainBenchmarkNew command := cloudbrain.Command + if !reward.IsPointBalanceEnough(ctx.User.ID, jobType, resourceSpecId) { + log.Error("point balance is not enough,userId=%d jobType=%s resourceSpecId=%d", ctx.User.ID, jobType, resourceSpecId) + cloudBrainNewDataPrepare(ctx) + ctx.RenderWithErr("point balance not enough", tpl, &form) + return + } + tasks, err := models.GetCloudbrainsByDisplayJobName(repo.ID, jobType, displayJobName) if err == nil { if len(tasks) != 0 { diff --git a/routers/repo/modelarts.go b/routers/repo/modelarts.go index 95ca8df62..dea996a50 100755 --- a/routers/repo/modelarts.go +++ b/routers/repo/modelarts.go @@ -2,6 +2,7 @@ package repo import ( "archive/zip" + "code.gitea.io/gitea/services/reward" "encoding/json" "errors" "fmt" @@ -204,7 +205,14 @@ func Notebook2Create(ctx *context.Context, form auth.CreateModelArtsNotebookForm flavor := form.Flavor imageId := form.ImageId repo := ctx.Repo.Repository + resourceSpecId := form.ResourceSpecId + if !reward.IsPointBalanceEnough(ctx.User.ID, string(models.JobTypeDebug), resourceSpecId) { + log.Error("point balance is not enough,userId=%d jobType=%s resourceSpecId=%d", ctx.User.ID, string(models.JobTypeBenchmark), resourceSpecId) + cloudBrainNewDataPrepare(ctx) + ctx.RenderWithErr("point balance not enough", tplModelArtsNotebookNew, &form) + return + } count, err := models.GetCloudbrainNotebookCountByUserID(ctx.User.ID) if err != nil { log.Error("GetCloudbrainNotebookCountByUserID failed:%v", err, ctx.Data["MsgID"]) @@ -418,6 +426,13 @@ func NotebookManage(ctx *context.Context) { errorMsg = "you have no right to restart the job" break } + if !reward.IsPointBalanceEnough(ctx.User.ID, task.JobType, task.ResourceSpecId) { + log.Error("point balance is not enough,userId=%d jobType=%s resourceSpecId=%d", ctx.User.ID, task.JobType, task.ResourceSpecId) + resultCode = "-1" + errorMsg = "point balance not enough" + break + return + } count, err := models.GetCloudbrainNotebookCountByUserID(ctx.User.ID) if err != nil { @@ -985,7 +1000,14 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm) FlavorName := form.FlavorName VersionCount := modelarts.VersionCount EngineName := form.EngineName + resourceSpecId := form.ResourceSpecId + if !reward.IsPointBalanceEnough(ctx.User.ID, string(models.JobTypeTrain), resourceSpecId) { + log.Error("point balance is not enough,userId=%d jobType=%s resourceSpecId=%d", ctx.User.ID, string(models.JobTypeBenchmark), resourceSpecId) + cloudBrainNewDataPrepare(ctx) + ctx.RenderWithErr("point balance not enough", tplModelArtsTrainJobNew, &form) + return + } count, err := models.GetCloudbrainTrainJobCountByUserID(ctx.User.ID) if err != nil { log.Error("GetCloudbrainTrainJobCountByUserID failed:%v", err, ctx.Data["MsgID"]) @@ -1161,6 +1183,7 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm) EngineName: EngineName, VersionCount: VersionCount, TotalVersionCount: modelarts.TotalVersionCount, + ResourceSpecId: resourceSpecId, } //将params转换Parameters.Parameter,出错时返回给前端 @@ -1716,7 +1739,6 @@ func TrainJobStop(ctx *context.Context) { ctx.RenderWithErr(err.Error(), tplModelArtsTrainJobIndex, nil) return } - ctx.Redirect(setting.AppSubURL + ctx.Repo.RepoLink + "/modelarts/train-job?listType=" + listType) } @@ -1825,9 +1847,16 @@ func InferenceJobCreate(ctx *context.Context, form auth.CreateModelArtsInference modelName := form.ModelName modelVersion := form.ModelVersion ckptName := form.CkptName + resourceSpecId := form.ResourceSpecId ckptUrl := form.TrainUrl + form.CkptName + if !reward.IsPointBalanceEnough(ctx.User.ID, string(models.JobTypeInference), resourceSpecId) { + log.Error("point balance is not enough,userId=%d jobType=%s resourceSpecId=%d", ctx.User.ID, string(models.JobTypeBenchmark), resourceSpecId) + inferenceJobErrorNewDataPrepare(ctx, form) + ctx.RenderWithErr("point balance not enough", tplModelArtsInferenceJobNew, &form) + return + } count, err := models.GetCloudbrainInferenceJobCountByUserID(ctx.User.ID) if err != nil { log.Error("GetCloudbrainInferenceJobCountByUserID failed:%v", err, ctx.Data["MsgID"]) @@ -1973,6 +2002,7 @@ func InferenceJobCreate(ctx *context.Context, form auth.CreateModelArtsInference ModelVersion: modelVersion, CkptName: ckptName, ResultUrl: resultObsPath, + ResourceSpecId: resourceSpecId, } err = modelarts.GenerateInferenceJob(ctx, req) diff --git a/routers/routes/routes.go b/routers/routes/routes.go index 31075742c..3ce633f93 100755 --- a/routers/routes/routes.go +++ b/routers/routes/routes.go @@ -1068,7 +1068,7 @@ func RegisterRoutes(m *macaron.Macaron) { m.Get("/models", reqRepoCloudBrainReader, repo.CloudBrainShowModels) m.Get("/download_model", cloudbrain.AdminOrJobCreaterRight, repo.CloudBrainDownloadModel) }) - m.Get("/create", reqWechatBind, reqRepoCloudBrainWriter, repo.CloudBrainNew) + m.Get("/create", reqWechatBind, reqRepoCloudBrainWriter, context.PointAccount(), repo.CloudBrainNew) m.Post("/create", reqWechatBind, reqRepoCloudBrainWriter, bindIgnErr(auth.CreateCloudBrainForm{}), repo.CloudBrainCreate) m.Group("/benchmark", func() { @@ -1079,7 +1079,7 @@ func RegisterRoutes(m *macaron.Macaron) { m.Post("/del", cloudbrain.AdminOrOwnerOrJobCreaterRight, repo.BenchmarkDel) m.Get("/rate", reqRepoCloudBrainReader, repo.GetRate) }) - m.Get("/create", reqWechatBind, reqRepoCloudBrainWriter, repo.CloudBrainBenchmarkNew) + m.Get("/create", reqWechatBind, reqRepoCloudBrainWriter, context.PointAccount(), repo.CloudBrainBenchmarkNew) m.Post("/create", reqWechatBind, reqRepoCloudBrainWriter, bindIgnErr(auth.CreateCloudBrainForm{}), repo.CloudBrainBenchmarkCreate) m.Get("/get_child_types", repo.GetChildTypes) }) @@ -1093,7 +1093,7 @@ func RegisterRoutes(m *macaron.Macaron) { //m.Get("/create_version", reqWechatBind, cloudbrain.AdminOrJobCreaterRightForTrain, repo.TrainJobNewVersion) //m.Post("/create_version", reqWechatBind, cloudbrain.AdminOrJobCreaterRightForTrain, bindIgnErr(auth.CreateModelArtsTrainJobForm{}), repo.TrainJobCreateVersion) }) - m.Get("/create", reqWechatBind, reqRepoCloudBrainWriter, repo.CloudBrainTrainJobNew) + m.Get("/create", reqWechatBind, reqRepoCloudBrainWriter, context.PointAccount(), repo.CloudBrainTrainJobNew) m.Post("/create", reqWechatBind, reqRepoCloudBrainWriter, bindIgnErr(auth.CreateCloudBrainForm{}), repo.CloudBrainCreate) }) }, context.RepoRef()) @@ -1141,7 +1141,7 @@ func RegisterRoutes(m *macaron.Macaron) { m.Post("/:action", reqRepoCloudBrainWriter, repo.NotebookManage) m.Post("/del", cloudbrain.AdminOrOwnerOrJobCreaterRight, repo.NotebookDel) }) - m.Get("/create", reqWechatBind, reqRepoCloudBrainWriter, repo.NotebookNew) + m.Get("/create", reqWechatBind, reqRepoCloudBrainWriter, context.PointAccount(), repo.NotebookNew) m.Post("/create", reqWechatBind, reqRepoCloudBrainWriter, bindIgnErr(auth.CreateModelArtsNotebookForm{}), repo.Notebook2Create) }) @@ -1155,7 +1155,7 @@ func RegisterRoutes(m *macaron.Macaron) { m.Get("/create_version", reqWechatBind, cloudbrain.AdminOrJobCreaterRightForTrain, repo.TrainJobNewVersion) m.Post("/create_version", reqWechatBind, cloudbrain.AdminOrJobCreaterRightForTrain, bindIgnErr(auth.CreateModelArtsTrainJobForm{}), repo.TrainJobCreateVersion) }) - m.Get("/create", reqWechatBind, reqRepoCloudBrainWriter, repo.TrainJobNew) + m.Get("/create", reqWechatBind, reqRepoCloudBrainWriter, context.PointAccount(), repo.TrainJobNew) m.Post("/create", reqWechatBind, reqRepoCloudBrainWriter, bindIgnErr(auth.CreateModelArtsTrainJobForm{}), repo.TrainJobCreate) m.Get("/para-config-list", reqRepoCloudBrainReader, repo.TrainJobGetConfigList) @@ -1168,7 +1168,7 @@ func RegisterRoutes(m *macaron.Macaron) { m.Get("/result_download", cloudbrain.AdminOrJobCreaterRightForTrain, repo.ResultDownload) m.Get("/downloadall", repo.DownloadMultiResultFile) }) - m.Get("/create", reqWechatBind, reqRepoCloudBrainWriter, repo.InferenceJobNew) + m.Get("/create", reqWechatBind, reqRepoCloudBrainWriter, context.PointAccount(), repo.InferenceJobNew) m.Post("/create", reqWechatBind, reqRepoCloudBrainWriter, bindIgnErr(auth.CreateModelArtsInferenceJobForm{}), repo.InferenceJobCreate) }) }, context.RepoRef()) diff --git a/services/reward/cloubrain_deduct.go b/services/reward/cloubrain_deduct.go new file mode 100644 index 000000000..61068a87a --- /dev/null +++ b/services/reward/cloubrain_deduct.go @@ -0,0 +1,128 @@ +package reward + +import ( + "code.gitea.io/gitea/models" + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/services/reward/point/account" + "encoding/json" + "fmt" + "time" +) + +var ( + ResourceSpecs *models.ResourceSpecs + TrainResourceSpecs *models.ResourceSpecs +) + +//IsPointBalanceEnough check whether the user's point balance is bigger than task unit price +func IsPointBalanceEnough(targetUserId int64, jobType string, resourceSpecId int) bool { + if !setting.CloudBrainTaskPointPaySwitch { + return true + } + spec := getResourceSpec(jobType, resourceSpecId) + if spec == nil { + return true + } + a, error := account.GetAccount(targetUserId) + if error != nil { + return false + } + return a.Balance >= spec.UnitPrice + +} + +func StartCloudBrainPointDeductTask(task models.Cloudbrain) { + if !setting.CloudBrainTaskPointPaySwitch { + return + } + + spec := getResourceSpec(task.JobType, task.ResourceSpecId) + if spec == nil || spec.UnitPrice == 0 { + return + } + + StartPeriodicTask(&models.StartPeriodicTaskOpts{ + SourceType: models.SourceTypeRunCloudbrainTask, + SourceId: getCloudBrainPointTaskSourceId(task), + TargetUserId: task.UserID, + RequestId: getCloudBrainPointTaskSourceId(task), + OperateType: models.OperateTypeDecrease, + Delay: 30 * time.Minute, + Interval: 60 * time.Minute, + UnitAmount: spec.UnitPrice, + RewardType: models.RewardTypePoint, + StartTime: time.Unix(int64(task.StartTime), 0), + }) +} + +func StopCloudBrainPointDeductTask(task models.Cloudbrain) { + StopPeriodicTask(models.SourceTypeRunCloudbrainTask, getCloudBrainPointTaskSourceId(task), models.OperateTypeDecrease) +} + +func getCloudBrainPointTaskSourceId(task models.Cloudbrain) string { + return models.SourceTypeRunCloudbrainTask.Name() + "_" + task.JobType + "_" + fmt.Sprint(task.Type) + "_" + fmt.Sprint(task.ID) +} + +func getResourceSpec(jobType string, resourceSpecId int) *models.ResourceSpec { + if jobType == string(models.JobTypeTrain) { + if TrainResourceSpecs == nil { + json.Unmarshal([]byte(setting.TrainResourceSpecs), &TrainResourceSpecs) + } + for _, spec := range TrainResourceSpecs.ResourceSpec { + if resourceSpecId == spec.Id { + return spec + } + } + } else { + if ResourceSpecs == nil { + json.Unmarshal([]byte(setting.ResourceSpecs), &ResourceSpecs) + } + for _, spec := range ResourceSpecs.ResourceSpec { + if resourceSpecId == spec.Id { + return spec + } + } + + } + return nil + +} + +var firstTimeFlag = true + +func StartCloudbrainPointDeductTask() { + defer func() { + if err := recover(); err != nil { + combinedErr := fmt.Errorf("%s\n%s", err, log.Stack(2)) + log.Error("PANIC:%v", combinedErr) + } + }() + log.Debug("try to run CloudbrainPointDeductTask") + end := time.Now() + start := end.Add(5 * time.Minute) + if firstTimeFlag { + //When it is executed for the first time, it needs to process the tasks of the last 1 hours. + //This is done to prevent the application from hanging for a long time + start = end.Add(1 * time.Hour) + firstTimeFlag = false + } + + taskList, err := models.GetStartedCloudbrainTaskByUpdatedUnix(start, end) + if err != nil { + log.Error("GetStartedCloudbrainTaskByUpdatedUnix error. %v", err) + return + } + if taskList == nil || len(taskList) == 0 { + log.Debug("No cloudbrain task need handled") + return + } + for _, t := range taskList { + if int64(t.StartTime) <= end.Unix() && int64(t.StartTime) >= start.Unix() { + StartCloudBrainPointDeductTask(t) + } + if int64(t.EndTime) <= end.Unix() && int64(t.EndTime) >= start.Unix() { + StopCloudBrainPointDeductTask(t) + } + } +} diff --git a/services/reward/operator.go b/services/reward/operator.go index 40c093b67..50ec01ff3 100644 --- a/services/reward/operator.go +++ b/services/reward/operator.go @@ -21,7 +21,7 @@ type RewardOperator interface { Operate(ctx *models.RewardOperateContext) error } -func Send(ctx *models.RewardOperateContext) error { +func Operate(ctx *models.RewardOperateContext) error { defer func() { if err := recover(); err != nil { combinedErr := fmt.Errorf("%s\n%s", err, log.Stack(2)) @@ -33,7 +33,7 @@ func Send(ctx *models.RewardOperateContext) error { return errors.New("param incorrect") } //add lock - var rewardLock = redis_lock.NewDistributeLock(redis_key.RewardSendLock(ctx.RequestId, ctx.SourceType)) + var rewardLock = redis_lock.NewDistributeLock(redis_key.RewardOperateLock(ctx.RequestId, ctx.SourceType.Name(), ctx.OperateType.Name())) isOk, err := rewardLock.Lock(3 * time.Second) if err != nil { return err @@ -45,7 +45,7 @@ func Send(ctx *models.RewardOperateContext) error { defer rewardLock.UnLock() //is handled before? - isHandled, err := isHandled(ctx.SourceType, ctx.RequestId) + isHandled, err := isHandled(ctx.SourceType.Name(), ctx.RequestId, ctx.OperateType.Name()) if err != nil { log.Error("reward is handled error,%v", err) return err @@ -61,9 +61,11 @@ func Send(ctx *models.RewardOperateContext) error { return errors.New("operator of reward type is not exist") } - //is limited? - if isLimited := operator.IsLimited(ctx); isLimited { - return nil + if ctx.OperateType == models.OperateTypeIncrease { + //is limited? + if isLimited := operator.IsLimited(ctx); isLimited { + return nil + } } //new reward operate record @@ -76,15 +78,12 @@ func Send(ctx *models.RewardOperateContext) error { //operate if err := operator.Operate(ctx); err != nil { - updateAwardOperateRecordStatus(ctx.SourceType, ctx.RequestId, models.OperateStatusOperating, models.OperateStatusFailed) + updateAwardOperateRecordStatus(ctx.SourceType.Name(), ctx.RequestId, models.OperateStatusOperating, models.OperateStatusFailed) return err } - //if not a cycle operate,update status to success - if ctx.CycleIntervalSeconds == 0 { - updateAwardOperateRecordStatus(ctx.SourceType, ctx.RequestId, models.OperateStatusOperating, models.OperateStatusSucceeded) - NotifyRewardOperation(ctx.TargetUserId, ctx.Reward.Amount, ctx.Reward.Type, ctx.OperateType) - } + updateAwardOperateRecordStatus(ctx.SourceType.Name(), ctx.RequestId, models.OperateStatusOperating, models.OperateStatusSucceeded) + NotifyRewardOperation(ctx.TargetUserId, ctx.Reward.Amount, ctx.Reward.Type, ctx.OperateType) return nil } @@ -99,8 +98,8 @@ func GetOperator(rewardType models.RewardType) RewardOperator { return RewardOperatorMap[rewardType.Name()] } -func isHandled(sourceType string, requestId string) (bool, error) { - _, err := models.GetPointOperateRecordBySourceTypeAndRequestId(sourceType, requestId) +func isHandled(sourceType string, requestId string, operateType string) (bool, error) { + _, err := models.GetPointOperateRecordBySourceTypeAndRequestId(sourceType, requestId, operateType) if err != nil { if models.IsErrRecordNotExist(err) { return false, nil @@ -113,17 +112,36 @@ func isHandled(sourceType string, requestId string) (bool, error) { func initAwardOperateRecord(ctx *models.RewardOperateContext) (string, error) { record := &models.RewardOperateRecord{ - RecordId: util.UUID(), - UserId: ctx.TargetUserId, - Amount: ctx.Reward.Amount, - RewardType: ctx.Reward.Type.Name(), - SourceType: ctx.SourceType, - SourceId: ctx.SourceId, - RequestId: ctx.RequestId, - OperateType: ctx.OperateType.Name(), - CycleIntervalSeconds: ctx.CycleIntervalSeconds, - Status: models.OperateStatusOperating, - Remark: ctx.Remark, + RecordId: util.UUID(), + UserId: ctx.TargetUserId, + Amount: ctx.Reward.Amount, + RewardType: ctx.Reward.Type.Name(), + SourceType: ctx.SourceType.Name(), + SourceId: ctx.SourceId, + RequestId: ctx.RequestId, + OperateType: ctx.OperateType.Name(), + Status: models.OperateStatusOperating, + Remark: ctx.Remark, + } + _, err := models.InsertAwardOperateRecord(record) + if err != nil { + return "", err + } + return record.RecordId, nil +} + +func createPeriodicRewardOperateRecord(ctx *models.StartPeriodicTaskOpts) (string, error) { + record := &models.RewardOperateRecord{ + RecordId: util.UUID(), + UserId: ctx.TargetUserId, + Amount: 0, + RewardType: ctx.RewardType.Name(), + SourceType: ctx.SourceType.Name(), + SourceId: ctx.SourceId, + RequestId: ctx.RequestId, + OperateType: ctx.OperateType.Name(), + Status: models.OperateStatusOperating, + Remark: ctx.Remark, } _, err := models.InsertAwardOperateRecord(record) if err != nil { @@ -139,3 +157,78 @@ func updateAwardOperateRecordStatus(sourceType, requestId, oldStatus, newStatus } return nil } + +func StartPeriodicTaskAsyn(opts *models.StartPeriodicTaskOpts) { + go StartPeriodicTask(opts) +} + +func StartPeriodicTask(opts *models.StartPeriodicTaskOpts) error { + defer func() { + if err := recover(); err != nil { + combinedErr := fmt.Errorf("%s\n%s", err, log.Stack(2)) + log.Error("PANIC:%v", combinedErr) + } + }() + //add lock + var rewardLock = redis_lock.NewDistributeLock(redis_key.RewardOperateLock(opts.RequestId, opts.SourceType.Name(), opts.OperateType.Name())) + isOk, err := rewardLock.Lock(3 * time.Second) + if err != nil { + return err + } + if !isOk { + log.Info("duplicated operate request,targetUserId=%d requestId=%s", opts.TargetUserId, opts.RequestId) + return nil + } + defer rewardLock.UnLock() + + //is handled before? + isHandled, err := isHandled(opts.SourceType.Name(), opts.RequestId, opts.OperateType.Name()) + if err != nil { + log.Error("operate is handled error,%v", err) + return err + } + if isHandled { + log.Info("operate has been handled,opts=%+v", opts) + return nil + } + //new reward operate record + recordId, err := createPeriodicRewardOperateRecord(opts) + if err != nil { + return err + } + + if err = NewRewardPeriodicTask(recordId, opts); err != nil { + updateAwardOperateRecordStatus(opts.SourceType.Name(), opts.RequestId, models.OperateStatusOperating, models.OperateStatusFailed) + return err + } + return nil +} + +func StopPeriodicTaskAsyn(sourceType models.SourceType, sourceId string, operateType models.RewardOperateType) { + go StopPeriodicTask(sourceType, sourceId, operateType) +} + +func StopPeriodicTask(sourceType models.SourceType, sourceId string, operateType models.RewardOperateType) error { + defer func() { + if err := recover(); err != nil { + combinedErr := fmt.Errorf("%s\n%s", err, log.Stack(2)) + log.Error("PANIC:%v", combinedErr) + } + }() + task, err := models.GetPeriodicTaskBySourceIdAndType(sourceType, sourceId, operateType) + if err != nil { + log.Error("StopPeriodicTask. GetPeriodicTaskBySourceIdAndType error. %v", err) + return err + } + if task == nil { + log.Info("Periodic task is not exist") + return nil + } + if task.Status == models.PeriodicTaskStatusFinished { + log.Info("Periodic task is finished") + return nil + } + now := time.Now() + RunRewardTask(*task, now) + return models.StopPeriodicTask(task.ID, task.OperateRecordId, now) +} diff --git a/services/reward/period_task.go b/services/reward/period_task.go new file mode 100644 index 000000000..d00e8d0c4 --- /dev/null +++ b/services/reward/period_task.go @@ -0,0 +1,103 @@ +package reward + +import ( + "code.gitea.io/gitea/models" + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/redis/redis_key" + "code.gitea.io/gitea/modules/redis/redis_lock" + "code.gitea.io/gitea/modules/timeutil" + "fmt" + "time" +) + +func NewRewardPeriodicTask(operateRecordId string, opts *models.StartPeriodicTaskOpts) error { + task := &models.RewardPeriodicTask{} + task.DelaySeconds = int64(opts.Delay.Seconds()) + task.IntervalSeconds = int64(opts.Interval.Seconds()) + task.Amount = opts.UnitAmount + task.OperateRecordId = operateRecordId + task.Status = models.PeriodicTaskStatusRunning + task.NextExecuteTime = timeutil.TimeStamp(opts.StartTime.Add(opts.Delay).Unix()) + + _, err := models.InsertPeriodicTask(task) + return err +} + +func StartRewardTask() { + defer func() { + if err := recover(); err != nil { + combinedErr := fmt.Errorf("%s\n%s", err, log.Stack(2)) + log.Error("PANIC:%v", combinedErr) + } + }() + log.Debug("try to run reward tasks") + now := time.Now() + taskList, err := models.GetRunningRewardTask(now) + if err != nil { + log.Error("GetRunningRewardTask error. %v", err) + return + } + if taskList == nil || len(taskList) == 0 { + log.Debug("No GetRunningRewardTask need handled") + return + } + for _, t := range taskList { + RunRewardTask(t, now) + } +} + +func RunRewardTask(t models.RewardPeriodicTask, now time.Time) { + lock := redis_lock.NewDistributeLock(redis_key.RewardTaskRunningLock(t.ID)) + isOk, _ := lock.LockWithWait(3*time.Second, 3*time.Second) + if !isOk { + log.Error("get RewardTaskRunningLock failed,t=%+v", t) + return + } + defer lock.UnLock() + record, err := models.GetPointOperateRecordByRecordId(t.OperateRecordId) + if err != nil { + log.Error("RunRewardTask. GetPointOperateRecordByRecordId error. %v", err) + return + } + if record.Status != models.OperateStatusOperating { + log.Info("RunRewardTask. operate record is finished,record=%+v", record) + return + } + n, nextTime := countExecuteTimes(t, now) + if n == 0 { + return + } + //get operator + operator := GetOperator(models.GetRewardTypeInstance(record.RewardType)) + if operator == nil { + log.Error("RunRewardTask. operator of reward type is not exist") + return + } + err = operator.Operate(&models.RewardOperateContext{ + SourceType: models.SourceTypeRunCloudbrainTask, + SourceId: t.OperateRecordId, + Reward: models.Reward{ + Amount: n * t.Amount, + Type: models.GetRewardTypeInstance(record.RewardType), + }, + TargetUserId: record.UserId, + OperateType: models.GetRewardOperateTypeInstance(record.OperateType), + }) + if err != nil { + log.Error("RunRewardTask.operator operate error.%v", err) + return + } + models.IncrRewardTaskSuccessCount(t, n, nextTime) +} + +func countExecuteTimes(t models.RewardPeriodicTask, now time.Time) (int64, timeutil.TimeStamp) { + interval := t.IntervalSeconds + nextTime := int64(t.NextExecuteTime) + if nextTime > now.Unix() { + return 0, 0 + } + diff := now.Unix() - nextTime + n := diff/interval + 1 + newNextTime := timeutil.TimeStamp(nextTime + n*interval) + return n, newNextTime +} diff --git a/services/reward/point/point_operate.go b/services/reward/point/point_operate.go index 38b6b5384..4b84cdd0c 100644 --- a/services/reward/point/point_operate.go +++ b/services/reward/point/point_operate.go @@ -18,13 +18,12 @@ type PointOperator struct { } func (operator *PointOperator) IsLimited(ctx *models.RewardOperateContext) bool { - realAmount, err := limiter.CheckLimitWithFillUp(ctx.SourceType, models.LimitTypeRewardPoint, ctx.TargetUserId, ctx.Reward.Amount) + realAmount, err := limiter.CheckLimitWithFillUp(ctx.SourceType.Name(), models.LimitTypeRewardPoint, ctx.TargetUserId, ctx.Reward.Amount) if err != nil { return true } if realAmount < ctx.Reward.Amount { ctx.Remark = models.AppendRemark(ctx.Remark, fmt.Sprintf(LossMsg, ctx.Reward.Amount, realAmount)) - ctx.Reward.Amount = realAmount } return false diff --git a/services/task/task.go b/services/task/task.go index cd6ca830e..4c85ce52e 100644 --- a/services/task/task.go +++ b/services/task/task.go @@ -51,7 +51,7 @@ func accomplish(userId int64, taskType string) error { } //reward - reward.Send(&models.RewardOperateContext{ + reward.Operate(&models.RewardOperateContext{ SourceType: models.SourceTypeAccomplishTask, SourceId: logId, Reward: models.Reward{