Browse Source

优化访问Token刷新的逻辑

master
Sydonian 3 months ago
parent
commit
3bc7a93f80
7 changed files with 67 additions and 34 deletions
  1. +32
    -6
      client/internal/accesstoken/accesstoken.go
  2. +5
    -4
      coordinator/internal/db/db.go
  3. +6
    -0
      coordinator/internal/db/user_access_token.go
  4. +2
    -2
      coordinator/internal/repl/user.go
  5. +1
    -1
      coordinator/internal/rpc/storage.go
  6. +2
    -2
      coordinator/internal/rpc/user.go
  7. +19
    -19
      coordinator/internal/ticktock/clear_expired_access_token.go

+ 32
- 6
client/internal/accesstoken/accesstoken.go View File

@@ -11,6 +11,7 @@ import (


"gitlink.org.cn/cloudream/common/pkgs/async" "gitlink.org.cn/cloudream/common/pkgs/async"
"gitlink.org.cn/cloudream/common/pkgs/logger" "gitlink.org.cn/cloudream/common/pkgs/logger"
"gitlink.org.cn/cloudream/jcs-pub/common/ecode"
stgglb "gitlink.org.cn/cloudream/jcs-pub/common/globals" stgglb "gitlink.org.cn/cloudream/jcs-pub/common/globals"
"gitlink.org.cn/cloudream/jcs-pub/common/pkgs/accesstoken" "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/accesstoken"
"gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc"
@@ -106,25 +107,50 @@ func (k *Keeper) Start() *async.UnboundChannel[KeeperEvent] {
continue continue
} }


var priKeyStr string
var newToken jcstypes.UserAccessToken

corCli := stgglb.CoordinatorRPCPool.Get() corCli := stgglb.CoordinatorRPCPool.Get()
refResp, cerr := corCli.UserRefreshToken(context.Background(), &corrpc.UserRefreshToken{}) refResp, cerr := corCli.UserRefreshToken(context.Background(), &corrpc.UserRefreshToken{})
if cerr != nil { if cerr != nil {
log.Warnf("refresh token: %v", cerr)
corCli.Release()
continue
if cerr.Code != string(ecode.Unauthorized) {
log.Warnf("refresh token: %v", cerr)
corCli.Release()
continue
}

log.Warnf("unauthorized to refresh token, login again")

loginResp, cerr := corCli.UserLogin(context.Background(), &corrpc.UserLogin{
Account: k.cfg.Account,
Password: k.cfg.Password,
})
if cerr != nil {
log.Warnf("login: %v", cerr)
corCli.Release()
continue

} else {
priKeyStr = loginResp.PrivateKey
newToken = loginResp.Token
}

} else {
priKeyStr = refResp.PrivateKey
newToken = refResp.Token
} }


priKey, err := hex.DecodeString(refResp.PrivateKey)
priKey, err := hex.DecodeString(priKeyStr)
if err != nil { if err != nil {
log.Warnf("decode private key: %v", err) log.Warnf("decode private key: %v", err)
corCli.Release() corCli.Release()
continue continue
} }


log.Infof("refresh token success, new token expires at %v", refResp.Token.ExpiresAt)
log.Infof("new token expires at %v", newToken.ExpiresAt)


k.lock.Lock() k.lock.Lock()
k.token = refResp.Token
k.token = newToken
k.priKey = priKey k.priKey = priKey
k.lock.Unlock() k.lock.Unlock()




+ 5
- 4
coordinator/internal/db/db.go View File

@@ -38,14 +38,15 @@ func DoTx01[R any](db *DB, do func(tx SQLContext) (R, error)) (R, error) {
return ret, err return ret, err
} }


func DoTx02[R any](db *DB, do func(tx SQLContext) (R, error)) (R, error) {
var ret R
func DoTx02[R1, R2 any](db *DB, do func(tx SQLContext) (R1, R2, error)) (R1, R2, error) {
var r1 R1
var r2 R2
err := db.db.Transaction(func(tx *gorm.DB) error { err := db.db.Transaction(func(tx *gorm.DB) error {
var err error var err error
ret, err = do(SQLContext{tx})
r1, r2, err = do(SQLContext{tx})
return err return err
}) })
return ret, err
return r1, r2, err
} }


