| @@ -11,6 +11,7 @@ import ( | |||||
| "gitlink.org.cn/cloudream/common/pkgs/async" | "gitlink.org.cn/cloudream/common/pkgs/async" | ||||
| "gitlink.org.cn/cloudream/common/pkgs/logger" | "gitlink.org.cn/cloudream/common/pkgs/logger" | ||||
| "gitlink.org.cn/cloudream/jcs-pub/common/ecode" | |||||
| stgglb "gitlink.org.cn/cloudream/jcs-pub/common/globals" | stgglb "gitlink.org.cn/cloudream/jcs-pub/common/globals" | ||||
| "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/accesstoken" | "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/accesstoken" | ||||
| "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" | "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" | ||||
| @@ -106,25 +107,50 @@ func (k *Keeper) Start() *async.UnboundChannel[KeeperEvent] { | |||||
| continue | continue | ||||
| } | } | ||||
| var priKeyStr string | |||||
| var newToken jcstypes.UserAccessToken | |||||
| corCli := stgglb.CoordinatorRPCPool.Get() | corCli := stgglb.CoordinatorRPCPool.Get() | ||||
| refResp, cerr := corCli.UserRefreshToken(context.Background(), &corrpc.UserRefreshToken{}) | refResp, cerr := corCli.UserRefreshToken(context.Background(), &corrpc.UserRefreshToken{}) | ||||
| if cerr != nil { | if cerr != nil { | ||||
| log.Warnf("refresh token: %v", cerr) | |||||
| corCli.Release() | |||||
| continue | |||||
| if cerr.Code != string(ecode.Unauthorized) { | |||||
| log.Warnf("refresh token: %v", cerr) | |||||
| corCli.Release() | |||||
| continue | |||||
| } | |||||
| log.Warnf("unauthorized to refresh token, login again") | |||||
| loginResp, cerr := corCli.UserLogin(context.Background(), &corrpc.UserLogin{ | |||||
| Account: k.cfg.Account, | |||||
| Password: k.cfg.Password, | |||||
| }) | |||||
| if cerr != nil { | |||||
| log.Warnf("login: %v", cerr) | |||||
| corCli.Release() | |||||
| continue | |||||
| } else { | |||||
| priKeyStr = loginResp.PrivateKey | |||||
| newToken = loginResp.Token | |||||
| } | |||||
| } else { | |||||
| priKeyStr = refResp.PrivateKey | |||||
| newToken = refResp.Token | |||||
| } | } | ||||
| priKey, err := hex.DecodeString(refResp.PrivateKey) | |||||
| priKey, err := hex.DecodeString(priKeyStr) | |||||
| if err != nil { | if err != nil { | ||||
| log.Warnf("decode private key: %v", err) | log.Warnf("decode private key: %v", err) | ||||
| corCli.Release() | corCli.Release() | ||||
| continue | continue | ||||
| } | } | ||||
| log.Infof("refresh token success, new token expires at %v", refResp.Token.ExpiresAt) | |||||
| log.Infof("new token expires at %v", newToken.ExpiresAt) | |||||
| k.lock.Lock() | k.lock.Lock() | ||||
| k.token = refResp.Token | |||||
| k.token = newToken | |||||
| k.priKey = priKey | k.priKey = priKey | ||||
| k.lock.Unlock() | k.lock.Unlock() | ||||
| @@ -38,14 +38,15 @@ func DoTx01[R any](db *DB, do func(tx SQLContext) (R, error)) (R, error) { | |||||
| return ret, err | return ret, err | ||||
| } | } | ||||
| func DoTx02[R any](db *DB, do func(tx SQLContext) (R, error)) (R, error) { | |||||
| var ret R | |||||
| func DoTx02[R1, R2 any](db *DB, do func(tx SQLContext) (R1, R2, error)) (R1, R2, error) { | |||||
| var r1 R1 | |||||
| var r2 R2 | |||||
| err := db.db.Transaction(func(tx *gorm.DB) error { | err := db.db.Transaction(func(tx *gorm.DB) error { | ||||
| var err error | var err error | ||||
| ret, err = do(SQLContext{tx}) | |||||
| r1, r2, err = do(SQLContext{tx}) | |||||
| return err | return err | ||||
| }) | }) | ||||
| return ret, err | |||||
| return r1, r2, err | |||||
| } | } | ||||
| func DoTx12[T any, R any](db *DB, do func(tx SQLContext, t T) (R, error), t T) (R, error) { | func DoTx12[T any, R any](db *DB, do func(tx SQLContext, t T) (R, error), t T) (R, error) { | ||||
| @@ -24,6 +24,12 @@ func (*UserAccessTokenDB) Create(ctx SQLContext, token *jcstypes.UserAccessToken | |||||
| return ctx.Table("UserAccessToken").Create(token).Error | return ctx.Table("UserAccessToken").Create(token).Error | ||||
| } | } | ||||
| func (*UserAccessTokenDB) GetExpired(ctx SQLContext, expireTime time.Time) ([]jcstypes.UserAccessToken, error) { | |||||
| var ret []jcstypes.UserAccessToken | |||||
| err := ctx.Table("UserAccessToken").Where("ExpiresAt < ?", expireTime).Find(&ret).Error | |||||
| return ret, err | |||||
| } | |||||
| func (db *UserAccessTokenDB) DeleteByID(ctx SQLContext, userID jcstypes.UserID, tokenID jcstypes.AccessTokenID) error { | func (db *UserAccessTokenDB) DeleteByID(ctx SQLContext, userID jcstypes.UserID, tokenID jcstypes.AccessTokenID) error { | ||||
| return ctx.Table("UserAccessToken").Where("UserID = ? AND TokenID = ?", userID, tokenID).Delete(&jcstypes.UserAccessToken{}).Error | return ctx.Table("UserAccessToken").Where("UserID = ? AND TokenID = ?", userID, tokenID).Delete(&jcstypes.UserAccessToken{}).Error | ||||
| } | } | ||||
| @@ -66,7 +66,7 @@ func userCreate(ctx *CommandContext, account string, nickName string) { | |||||
| return | return | ||||
| } | } | ||||
| user, err := db.DoTx02(ctx.repl.db, func(tx db.SQLContext) (jcstypes.User, error) { | |||||
| user, err := db.DoTx01(ctx.repl.db, func(tx db.SQLContext) (jcstypes.User, error) { | |||||
| return ctx.repl.db.User().Create(tx, account, hex.EncodeToString(passHash), nickName) | return ctx.repl.db.User().Create(tx, account, hex.EncodeToString(passHash), nickName) | ||||
| }) | }) | ||||
| if err != nil { | if err != nil { | ||||
| @@ -87,7 +87,7 @@ func userLogout(ctx *CommandContext, account string, tokenID jcstypes.AccessToke | |||||
| log := logger.WithField("UserID", acc.UserID).WithField("TokenID", tokenID) | log := logger.WithField("UserID", acc.UserID).WithField("TokenID", tokenID) | ||||
| d := ctx.repl.db | d := ctx.repl.db | ||||
| loaded, err := db.DoTx02(d, func(tx db.SQLContext) ([]jcstypes.LoadedAccessToken, error) { | |||||
| loaded, err := db.DoTx01(d, func(tx db.SQLContext) ([]jcstypes.LoadedAccessToken, error) { | |||||
| token, err := d.UserAccessToken().GetByID(tx, acc.UserID, tokenID) | token, err := d.UserAccessToken().GetByID(tx, acc.UserID, tokenID) | ||||
| if err != nil { | if err != nil { | ||||
| return nil, err | return nil, err | ||||
| @@ -13,7 +13,7 @@ import ( | |||||
| func (svc *Service) SelectStorageHub(ctx context.Context, msg *corrpc.SelectStorageHub) (*corrpc.SelectStorageHubResp, *rpc.CodeError) { | func (svc *Service) SelectStorageHub(ctx context.Context, msg *corrpc.SelectStorageHub) (*corrpc.SelectStorageHubResp, *rpc.CodeError) { | ||||
| d := svc.db | d := svc.db | ||||
| resp, err := db.DoTx02(d, func(tx db.SQLContext) ([]*jcstypes.Hub, error) { | |||||
| resp, err := db.DoTx01(d, func(tx db.SQLContext) ([]*jcstypes.Hub, error) { | |||||
| allLoc, err := d.HubLocation().GetAll(tx) | allLoc, err := d.HubLocation().GetAll(tx) | ||||
| if err != nil { | if err != nil { | ||||
| return nil, err | return nil, err | ||||
| @@ -122,7 +122,7 @@ func (svc *Service) UserLogout(ctx context.Context, msg *corrpc.UserLogout) (*co | |||||
| log := logger.WithField("UserID", authInfo.UserID).WithField("TokenID", authInfo.AccessTokenID) | log := logger.WithField("UserID", authInfo.UserID).WithField("TokenID", authInfo.AccessTokenID) | ||||
| loaded, err := db.DoTx02(svc.db, func(tx db.SQLContext) ([]jcstypes.LoadedAccessToken, error) { | |||||
| loaded, err := db.DoTx01(svc.db, func(tx db.SQLContext) ([]jcstypes.LoadedAccessToken, error) { | |||||
| token, err := svc.db.UserAccessToken().GetByID(tx, authInfo.UserID, authInfo.AccessTokenID) | token, err := svc.db.UserAccessToken().GetByID(tx, authInfo.UserID, authInfo.AccessTokenID) | ||||
| if err != nil { | if err != nil { | ||||
| return nil, err | return nil, err | ||||
| @@ -195,7 +195,7 @@ func (svc *Service) notifyLoadedHubs(userID jcstypes.UserID, tokenID jcstypes.Ac | |||||
| } | } | ||||
| func (svc *Service) HubLoadAccessToken(ctx context.Context, msg *corrpc.HubLoadAccessToken) (*corrpc.HubLoadAccessTokenResp, *rpc.CodeError) { | func (svc *Service) HubLoadAccessToken(ctx context.Context, msg *corrpc.HubLoadAccessToken) (*corrpc.HubLoadAccessTokenResp, *rpc.CodeError) { | ||||
| token, err := db.DoTx02(svc.db, func(tx db.SQLContext) (jcstypes.UserAccessToken, error) { | |||||
| token, err := db.DoTx01(svc.db, func(tx db.SQLContext) (jcstypes.UserAccessToken, error) { | |||||
| token, err := svc.db.UserAccessToken().GetByID(tx, msg.UserID, msg.TokenID) | token, err := svc.db.UserAccessToken().GetByID(tx, msg.UserID, msg.TokenID) | ||||
| if err != nil { | if err != nil { | ||||
| return jcstypes.UserAccessToken{}, err | return jcstypes.UserAccessToken{}, err | ||||
| @@ -30,49 +30,49 @@ func (j *ClearExpiredAccessToken) Execute(t *TickTock) { | |||||
| log.Infof("job end, time: %v", time.Since(startTime)) | log.Infof("job end, time: %v", time.Since(startTime)) | ||||
| }() | }() | ||||
| expired, err := db.DoTx02(t.db, func(tx db.SQLContext) ([]jcstypes.LoadedAccessToken, error) { | |||||
| expd, loaded, err := db.DoTx02(t.db, func(tx db.SQLContext) ([]jcstypes.UserAccessToken, []jcstypes.LoadedAccessToken, error) { | |||||
| nowTime := time.Now() | nowTime := time.Now() | ||||
| expired, err := t.db.LoadedAccessToken().GetExpired(tx, nowTime) | |||||
| loaded, err := t.db.LoadedAccessToken().GetExpired(tx, nowTime) | |||||
| if err != nil { | if err != nil { | ||||
| return nil, fmt.Errorf("get expired access token load record: %w", err) | |||||
| return nil, nil, fmt.Errorf("get expired access token load record: %w", err) | |||||
| } | } | ||||
| err = t.db.LoadedAccessToken().DeleteExpired(tx, nowTime) | err = t.db.LoadedAccessToken().DeleteExpired(tx, nowTime) | ||||
| if err != nil { | if err != nil { | ||||
| return nil, fmt.Errorf("delete expired access token load record: %w", err) | |||||
| return nil, nil, fmt.Errorf("delete expired access token load record: %w", err) | |||||
| } | |||||
| expd, err := t.db.UserAccessToken().GetExpired(tx, nowTime) | |||||
| if err != nil { | |||||
| return nil, nil, fmt.Errorf("get expired user access token: %w", err) | |||||
| } | } | ||||
| err = t.db.UserAccessToken().DeleteExpired(tx, nowTime) | err = t.db.UserAccessToken().DeleteExpired(tx, nowTime) | ||||
| if err != nil { | if err != nil { | ||||
| return nil, fmt.Errorf("delete expired user access token: %w", err) | |||||
| return nil, nil, fmt.Errorf("delete expired user access token: %w", err) | |||||
| } | } | ||||
| return expired, nil | |||||
| return expd, loaded, nil | |||||
| }) | }) | ||||
| if err != nil { | if err != nil { | ||||
| log.Warn(err.Error()) | log.Warn(err.Error()) | ||||
| return | return | ||||
| } | } | ||||
| uniToken := make(map[accesstoken.CacheKey]bool) | |||||
| for _, t := range expired { | |||||
| uniToken[accesstoken.CacheKey{ | |||||
| UserID: t.UserID, | |||||
| TokenID: t.TokenID, | |||||
| }] = true | |||||
| } | |||||
| log.Infof("%v expired access token cleared", len(uniToken)) | |||||
| log.Infof("%v(loaded: %v) expired access token cleared", len(expd), len(loaded)) | |||||
| // 通知本服务的AccessToken缓存失效 | // 通知本服务的AccessToken缓存失效 | ||||
| for k := range uniToken { | |||||
| t.accessToken.NotifyTokenInvalid(k) | |||||
| for _, e := range expd { | |||||
| t.accessToken.NotifyTokenInvalid(accesstoken.CacheKey{ | |||||
| UserID: e.UserID, | |||||
| TokenID: e.TokenID, | |||||
| }) | |||||
| } | } | ||||
| // 通知所有加载了失效Token的Hub | // 通知所有加载了失效Token的Hub | ||||
| var loadedHubIDs []jcstypes.HubID | var loadedHubIDs []jcstypes.HubID | ||||
| for _, e := range expired { | |||||
| for _, e := range loaded { | |||||
| loadedHubIDs = append(loadedHubIDs, e.HubID) | loadedHubIDs = append(loadedHubIDs, e.HubID) | ||||
| } | } | ||||
| @@ -87,7 +87,7 @@ func (j *ClearExpiredAccessToken) Execute(t *TickTock) { | |||||
| hubMap[h.HubID] = h | hubMap[h.HubID] = h | ||||
| } | } | ||||
| for _, e := range expired { | |||||
| for _, e := range loaded { | |||||
| h, ok := hubMap[e.HubID] | h, ok := hubMap[e.HubID] | ||||
| if !ok { | if !ok { | ||||
| continue | continue | ||||