From f054aaedb4e9cafabd8416a38394fcbdb7d26108 Mon Sep 17 00:00:00 2001 From: Sydonian <794346190@qq.com> Date: Fri, 30 May 2025 15:58:31 +0800 Subject: [PATCH] =?UTF-8?q?=E8=A7=A3=E5=86=B3=E8=B0=83=E8=AF=95=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/pkgs/accesstoken/accesstoken.go | 7 +- common/pkgs/rpc/coordinator/user.go | 2 +- coordinator/internal/cmd/serve.go | 2 +- coordinator/internal/repl/repl.go | 19 ++--- coordinator/internal/repl/user.go | 98 ++++++++++++++++++++++++++ 5 files changed, 114 insertions(+), 14 deletions(-) diff --git a/common/pkgs/accesstoken/accesstoken.go b/common/pkgs/accesstoken/accesstoken.go index 314b7c8..0c6692c 100644 --- a/common/pkgs/accesstoken/accesstoken.go +++ b/common/pkgs/accesstoken/accesstoken.go @@ -160,12 +160,11 @@ func (mc *Cache) NotifyTokenInvalid(key CacheKey) { LastUsedAt: time.Now(), } mc.cache[key] = entry - return + } else { + entry.IsTokenValid = false + entry.LastUsedAt = time.Now() } - entry.IsTokenValid = false - entry.LastUsedAt = time.Now() - log.WithField("UserID", key.UserID).WithField("TokenID", key.TokenID).Infof("notify token invalid") } diff --git a/common/pkgs/rpc/coordinator/user.go b/common/pkgs/rpc/coordinator/user.go index 2d44a5b..2da8c67 100644 --- a/common/pkgs/rpc/coordinator/user.go +++ b/common/pkgs/rpc/coordinator/user.go @@ -46,7 +46,7 @@ type UserRefreshTokenResp struct { PrivateKey string } -var _ = TokenAuth(Coordinator_UserLogin_FullMethodName) +var _ = TokenAuth(Coordinator_UserRefreshToken_FullMethodName) func (c *Client) UserRefreshToken(ctx context.Context, msg *UserRefreshToken) (*UserRefreshTokenResp, *rpc.CodeError) { if c.fusedErr != nil { diff --git a/coordinator/internal/cmd/serve.go b/coordinator/internal/cmd/serve.go index 12f4c08..24fe6ce 100644 --- a/coordinator/internal/cmd/serve.go +++ b/coordinator/internal/cmd/serve.go @@ -81,7 +81,7 @@ func serve(configPath string) { defer tktk.Stop() // 交互式命令行 - rep := repl.New(db2, tktk) + rep := repl.New(db2, tktk, accToken) replCh := rep.Start() /// 开始监听各个模块的事件 diff --git a/coordinator/internal/repl/repl.go b/coordinator/internal/repl/repl.go index 278b299..19498fd 100644 --- a/coordinator/internal/repl/repl.go +++ b/coordinator/internal/repl/repl.go @@ -8,6 +8,7 @@ import ( "github.com/c-bata/go-prompt" "github.com/spf13/cobra" "gitlink.org.cn/cloudream/common/pkgs/async" + "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/accesstoken" "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/db" "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/ticktock" "golang.org/x/term" @@ -28,17 +29,19 @@ type ExitEvent struct { } type Repl struct { - prompt *prompt.Prompt - evtCh *ReplEventChan - db *db.DB - tktk *ticktock.TickTock + prompt *prompt.Prompt + evtCh *ReplEventChan + db *db.DB + tktk *ticktock.TickTock + accessToken *accesstoken.Cache } -func New(db *db.DB, ticktock *ticktock.TickTock) *Repl { +func New(db *db.DB, ticktock *ticktock.TickTock, accessToken *accesstoken.Cache) *Repl { r := &Repl{ - evtCh: async.NewUnboundChannel[ReplEvent](), - db: db, - tktk: ticktock, + evtCh: async.NewUnboundChannel[ReplEvent](), + db: db, + tktk: ticktock, + accessToken: accessToken, } return r } diff --git a/coordinator/internal/repl/user.go b/coordinator/internal/repl/user.go index 268f9f5..7645e63 100644 --- a/coordinator/internal/repl/user.go +++ b/coordinator/internal/repl/user.go @@ -1,15 +1,21 @@ package repl import ( + "context" "encoding/hex" "fmt" "os" "github.com/spf13/cobra" + "gitlink.org.cn/cloudream/common/pkgs/logger" + stgglb "gitlink.org.cn/cloudream/jcs-pub/common/globals" + hubrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/hub" + "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/accesstoken" "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/db" cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types" "golang.org/x/crypto/bcrypt" "golang.org/x/term" + "gorm.io/gorm" ) func init() { @@ -28,6 +34,16 @@ func init() { }, } userCmd.AddCommand(createCmd) + + logoutCmd := &cobra.Command{ + Use: "logout [account] [tokenID]", + Short: "logout from a user account", + Args: cobra.ExactArgs(2), + Run: func(cmd *cobra.Command, args []string) { + userLogout(GetCmdCtx(cmd), args[0], cortypes.AccessTokenID(args[1])) + }, + } + userCmd.AddCommand(logoutCmd) } func userCreate(ctx *CommandContext, account string, nickName string) { @@ -60,3 +76,85 @@ func userCreate(ctx *CommandContext, account string, nickName string) { fmt.Printf("user %s created\n", user.Account) } + +func userLogout(ctx *CommandContext, account string, tokenID cortypes.AccessTokenID) { + acc, err := ctx.repl.db.User().GetByAccount(ctx.repl.db.DefCtx(), account) + if err != nil { + fmt.Printf("user %s not found\n", account) + return + } + + log := logger.WithField("UserID", acc.UserID).WithField("TokenID", tokenID) + + d := ctx.repl.db + loaded, err := db.DoTx02(d, func(tx db.SQLContext) ([]cortypes.LoadedAccessToken, error) { + token, err := d.UserAccessToken().GetByID(tx, acc.UserID, tokenID) + if err != nil { + return nil, err + } + + err = d.UserAccessToken().DeleteByID(tx, token.UserID, token.TokenID) + if err != nil { + return nil, err + } + + loaded, err := d.LoadedAccessToken().GetByUserIDAndTokenID(tx, token.UserID, token.TokenID) + if err != nil { + return nil, err + } + + err = d.LoadedAccessToken().DeleteAllByUserIDAndTokenID(tx, token.UserID, token.TokenID) + if err != nil { + return nil, err + } + + return loaded, nil + }) + if err != nil { + log.Warnf("delete access token: %v", err) + if err == gorm.ErrRecordNotFound { + return + } + + return + } + + ctx.repl.accessToken.NotifyTokenInvalid(accesstoken.CacheKey{ + UserID: acc.UserID, + TokenID: tokenID, + }) + + var loadedHubIDs []cortypes.HubID + for _, l := range loaded { + loadedHubIDs = append(loadedHubIDs, l.HubID) + } + + notifyLoadedHubs(ctx, acc.UserID, tokenID, loadedHubIDs) +} + +func notifyLoadedHubs(ctx *CommandContext, userID cortypes.UserID, tokenID cortypes.AccessTokenID, loadedHubIDs []cortypes.HubID) { + log := logger.WithField("UserID", userID).WithField("TokenID", tokenID) + + d := ctx.repl.db + + loadedHubs, err := d.Hub().BatchGetByID(d.DefCtx(), loadedHubIDs) + if err != nil { + log.Warnf("getting hubs: %v", err) + return + } + + for _, l := range loadedHubs { + addr, ok := l.Address.(*cortypes.GRPCAddressInfo) + if !ok { + continue + } + + cli := stgglb.HubRPCPool.Get(addr.ExternalIP, addr.ExternalGRPCPort) + // 不关心返回值 + cli.NotifyUserAccessTokenInvalid(context.Background(), &hubrpc.NotifyUserAccessTokenInvalid{ + UserID: userID, + TokenID: tokenID, + }) + cli.Release() + } +}