func DoTx12[T any, R any](db *DB, do func(tx SQLContext, t T) (R, error), t T) (R, error) { func DoTx12[T any, R any](db *DB, do func(tx SQLContext, t T) (R, error), t T) (R, error) {


+ 6
- 0
coordinator/internal/db/user_access_token.go View File

@@ -24,6 +24,12 @@ func (*UserAccessTokenDB) Create(ctx SQLContext, token *jcstypes.UserAccessToken
return ctx.Table("UserAccessToken").Create(token).Error return ctx.Table("UserAccessToken").Create(token).Error
} }


func (*UserAccessTokenDB) GetExpired(ctx SQLContext, expireTime time.Time) ([]jcstypes.UserAccessToken, error) {
var ret []jcstypes.UserAccessToken
err := ctx.Table("UserAccessToken").Where("ExpiresAt < ?", expireTime).Find(&ret).Error
return ret, err
}

func (db *UserAccessTokenDB) DeleteByID(ctx SQLContext, userID jcstypes.UserID, tokenID jcstypes.AccessTokenID) error { func (db *UserAccessTokenDB) DeleteByID(ctx SQLContext, userID jcstypes.UserID, tokenID jcstypes.AccessTokenID) error {
return ctx.Table("UserAccessToken").Where("UserID = ? AND TokenID = ?", userID, tokenID).Delete(&jcstypes.UserAccessToken{}).Error return ctx.Table("UserAccessToken").Where("UserID = ? AND TokenID = ?", userID, tokenID).Delete(&jcstypes.UserAccessToken{}).Error
} }


+ 2
- 2
coordinator/internal/repl/user.go View File

@@ -66,7 +66,7 @@ func userCreate(ctx *CommandContext, account string, nickName string) {
return return
} }


