|
- package accesstoken
-
- import (
- "crypto/ed25519"
- "encoding/hex"
- "fmt"
- "sync"
- "time"
-
- "gitlink.org.cn/cloudream/common/pkgs/async"
- "gitlink.org.cn/cloudream/common/pkgs/logger"
- "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc"
- cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types"
- )
-
- type CacheEvent interface {
- IsAccessTokenCacheEvent() bool
- }
-
- type ExitEvent struct {
- CacheEvent
- Err error
- }
-
- type CacheKey struct {
- UserID cortypes.UserID
- TokenID cortypes.AccessTokenID
- }
-
- var ErrTokenNotFound = fmt.Errorf("token not found")
-
- type AccessTokenLoader func(key CacheKey) (cortypes.UserAccessToken, error)
-
- type CacheEntry struct {
- IsTokenValid bool
- Token cortypes.UserAccessToken
- PublicKey ed25519.PublicKey
- LoadedAt time.Time
- LastUsedAt time.Time
- }
-
- type Cache struct {
- lock sync.Mutex
- cache map[CacheKey]*CacheEntry
- done chan any
- loader AccessTokenLoader
- }
-
- func New(loader AccessTokenLoader) *Cache {
- return &Cache{
- cache: make(map[CacheKey]*CacheEntry),
- done: make(chan any, 1),
- loader: loader,
- }
- }
-
- func (nc *Cache) Start() *async.UnboundChannel[CacheEvent] {
- log := logger.WithField("Mod", "AccessTokenCache")
-
- ch := async.NewUnboundChannel[CacheEvent]()
- go func() {
- ticker := time.NewTicker(time.Second * 10)
- defer ticker.Stop()
-
- loop:
- for {
- select {
- case <-nc.done:
- break loop
-
- case <-ticker.C:
- nc.lock.Lock()
- for key, entry := range nc.cache {
- if !entry.IsTokenValid {
- // 无效Token的记录5分钟后删除
- if time.Since(entry.LoadedAt) > time.Minute*5 {
- log.WithField("UserID", key.UserID).WithField("TokenID", key.TokenID).Infof("delete expired invalid token")
- delete(nc.cache, key)
- continue
- }
- } else {
- // 5分钟没有使用的Token则删除
- if time.Since(entry.LastUsedAt) > time.Minute*5 {
- log.WithField("UserID", key.UserID).WithField("TokenID", key.TokenID).Infof("delete unused token")
- delete(nc.cache, key)
- continue
- }
-
- // 过期Token标记为无效
- if time.Now().After(entry.Token.ExpiresAt) {
- entry.IsTokenValid = false
- entry.LastUsedAt = time.Now()
- log.WithField("UserID", key.UserID).WithField("TokenID", key.TokenID).Infof("token expired")
-
- } else if time.Since(entry.LoadedAt) > time.Minute*5 {
- // 依然有效的Token,则5分钟检查一次有效性
- go nc.load(key)
- }
- }
- }
- nc.lock.Unlock()
- }
- }
-
- ch.Send(&ExitEvent{})
- }()
- return ch
- }
-
- func (mc *Cache) Stop() {
- select {
- case mc.done <- true:
- default:
- }
- }
-
- func (mc *Cache) Get(key CacheKey) (*CacheEntry, bool) {
- var ret *CacheEntry
- var ok bool
-
- for i := 0; i < 2; i++ {
- mc.lock.Lock()
- entry, getOk := mc.cache[key]
-
- if getOk {
- ret = entry
- ok = true
- ret.LastUsedAt = time.Now()
-
- // 如果Token已经过期,则直接设置为无效Token。因为Token是随机生成的,几乎不可能把一个过期的Token再用上
- if entry.IsTokenValid && time.Now().After(entry.Token.ExpiresAt) {
- entry.IsTokenValid = false
- entry.LastUsedAt = time.Now()
- }
- }
-
- mc.lock.Unlock()
-
- if ok {
- break
- }
-
- mc.load(key)
- }
-
- return ret, ok
- }
-
- func (mc *Cache) NotifyTokenInvalid(key CacheKey) {
- log := logger.WithField("Mod", "AccessTokenCache")
-
- mc.lock.Lock()
- defer mc.lock.Unlock()
-
- entry, ok := mc.cache[key]
- if !ok {
- entry = &CacheEntry{
- IsTokenValid: false,
- LoadedAt: time.Now(),
- LastUsedAt: time.Now(),
- }
- mc.cache[key] = entry
- return
- }
-
- entry.IsTokenValid = false
- entry.LastUsedAt = time.Now()
-
- log.WithField("UserID", key.UserID).WithField("TokenID", key.TokenID).Infof("notify token invalid")
- }
-
- func (mc *Cache) load(key CacheKey) {
- log := logger.WithField("Mod", "AccessTokenCache")
-
- loadToken, cerr := mc.loader(key)
-
- mc.lock.Lock()
- defer mc.lock.Unlock()
-
- if cerr != nil {
- log.WithField("UserID", key.UserID).WithField("TokenID", key.TokenID).Warnf("load token: %v", cerr)
-
- // 明确是无效的Token的也缓存一下,用于快速拒绝请求
- if cerr == ErrTokenNotFound {
- mc.cache[key] = &CacheEntry{
- IsTokenValid: false,
- LoadedAt: time.Now(),
- LastUsedAt: time.Now(),
- }
- }
- return
- }
-
- pubKey, err := hex.DecodeString(loadToken.PublicKey)
- if err != nil {
- log.WithField("UserID", key.UserID).WithField("TokenID", key.TokenID).Warnf("invalid public key: %v", err)
- return
- }
-
- mc.cache[key] = &CacheEntry{
- IsTokenValid: true,
- Token: loadToken,
- PublicKey: pubKey,
- LoadedAt: time.Now(),
- LastUsedAt: time.Now(),
- }
-
- log.WithField("UserID", key.UserID).WithField("TokenID", key.TokenID).Infof("load token success, expires at: %v", loadToken.ExpiresAt)
- }
-
- func (mc *Cache) Verify(authInfo rpc.AccessTokenAuthInfo) bool {
- token, ok := mc.Get(CacheKey{
- UserID: authInfo.UserID,
- TokenID: authInfo.AccessTokenID,
- })
- if !ok {
- return false
- }
- if !token.IsTokenValid {
- return false
- }
-
- sig, err := hex.DecodeString(authInfo.Signature)
- if err != nil {
- return false
- }
-
- return ed25519.Verify(token.PublicKey, []byte(MakeStringToSign(authInfo.UserID, authInfo.AccessTokenID, authInfo.Nonce)), []byte(sig))
- }
-
- func MakeStringToSign(userID cortypes.UserID, tokenID cortypes.AccessTokenID, nonce string) string {
- return fmt.Sprintf("%v.%v.%v", userID, tokenID, nonce)
- }
|