diff --git a/client/internal/accesstoken/accesstoken.go b/client/internal/accesstoken/accesstoken.go index 1ce967c..eff441c 100644 --- a/client/internal/accesstoken/accesstoken.go +++ b/client/internal/accesstoken/accesstoken.go @@ -11,6 +11,7 @@ import ( "gitlink.org.cn/cloudream/common/pkgs/async" "gitlink.org.cn/cloudream/common/pkgs/logger" + "gitlink.org.cn/cloudream/jcs-pub/common/ecode" stgglb "gitlink.org.cn/cloudream/jcs-pub/common/globals" "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/accesstoken" "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" @@ -106,25 +107,50 @@ func (k *Keeper) Start() *async.UnboundChannel[KeeperEvent] { continue } + var priKeyStr string + var newToken jcstypes.UserAccessToken + corCli := stgglb.CoordinatorRPCPool.Get() refResp, cerr := corCli.UserRefreshToken(context.Background(), &corrpc.UserRefreshToken{}) if cerr != nil { - log.Warnf("refresh token: %v", cerr) - corCli.Release() - continue + if cerr.Code != string(ecode.Unauthorized) { + log.Warnf("refresh token: %v", cerr) + corCli.Release() + continue + } + + log.Warnf("unauthorized to refresh token, login again") + + loginResp, cerr := corCli.UserLogin(context.Background(), &corrpc.UserLogin{ + Account: k.cfg.Account, + Password: k.cfg.Password, + }) + if cerr != nil { + log.Warnf("login: %v", cerr) + corCli.Release() + continue + + } else { + priKeyStr = loginResp.PrivateKey + newToken = loginResp.Token + } + + } else { + priKeyStr = refResp.PrivateKey + newToken = refResp.Token } - priKey, err := hex.DecodeString(refResp.PrivateKey) + priKey, err := hex.DecodeString(priKeyStr) if err != nil { log.Warnf("decode private key: %v", err) corCli.Release() continue } - log.Infof("refresh token success, new token expires at %v", refResp.Token.ExpiresAt) + log.Infof("new token expires at %v", newToken.ExpiresAt) k.lock.Lock() - k.token = refResp.Token + k.token = newToken k.priKey = priKey k.lock.Unlock() diff --git a/coordinator/internal/db/db.go b/coordinator/internal/db/db.go index 243d8f1..157efd5 100644 --- a/coordinator/internal/db/db.go +++ b/coordinator/internal/db/db.go @@ -38,14 +38,15 @@ func DoTx01[R any](db *DB, do func(tx SQLContext) (R, error)) (R, error) { return ret, err } -func DoTx02[R any](db *DB, do func(tx SQLContext) (R, error)) (R, error) { - var ret R +func DoTx02[R1, R2 any](db *DB, do func(tx SQLContext) (R1, R2, error)) (R1, R2, error) { + var r1 R1 + var r2 R2 err := db.db.Transaction(func(tx *gorm.DB) error { var err error - ret, err = do(SQLContext{tx}) + r1, r2, err = do(SQLContext{tx}) return err }) - return ret, err + return r1, r2, err } func DoTx12[T any, R any](db *DB, do func(tx SQLContext, t T) (R, error), t T) (R, error) { diff --git a/coordinator/internal/db/user_access_token.go b/coordinator/internal/db/user_access_token.go index a800d7a..449f67e 100644 --- a/coordinator/internal/db/user_access_token.go +++ b/coordinator/internal/db/user_access_token.go @@ -24,6 +24,12 @@ func (*UserAccessTokenDB) Create(ctx SQLContext, token *jcstypes.UserAccessToken return ctx.Table("UserAccessToken").Create(token).Error } +func (*UserAccessTokenDB) GetExpired(ctx SQLContext, expireTime time.Time) ([]jcstypes.UserAccessToken, error) { + var ret []jcstypes.UserAccessToken + err := ctx.Table("UserAccessToken").Where("ExpiresAt < ?", expireTime).Find(&ret).Error + return ret, err +} + func (db *UserAccessTokenDB) DeleteByID(ctx SQLContext, userID jcstypes.UserID, tokenID jcstypes.AccessTokenID) error { return ctx.Table("UserAccessToken").Where("UserID = ? AND TokenID = ?", userID, tokenID).Delete(&jcstypes.UserAccessToken{}).Error } diff --git a/coordinator/internal/repl/user.go b/coordinator/internal/repl/user.go index c619e04..f9aaa1d 100644 --- a/coordinator/internal/repl/user.go +++ b/coordinator/internal/repl/user.go @@ -66,7 +66,7 @@ func userCreate(ctx *CommandContext, account string, nickName string) { return } - user, err := db.DoTx02(ctx.repl.db, func(tx db.SQLContext) (jcstypes.User, error) { + user, err := db.DoTx01(ctx.repl.db, func(tx db.SQLContext) (jcstypes.User, error) { return ctx.repl.db.User().Create(tx, account, hex.EncodeToString(passHash), nickName) }) if err != nil { @@ -87,7 +87,7 @@ func userLogout(ctx *CommandContext, account string, tokenID jcstypes.AccessToke log := logger.WithField("UserID", acc.UserID).WithField("TokenID", tokenID) d := ctx.repl.db - loaded, err := db.DoTx02(d, func(tx db.SQLContext) ([]jcstypes.LoadedAccessToken, error) { + loaded, err := db.DoTx01(d, func(tx db.SQLContext) ([]jcstypes.LoadedAccessToken, error) { token, err := d.UserAccessToken().GetByID(tx, acc.UserID, tokenID) if err != nil { return nil, err diff --git a/coordinator/internal/rpc/storage.go b/coordinator/internal/rpc/storage.go index 445257e..9e4dce7 100644 --- a/coordinator/internal/rpc/storage.go +++ b/coordinator/internal/rpc/storage.go @@ -13,7 +13,7 @@ import ( func (svc *Service) SelectStorageHub(ctx context.Context, msg *corrpc.SelectStorageHub) (*corrpc.SelectStorageHubResp, *rpc.CodeError) { d := svc.db - resp, err := db.DoTx02(d, func(tx db.SQLContext) ([]*jcstypes.Hub, error) { + resp, err := db.DoTx01(d, func(tx db.SQLContext) ([]*jcstypes.Hub, error) { allLoc, err := d.HubLocation().GetAll(tx) if err != nil { return nil, err diff --git a/coordinator/internal/rpc/user.go b/coordinator/internal/rpc/user.go index bbe6aec..788f311 100644 --- a/coordinator/internal/rpc/user.go +++ b/coordinator/internal/rpc/user.go @@ -122,7 +122,7 @@ func (svc *Service) UserLogout(ctx context.Context, msg *corrpc.UserLogout) (*co log := logger.WithField("UserID", authInfo.UserID).WithField("TokenID", authInfo.AccessTokenID) - loaded, err := db.DoTx02(svc.db, func(tx db.SQLContext) ([]jcstypes.LoadedAccessToken, error) { + loaded, err := db.DoTx01(svc.db, func(tx db.SQLContext) ([]jcstypes.LoadedAccessToken, error) { token, err := svc.db.UserAccessToken().GetByID(tx, authInfo.UserID, authInfo.AccessTokenID) if err != nil { return nil, err @@ -195,7 +195,7 @@ func (svc *Service) notifyLoadedHubs(userID jcstypes.UserID, tokenID jcstypes.Ac } func (svc *Service) HubLoadAccessToken(ctx context.Context, msg *corrpc.HubLoadAccessToken) (*corrpc.HubLoadAccessTokenResp, *rpc.CodeError) { - token, err := db.DoTx02(svc.db, func(tx db.SQLContext) (jcstypes.UserAccessToken, error) { + token, err := db.DoTx01(svc.db, func(tx db.SQLContext) (jcstypes.UserAccessToken, error) { token, err := svc.db.UserAccessToken().GetByID(tx, msg.UserID, msg.TokenID) if err != nil { return jcstypes.UserAccessToken{}, err diff --git a/coordinator/internal/ticktock/clear_expired_access_token.go b/coordinator/internal/ticktock/clear_expired_access_token.go index 28db0c0..0b0d7c2 100644 --- a/coordinator/internal/ticktock/clear_expired_access_token.go +++ b/coordinator/internal/ticktock/clear_expired_access_token.go @@ -30,49 +30,49 @@ func (j *ClearExpiredAccessToken) Execute(t *TickTock) { log.Infof("job end, time: %v", time.Since(startTime)) }() - expired, err := db.DoTx02(t.db, func(tx db.SQLContext) ([]jcstypes.LoadedAccessToken, error) { + expd, loaded, err := db.DoTx02(t.db, func(tx db.SQLContext) ([]jcstypes.UserAccessToken, []jcstypes.LoadedAccessToken, error) { nowTime := time.Now() - expired, err := t.db.LoadedAccessToken().GetExpired(tx, nowTime) + loaded, err := t.db.LoadedAccessToken().GetExpired(tx, nowTime) if err != nil { - return nil, fmt.Errorf("get expired access token load record: %w", err) + return nil, nil, fmt.Errorf("get expired access token load record: %w", err) } err = t.db.LoadedAccessToken().DeleteExpired(tx, nowTime) if err != nil { - return nil, fmt.Errorf("delete expired access token load record: %w", err) + return nil, nil, fmt.Errorf("delete expired access token load record: %w", err) + } + + expd, err := t.db.UserAccessToken().GetExpired(tx, nowTime) + if err != nil { + return nil, nil, fmt.Errorf("get expired user access token: %w", err) } err = t.db.UserAccessToken().DeleteExpired(tx, nowTime) if err != nil { - return nil, fmt.Errorf("delete expired user access token: %w", err) + return nil, nil, fmt.Errorf("delete expired user access token: %w", err) } - return expired, nil + return expd, loaded, nil }) if err != nil { log.Warn(err.Error()) return } - uniToken := make(map[accesstoken.CacheKey]bool) - for _, t := range expired { - uniToken[accesstoken.CacheKey{ - UserID: t.UserID, - TokenID: t.TokenID, - }] = true - } - - log.Infof("%v expired access token cleared", len(uniToken)) + log.Infof("%v(loaded: %v) expired access token cleared", len(expd), len(loaded)) // 通知本服务的AccessToken缓存失效 - for k := range uniToken { - t.accessToken.NotifyTokenInvalid(k) + for _, e := range expd { + t.accessToken.NotifyTokenInvalid(accesstoken.CacheKey{ + UserID: e.UserID, + TokenID: e.TokenID, + }) } // 通知所有加载了失效Token的Hub var loadedHubIDs []jcstypes.HubID - for _, e := range expired { + for _, e := range loaded { loadedHubIDs = append(loadedHubIDs, e.HubID) } @@ -87,7 +87,7 @@ func (j *ClearExpiredAccessToken) Execute(t *TickTock) { hubMap[h.HubID] = h } - for _, e := range expired { + for _, e := range loaded { h, ok := hubMap[e.HubID] if !ok { continue