user, err := db.DoTx02(ctx.repl.db, func(tx db.SQLContext) (jcstypes.User, error) {
user, err := db.DoTx01(ctx.repl.db, func(tx db.SQLContext) (jcstypes.User, error) {
return ctx.repl.db.User().Create(tx, account, hex.EncodeToString(passHash), nickName) return ctx.repl.db.User().Create(tx, account, hex.EncodeToString(passHash), nickName)
}) })
if err != nil { if err != nil {
@@ -87,7 +87,7 @@ func userLogout(ctx *CommandContext, account string, tokenID jcstypes.AccessToke
log := logger.WithField("UserID", acc.UserID).WithField("TokenID", tokenID) log := logger.WithField("UserID", acc.UserID).WithField("TokenID", tokenID)


d := ctx.repl.db d := ctx.repl.db
loaded, err := db.DoTx02(d, func(tx db.SQLContext) ([]jcstypes.LoadedAccessToken, error) {
loaded, err := db.DoTx01(d, func(tx db.SQLContext) ([]jcstypes.LoadedAccessToken, error) {
token, err := d.UserAccessToken().GetByID(tx, acc.UserID, tokenID) token, err := d.UserAccessToken().GetByID(tx, acc.UserID, tokenID)
if err != nil { if err != nil {
return nil, err return nil, err


+ 1
- 1
coordinator/internal/rpc/storage.go View File

@@ -13,7 +13,7 @@ import (


func (svc *Service) SelectStorageHub(ctx context.Context, msg *corrpc.SelectStorageHub) (*corrpc.SelectStorageHubResp, *rpc.CodeError) { func (svc *Service) SelectStorageHub(ctx context.Context, msg *corrpc.SelectStorageHub) (*corrpc.SelectStorageHubResp, *rpc.CodeError) {
d := svc.db d := svc.db
resp, err := db.DoTx02(d, func(tx db.SQLContext) ([]*jcstypes.Hub, error) {
resp, err := db.DoTx01(d, func(tx db.SQLContext) ([]*jcstypes.Hub, error) {
allLoc, err := d.HubLocation().GetAll(tx) allLoc, err := d.HubLocation().GetAll(tx)
if err != nil { if err != nil {
return nil, err return nil, err


+ 2
- 2
coordinator/internal/rpc/user.go View File

@@ -122,7 +122,7 @@ func (svc *Service) UserLogout(ctx context.Context, msg *corrpc.UserLogout) (*co


log := logger.WithField("UserID", authInfo.UserID).WithField("TokenID", authInfo.AccessTokenID) log := logger.WithField("UserID", authInfo.UserID).WithField("TokenID", authInfo.AccessTokenID)


loaded, err := db.DoTx02(svc.db, func(tx db.SQLContext) ([]jcstypes.LoadedAccessToken, error) {
loaded, err := db.DoTx01(svc.db, func(tx db.SQLContext) ([]jcstypes.LoadedAccessToken, error) {
token, err := svc.db.UserAccessToken().GetByID(tx, authInfo.UserID, authInfo.AccessTokenID) token, err := svc.db.UserAccessToken().GetByID(tx, authInfo.UserID, authInfo.AccessTokenID)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -195,7 +195,7 @@ func (svc *Service) notifyLoadedHubs(userID jcstypes.UserID, tokenID jcstypes.Ac
} }


func (svc *Service) HubLoadAccessToken(ctx context.Context, msg *corrpc.HubLoadAccessToken) (*corrpc.HubLoadAccessTokenResp, *rpc.CodeError) { func (svc *Service) HubLoadAccessToken(ctx context.Context, msg *corrpc.HubLoadAccessToken) (*corrpc.HubLoadAccessTokenResp, *rpc.CodeError) {
token, err := db.DoTx02(svc.db, func(tx db.SQLContext) (jcstypes.UserAccessToken, error) {
token, err := db.DoTx01(svc.db, func(tx db.SQLContext) (jcstypes.UserAccessToken, error) {
token, err := svc.db.UserAccessToken().GetByID(tx, msg.UserID, msg.TokenID) token, err := svc.db.UserAccessToken().GetByID(tx, msg.UserID, msg.TokenID)
if err != nil { if err != nil {
return jcstypes.UserAccessToken{}, err return jcstypes.UserAccessToken{}, err


+ 19
- 19
coordinator/internal/ticktock/clear_expired_access_token.go View File

@@ -30,49 +30,49 @@ func (j *ClearExpiredAccessToken) Execute(t *TickTock) {
log.Infof("job end, time: %v", time.Since(startTime)) log.Infof("job end, time: %v", time.Since(startTime))
}() }()


expired, err := db.DoTx02(t.db, func(tx db.SQLContext) ([]jcstypes.LoadedAccessToken, error) {
expd, loaded, err := db.DoTx02(t.db, func(tx db.SQLContext) ([]jcstypes.UserAccessToken, []jcstypes.LoadedAccessToken, error) {
nowTime := time.Now() nowTime := time.Now()
expired, err := t.db.LoadedAccessToken().GetExpired(tx, nowTime)
loaded, err := t.db.LoadedAccessToken().GetExpired(tx, nowTime)
if err != nil { if err != nil {
return nil, fmt.Errorf("get expired access token load record: %w", err)
return nil, nil, fmt.Errorf("get expired access token load record: %w", err)
} }


err = t.db.LoadedAccessToken().DeleteExpired(tx, nowTime) err = t.db.LoadedAccessToken().DeleteExpired(tx, nowTime)
if err != nil { if err != nil {
return nil, fmt.Errorf("delete expired access token load record: %w", err)
return nil, nil, fmt.Errorf("delete expired access token load record: %w", err)
}

expd, err := t.db.UserAccessToken().GetExpired(tx, nowTime)
if err != nil {
return nil, nil, fmt.Errorf("get expired user access token: %w", err)
} }


err = t.db.UserAccessToken().DeleteExpired(tx, nowTime) err = t.db.UserAccessToken().DeleteExpired(tx, nowTime)
if err != nil { if err != nil {
return nil, fmt.Errorf("delete expired user access token: %w", err)
return nil, nil, fmt.Errorf("delete expired user access token: %w", err)
} }


return expired, nil
return expd, loaded, nil
}) })
if err != nil { if err != nil {
log.Warn(err.Error()) log.Warn(err.Error())
return return
} }


uniToken := make(map[accesstoken.CacheKey]bool)
for _, t := range expired {
uniToken[accesstoken.CacheKey{
UserID: t.UserID,
TokenID: t.TokenID,
}] = true
}

log.Infof("%v expired access token cleared", len(uniToken))
log.Infof("%v(loaded: %v) expired access token cleared", len(expd), len(loaded))


// 通知本服务的AccessToken缓存失效 // 通知本服务的AccessToken缓存失效
for k := range uniToken {
t.accessToken.NotifyTokenInvalid(k)
for _, e := range expd {
t.accessToken.NotifyTokenInvalid(accesstoken.CacheKey{
UserID: e.UserID,
TokenID: e.TokenID,
})
} }


// 通知所有加载了失效Token的Hub // 通知所有加载了失效Token的Hub


var loadedHubIDs []jcstypes.HubID var loadedHubIDs []jcstypes.HubID
for _, e := range expired {
for _, e := range loaded {
loadedHubIDs = append(loadedHubIDs, e.HubID) loadedHubIDs = append(loadedHubIDs, e.HubID)
} }


@@ -87,7 +87,7 @@ func (j *ClearExpiredAccessToken) Execute(t *TickTock) {
hubMap[h.HubID] = h hubMap[h.HubID] = h
} }


for _, e := range expired {
for _, e := range loaded {
h, ok := hubMap[e.HubID] h, ok := hubMap[e.HubID]
if !ok { if !ok {
continue continue


Loading…
Cancel
Save