| @@ -11,6 +11,7 @@ import ( | |||
| "gitlink.org.cn/cloudream/common/pkgs/async" | |||
| "gitlink.org.cn/cloudream/common/pkgs/logger" | |||
| "gitlink.org.cn/cloudream/jcs-pub/common/ecode" | |||
| stgglb "gitlink.org.cn/cloudream/jcs-pub/common/globals" | |||
| "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/accesstoken" | |||
| "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" | |||
| @@ -106,25 +107,50 @@ func (k *Keeper) Start() *async.UnboundChannel[KeeperEvent] { | |||
| continue | |||
| } | |||
| var priKeyStr string | |||
| var newToken jcstypes.UserAccessToken | |||
| corCli := stgglb.CoordinatorRPCPool.Get() | |||
| refResp, cerr := corCli.UserRefreshToken(context.Background(), &corrpc.UserRefreshToken{}) | |||
| if cerr != nil { | |||
| log.Warnf("refresh token: %v", cerr) | |||
| corCli.Release() | |||
| continue | |||
| if cerr.Code != string(ecode.Unauthorized) { | |||
| log.Warnf("refresh token: %v", cerr) | |||
| corCli.Release() | |||
| continue | |||
| } | |||
| log.Warnf("unauthorized to refresh token, login again") | |||
| loginResp, cerr := corCli.UserLogin(context.Background(), &corrpc.UserLogin{ | |||
| Account: k.cfg.Account, | |||
| Password: k.cfg.Password, | |||
| }) | |||
| if cerr != nil { | |||
| log.Warnf("login: %v", cerr) | |||
| corCli.Release() | |||
| continue | |||
| } else { | |||
| priKeyStr = loginResp.PrivateKey | |||
| newToken = loginResp.Token | |||
| } | |||
| } else { | |||
| priKeyStr = refResp.PrivateKey | |||
| newToken = refResp.Token | |||
| } | |||
| priKey, err := hex.DecodeString(refResp.PrivateKey) | |||
| priKey, err := hex.DecodeString(priKeyStr) | |||
| if err != nil { | |||
| log.Warnf("decode private key: %v", err) | |||
| corCli.Release() | |||
| continue | |||
| } | |||
| log.Infof("refresh token success, new token expires at %v", refResp.Token.ExpiresAt) | |||
| log.Infof("new token expires at %v", newToken.ExpiresAt) | |||
| k.lock.Lock() | |||
| k.token = refResp.Token | |||
| k.token = newToken | |||
| k.priKey = priKey | |||
| k.lock.Unlock() | |||
| @@ -38,14 +38,15 @@ func DoTx01[R any](db *DB, do func(tx SQLContext) (R, error)) (R, error) { | |||
| return ret, err | |||
| } | |||
| func DoTx02[R any](db *DB, do func(tx SQLContext) (R, error)) (R, error) { | |||
| var ret R | |||
| func DoTx02[R1, R2 any](db *DB, do func(tx SQLContext) (R1, R2, error)) (R1, R2, error) { | |||
| var r1 R1 | |||
| var r2 R2 | |||
| err := db.db.Transaction(func(tx *gorm.DB) error { | |||
| var err error | |||
| ret, err = do(SQLContext{tx}) | |||
| r1, r2, err = do(SQLContext{tx}) | |||
| return err | |||
| }) | |||
| return ret, err | |||
| return r1, r2, err | |||
| } | |||
| func DoTx12[T any, R any](db *DB, do func(tx SQLContext, t T) (R, error), t T) (R, error) { | |||
| @@ -24,6 +24,12 @@ func (*UserAccessTokenDB) Create(ctx SQLContext, token *jcstypes.UserAccessToken | |||
| return ctx.Table("UserAccessToken").Create(token).Error | |||
| } | |||
| func (*UserAccessTokenDB) GetExpired(ctx SQLContext, expireTime time.Time) ([]jcstypes.UserAccessToken, error) { | |||
| var ret []jcstypes.UserAccessToken | |||
| err := ctx.Table("UserAccessToken").Where("ExpiresAt < ?", expireTime).Find(&ret).Error | |||
| return ret, err | |||
| } | |||
| func (db *UserAccessTokenDB) DeleteByID(ctx SQLContext, userID jcstypes.UserID, tokenID jcstypes.AccessTokenID) error { | |||
| return ctx.Table("UserAccessToken").Where("UserID = ? AND TokenID = ?", userID, tokenID).Delete(&jcstypes.UserAccessToken{}).Error | |||
| } | |||
| @@ -66,7 +66,7 @@ func userCreate(ctx *CommandContext, account string, nickName string) { | |||
| return | |||
| } | |||
| user, err := db.DoTx02(ctx.repl.db, func(tx db.SQLContext) (jcstypes.User, error) { | |||
| user, err := db.DoTx01(ctx.repl.db, func(tx db.SQLContext) (jcstypes.User, error) { | |||
| return ctx.repl.db.User().Create(tx, account, hex.EncodeToString(passHash), nickName) | |||
| }) | |||
| if err != nil { | |||
| @@ -87,7 +87,7 @@ func userLogout(ctx *CommandContext, account string, tokenID jcstypes.AccessToke | |||
| log := logger.WithField("UserID", acc.UserID).WithField("TokenID", tokenID) | |||
| d := ctx.repl.db | |||
| loaded, err := db.DoTx02(d, func(tx db.SQLContext) ([]jcstypes.LoadedAccessToken, error) { | |||
| loaded, err := db.DoTx01(d, func(tx db.SQLContext) ([]jcstypes.LoadedAccessToken, error) { | |||
| token, err := d.UserAccessToken().GetByID(tx, acc.UserID, tokenID) | |||
| if err != nil { | |||
| return nil, err | |||
| @@ -13,7 +13,7 @@ import ( | |||
| func (svc *Service) SelectStorageHub(ctx context.Context, msg *corrpc.SelectStorageHub) (*corrpc.SelectStorageHubResp, *rpc.CodeError) { | |||
| d := svc.db | |||
| resp, err := db.DoTx02(d, func(tx db.SQLContext) ([]*jcstypes.Hub, error) { | |||
| resp, err := db.DoTx01(d, func(tx db.SQLContext) ([]*jcstypes.Hub, error) { | |||
| allLoc, err := d.HubLocation().GetAll(tx) | |||
| if err != nil { | |||
| return nil, err | |||
| @@ -122,7 +122,7 @@ func (svc *Service) UserLogout(ctx context.Context, msg *corrpc.UserLogout) (*co | |||
| log := logger.WithField("UserID", authInfo.UserID).WithField("TokenID", authInfo.AccessTokenID) | |||
| loaded, err := db.DoTx02(svc.db, func(tx db.SQLContext) ([]jcstypes.LoadedAccessToken, error) { | |||
| loaded, err := db.DoTx01(svc.db, func(tx db.SQLContext) ([]jcstypes.LoadedAccessToken, error) { | |||
| token, err := svc.db.UserAccessToken().GetByID(tx, authInfo.UserID, authInfo.AccessTokenID) | |||
| if err != nil { | |||
| return nil, err | |||
| @@ -195,7 +195,7 @@ func (svc *Service) notifyLoadedHubs(userID jcstypes.UserID, tokenID jcstypes.Ac | |||
| } | |||
| func (svc *Service) HubLoadAccessToken(ctx context.Context, msg *corrpc.HubLoadAccessToken) (*corrpc.HubLoadAccessTokenResp, *rpc.CodeError) { | |||
| token, err := db.DoTx02(svc.db, func(tx db.SQLContext) (jcstypes.UserAccessToken, error) { | |||
| token, err := db.DoTx01(svc.db, func(tx db.SQLContext) (jcstypes.UserAccessToken, error) { | |||
| token, err := svc.db.UserAccessToken().GetByID(tx, msg.UserID, msg.TokenID) | |||
| if err != nil { | |||
| return jcstypes.UserAccessToken{}, err | |||
| @@ -30,49 +30,49 @@ func (j *ClearExpiredAccessToken) Execute(t *TickTock) { | |||
| log.Infof("job end, time: %v", time.Since(startTime)) | |||
| }() | |||
| expired, err := db.DoTx02(t.db, func(tx db.SQLContext) ([]jcstypes.LoadedAccessToken, error) { | |||
| expd, loaded, err := db.DoTx02(t.db, func(tx db.SQLContext) ([]jcstypes.UserAccessToken, []jcstypes.LoadedAccessToken, error) { | |||
| nowTime := time.Now() | |||
| expired, err := t.db.LoadedAccessToken().GetExpired(tx, nowTime) | |||
| loaded, err := t.db.LoadedAccessToken().GetExpired(tx, nowTime) | |||
| if err != nil { | |||
| return nil, fmt.Errorf("get expired access token load record: %w", err) | |||
| return nil, nil, fmt.Errorf("get expired access token load record: %w", err) | |||
| } | |||
| err = t.db.LoadedAccessToken().DeleteExpired(tx, nowTime) | |||
| if err != nil { | |||
| return nil, fmt.Errorf("delete expired access token load record: %w", err) | |||
| return nil, nil, fmt.Errorf("delete expired access token load record: %w", err) | |||
| } | |||
| expd, err := t.db.UserAccessToken().GetExpired(tx, nowTime) | |||
| if err != nil { | |||
| return nil, nil, fmt.Errorf("get expired user access token: %w", err) | |||
| } | |||
| err = t.db.UserAccessToken().DeleteExpired(tx, nowTime) | |||
| if err != nil { | |||
| return nil, fmt.Errorf("delete expired user access token: %w", err) | |||
| return nil, nil, fmt.Errorf("delete expired user access token: %w", err) | |||
| } | |||
| return expired, nil | |||
| return expd, loaded, nil | |||
| }) | |||
| if err != nil { | |||
| log.Warn(err.Error()) | |||
| return | |||
| } | |||
| uniToken := make(map[accesstoken.CacheKey]bool) | |||
| for _, t := range expired { | |||
| uniToken[accesstoken.CacheKey{ | |||
| UserID: t.UserID, | |||
| TokenID: t.TokenID, | |||
| }] = true | |||
| } | |||
| log.Infof("%v expired access token cleared", len(uniToken)) | |||
| log.Infof("%v(loaded: %v) expired access token cleared", len(expd), len(loaded)) | |||
| // 通知本服务的AccessToken缓存失效 | |||
| for k := range uniToken { | |||
| t.accessToken.NotifyTokenInvalid(k) | |||
| for _, e := range expd { | |||
| t.accessToken.NotifyTokenInvalid(accesstoken.CacheKey{ | |||
| UserID: e.UserID, | |||
| TokenID: e.TokenID, | |||
| }) | |||
| } | |||
| // 通知所有加载了失效Token的Hub | |||
| var loadedHubIDs []jcstypes.HubID | |||
| for _, e := range expired { | |||
| for _, e := range loaded { | |||
| loadedHubIDs = append(loadedHubIDs, e.HubID) | |||
| } | |||
| @@ -87,7 +87,7 @@ func (j *ClearExpiredAccessToken) Execute(t *TickTock) { | |||
| hubMap[h.HubID] = h | |||
| } | |||
| for _, e := range expired { | |||
| for _, e := range loaded { | |||
| h, ok := hubMap[e.HubID] | |||
| if !ok { | |||
| continue | |||