| @@ -0,0 +1,175 @@ | |||
| package accesstoken | |||
| import ( | |||
| "context" | |||
| "crypto/ed25519" | |||
| "crypto/rand" | |||
| "encoding/hex" | |||
| "fmt" | |||
| "sync" | |||
| "time" | |||
| "gitlink.org.cn/cloudream/common/pkgs/async" | |||
| "gitlink.org.cn/cloudream/common/pkgs/logger" | |||
| stgglb "gitlink.org.cn/cloudream/jcs-pub/common/globals" | |||
| "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/accesstoken" | |||
| "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" | |||
| corrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/coordinator" | |||
| cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types" | |||
| ) | |||
| type KeeperEvent interface { | |||
| IsAccessTokenKeeper() bool | |||
| } | |||
| type ExitEvent struct { | |||
| KeeperEvent | |||
| Err error | |||
| } | |||
| type Keeper struct { | |||
| cfg Config | |||
| enabled bool | |||
| token cortypes.UserAccessToken | |||
| priKey ed25519.PrivateKey | |||
| lock sync.RWMutex | |||
| done chan any | |||
| } | |||
| func New(cfg Config, tempCli *corrpc.TempClient) (*Keeper, error) { | |||
| loginResp, cerr := tempCli.UserLogin(context.Background(), &corrpc.UserLogin{ | |||
| Account: cfg.Account, | |||
| Password: cfg.Password, | |||
| }) | |||
| if cerr != nil { | |||
| return nil, fmt.Errorf("login: %w", cerr.ToError()) | |||
| } | |||
| priKey, err := hex.DecodeString(loginResp.PrivateKey) | |||
| if err != nil { | |||
| return nil, fmt.Errorf("decode private key: %w", err) | |||
| } | |||
| return &Keeper{ | |||
| cfg: cfg, | |||
| enabled: true, | |||
| token: loginResp.Token, | |||
| priKey: priKey, | |||
| done: make(chan any, 1), | |||
| }, nil | |||
| } | |||
| func NewDisabled() *Keeper { | |||
| return &Keeper{ | |||
| done: make(chan any, 1), | |||
| enabled: false, | |||
| } | |||
| } | |||
| func (k *Keeper) Start() *async.UnboundChannel[KeeperEvent] { | |||
| log := logger.WithField("Mod", "Keeper") | |||
| ch := async.NewUnboundChannel[KeeperEvent]() | |||
| go func() { | |||
| if !k.enabled { | |||
| return | |||
| } | |||
| k.lock.RLock() | |||
| log.Infof("login success, token expires at %v", k.token.ExpiresAt) | |||
| k.lock.RUnlock() | |||
| ticker := time.NewTicker(time.Minute) | |||
| defer ticker.Stop() | |||
| loop: | |||
| for { | |||
| select { | |||
| case <-k.done: | |||
| break loop | |||
| case <-ticker.C: | |||
| k.lock.RLock() | |||
| token := k.token | |||
| k.lock.RUnlock() | |||
| // 当前Token已经过期,说明之前的刷新都失败了,打个日志 | |||
| if time.Now().After(token.ExpiresAt) { | |||
| log.Warnf("token expired at %v !", token.ExpiresAt) | |||
| } | |||
| // 在Token到期前5分钟时就要开始刷新Token | |||
| tokenDeadline := token.ExpiresAt.Add(-time.Minute * 5) | |||
| if time.Now().Before(tokenDeadline) { | |||
| continue | |||
| } | |||
| corCli := stgglb.CoordinatorRPCPool.Get() | |||
| refResp, cerr := corCli.UserRefreshToken(context.Background(), &corrpc.UserRefreshToken{}) | |||
| if cerr != nil { | |||
| log.Warnf("refresh token: %v", cerr) | |||
| corCli.Release() | |||
| continue | |||
| } | |||
| priKey, err := hex.DecodeString(refResp.PrivateKey) | |||
| if err != nil { | |||
| log.Warnf("decode private key: %v", err) | |||
| corCli.Release() | |||
| continue | |||
| } | |||
| log.Infof("refresh token success, new token expires at %v", refResp.Token.ExpiresAt) | |||
| k.lock.Lock() | |||
| k.token = refResp.Token | |||
| k.priKey = priKey | |||
| k.lock.Unlock() | |||
| corCli.Release() | |||
| } | |||
| } | |||
| ch.Send(ExitEvent{}) | |||
| }() | |||
| return ch | |||
| } | |||
| func (k *Keeper) Stop() { | |||
| select { | |||
| case k.done <- true: | |||
| default: | |||
| } | |||
| } | |||
| func (k *Keeper) GetAuthInfo() (rpc.AccessTokenAuthInfo, error) { | |||
| if !k.enabled { | |||
| return rpc.AccessTokenAuthInfo{}, fmt.Errorf("function disabled") | |||
| } | |||
| k.lock.RLock() | |||
| token := k.token | |||
| k.lock.RUnlock() | |||
| bytes := make([]byte, 8) | |||
| _, err := rand.Read(bytes) | |||
| if err != nil { | |||
| return rpc.AccessTokenAuthInfo{}, fmt.Errorf("generate nonce: %w", err) | |||
| } | |||
| nonce := hex.EncodeToString(bytes) | |||
| stringToSign := accesstoken.MakeStringToSign(token.UserID, token.TokenID, nonce) | |||
| signBytes := ed25519.Sign(k.priKey, []byte(stringToSign)) | |||
| signature := hex.EncodeToString(signBytes) | |||
| return rpc.AccessTokenAuthInfo{ | |||
| UserID: token.UserID, | |||
| AccessTokenID: token.TokenID, | |||
| Nonce: nonce, | |||
| Signature: signature, | |||
| }, nil | |||
| } | |||
| @@ -0,0 +1,6 @@ | |||
| package accesstoken | |||
| type Config struct { | |||
| Account string `json:"account"` | |||
| Password string `json:"password"` | |||
| } | |||
| @@ -8,6 +8,7 @@ import ( | |||
| "github.com/spf13/cobra" | |||
| "gitlink.org.cn/cloudream/common/pkgs/logger" | |||
| "gitlink.org.cn/cloudream/jcs-pub/client/internal/accessstat" | |||
| "gitlink.org.cn/cloudream/jcs-pub/client/internal/accesstoken" | |||
| "gitlink.org.cn/cloudream/jcs-pub/client/internal/config" | |||
| "gitlink.org.cn/cloudream/jcs-pub/client/internal/db" | |||
| "gitlink.org.cn/cloudream/jcs-pub/client/internal/downloader" | |||
| @@ -69,8 +70,42 @@ func serveHTTP(configPath string, opts serveHTTPOptions) { | |||
| } | |||
| stgglb.InitLocal(config.Cfg().Local) | |||
| stgglb.InitPools(&config.Cfg().HubRPC, &config.Cfg().CoordinatorRPC) | |||
| stgglb.StandaloneMode = opts.Standalone | |||
| stgglb.StandaloneMode = opts.Standalone || config.Cfg().AccessToken == nil | |||
| var accToken *accesstoken.Keeper | |||
| if !stgglb.StandaloneMode { | |||
| tempCli, err := config.Cfg().CoordinatorRPC.BuildTempClient() | |||
| if err != nil { | |||
| logger.Warnf("build coordinator rpc temp client: %v", err) | |||
| os.Exit(1) | |||
| } | |||
| accToken, err = accesstoken.New(*config.Cfg().AccessToken, tempCli) | |||
| tempCli.Release() | |||
| if err != nil { | |||
| logger.Warnf("new access token keeper: %v", err) | |||
| os.Exit(1) | |||
| } | |||
| hubRPCCfg, err := config.Cfg().HubRPC.Build(accToken) | |||
| if err != nil { | |||
| logger.Warnf("build hub rpc pool config: %v", err) | |||
| os.Exit(1) | |||
| } | |||
| corRPCCfg, err := config.Cfg().CoordinatorRPC.Build(accToken) | |||
| if err != nil { | |||
| logger.Warnf("build coordinator rpc pool config: %v", err) | |||
| os.Exit(1) | |||
| } | |||
| stgglb.InitPools(hubRPCCfg, corRPCCfg) | |||
| } else { | |||
| accToken = accesstoken.NewDisabled() | |||
| } | |||
| accTokenChan := accToken.Start() | |||
| defer accToken.Stop() | |||
| // 数据库 | |||
| db, err := db.NewDB(&config.Cfg().DB) | |||
| @@ -162,6 +197,7 @@ func serveHTTP(configPath string, opts serveHTTPOptions) { | |||
| /// 开始监听各个模块的事件 | |||
| accTokenEvt := accTokenChan.Receive() | |||
| evtPubEvt := evtPubChan.Receive() | |||
| acStatEvt := acStatChan.Receive() | |||
| replEvt := replCh.Receive() | |||
| @@ -171,6 +207,23 @@ func serveHTTP(configPath string, opts serveHTTPOptions) { | |||
| loop: | |||
| for { | |||
| select { | |||
| case e := <-accTokenEvt.Chan(): | |||
| if e.Err != nil { | |||
| logger.Errorf("receive access token event: %v", err) | |||
| break loop | |||
| } | |||
| switch e := e.Value.(type) { | |||
| case accesstoken.ExitEvent: | |||
| if e.Err != nil { | |||
| logger.Errorf("access token keeper exit with error: %v", err) | |||
| } else { | |||
| logger.Info("access token keeper exited") | |||
| } | |||
| break loop | |||
| } | |||
| accTokenEvt = accTokenChan.Receive() | |||
| case e := <-evtPubEvt.Chan(): | |||
| if e.Err != nil { | |||
| logger.Errorf("receive publisher event: %v", err) | |||
| @@ -10,6 +10,7 @@ import ( | |||
| "gitlink.org.cn/cloudream/common/pkgs/ioswitch/exec" | |||
| "gitlink.org.cn/cloudream/common/pkgs/logger" | |||
| "gitlink.org.cn/cloudream/jcs-pub/client/internal/accessstat" | |||
| "gitlink.org.cn/cloudream/jcs-pub/client/internal/accesstoken" | |||
| "gitlink.org.cn/cloudream/jcs-pub/client/internal/config" | |||
| "gitlink.org.cn/cloudream/jcs-pub/client/internal/db" | |||
| "gitlink.org.cn/cloudream/jcs-pub/client/internal/downloader" | |||
| @@ -80,7 +81,42 @@ func test(configPath string) { | |||
| } | |||
| stgglb.InitLocal(config.Cfg().Local) | |||
| stgglb.InitPools(&config.Cfg().HubRPC, &config.Cfg().CoordinatorRPC) | |||
| stgglb.StandaloneMode = config.Cfg().AccessToken == nil | |||
| var accToken *accesstoken.Keeper | |||
| if !stgglb.StandaloneMode { | |||
| tempCli, err := config.Cfg().CoordinatorRPC.BuildTempClient() | |||
| if err != nil { | |||
| logger.Warnf("build coordinator rpc temp client: %v", err) | |||
| os.Exit(1) | |||
| } | |||
| accToken, err = accesstoken.New(*config.Cfg().AccessToken, tempCli) | |||
| tempCli.Release() | |||
| if err != nil { | |||
| logger.Warnf("new access token keeper: %v", err) | |||
| os.Exit(1) | |||
| } | |||
| hubRPCCfg, err := config.Cfg().HubRPC.Build(accToken) | |||
| if err != nil { | |||
| logger.Warnf("build hub rpc pool config: %v", err) | |||
| os.Exit(1) | |||
| } | |||
| corRPCCfg, err := config.Cfg().CoordinatorRPC.Build(accToken) | |||
| if err != nil { | |||
| logger.Warnf("build coordinator rpc pool config: %v", err) | |||
| os.Exit(1) | |||
| } | |||
| stgglb.InitPools(hubRPCCfg, corRPCCfg) | |||
| } else { | |||
| accToken = accesstoken.NewDisabled() | |||
| } | |||
| accTokenChan := accToken.Start() | |||
| defer accToken.Stop() | |||
| // 数据库 | |||
| db, err := db.NewDB(&config.Cfg().DB) | |||
| @@ -140,13 +176,30 @@ func test(configPath string) { | |||
| os.Exit(0) | |||
| }() | |||
| /// 开始监听各个模块的事件 | |||
| accTokenEvt := accTokenChan.Receive() | |||
| evtPubEvt := evtPubChan.Receive() | |||
| acStatEvt := acStatChan.Receive() | |||
| loop: | |||
| for { | |||
| select { | |||
| case e := <-accTokenEvt.Chan(): | |||
| if e.Err != nil { | |||
| logger.Errorf("receive access token event: %v", err) | |||
| break loop | |||
| } | |||
| switch e := e.Value.(type) { | |||
| case accesstoken.ExitEvent: | |||
| if e.Err != nil { | |||
| logger.Errorf("access token keeper exit with error: %v", err) | |||
| } else { | |||
| logger.Info("access token keeper exited") | |||
| } | |||
| break loop | |||
| } | |||
| accTokenEvt = accTokenChan.Receive() | |||
| case e := <-evtPubEvt.Chan(): | |||
| if e.Err != nil { | |||
| logger.Errorf("receive publisher event: %v", err) | |||
| @@ -9,6 +9,7 @@ import ( | |||
| "github.com/spf13/cobra" | |||
| "gitlink.org.cn/cloudream/common/pkgs/logger" | |||
| "gitlink.org.cn/cloudream/jcs-pub/client/internal/accessstat" | |||
| "gitlink.org.cn/cloudream/jcs-pub/client/internal/accesstoken" | |||
| "gitlink.org.cn/cloudream/jcs-pub/client/internal/config" | |||
| "gitlink.org.cn/cloudream/jcs-pub/client/internal/db" | |||
| "gitlink.org.cn/cloudream/jcs-pub/client/internal/downloader" | |||
| @@ -60,7 +61,42 @@ func vfsTest(configPath string, opts serveHTTPOptions) { | |||
| } | |||
| stgglb.InitLocal(config.Cfg().Local) | |||
| stgglb.InitPools(&config.Cfg().HubRPC, &config.Cfg().CoordinatorRPC) | |||
| stgglb.StandaloneMode = opts.Standalone || config.Cfg().AccessToken == nil | |||
| var accToken *accesstoken.Keeper | |||
| if !opts.Standalone { | |||
| tempCli, err := config.Cfg().CoordinatorRPC.BuildTempClient() | |||
| if err != nil { | |||
| logger.Warnf("build coordinator rpc temp client: %v", err) | |||
| os.Exit(1) | |||
| } | |||
| accToken, err = accesstoken.New(*config.Cfg().AccessToken, tempCli) | |||
| tempCli.Release() | |||
| if err != nil { | |||
| logger.Warnf("new access token keeper: %v", err) | |||
| os.Exit(1) | |||
| } | |||
| hubRPCCfg, err := config.Cfg().HubRPC.Build(accToken) | |||
| if err != nil { | |||
| logger.Warnf("build hub rpc pool config: %v", err) | |||
| os.Exit(1) | |||
| } | |||
| corRPCCfg, err := config.Cfg().CoordinatorRPC.Build(accToken) | |||
| if err != nil { | |||
| logger.Warnf("build coordinator rpc pool config: %v", err) | |||
| os.Exit(1) | |||
| } | |||
| stgglb.InitPools(hubRPCCfg, corRPCCfg) | |||
| } else { | |||
| accToken = accesstoken.NewDisabled() | |||
| } | |||
| accTokenChan := accToken.Start() | |||
| defer accToken.Stop() | |||
| // 数据库 | |||
| db, err := db.NewDB(&config.Cfg().DB) | |||
| @@ -152,6 +188,7 @@ func vfsTest(configPath string, opts serveHTTPOptions) { | |||
| /// 开始监听各个模块的事件 | |||
| accTokenEvt := accTokenChan.Receive() | |||
| evtPubEvt := evtPubChan.Receive() | |||
| acStatEvt := acStatChan.Receive() | |||
| httpEvt := httpChan.Receive() | |||
| @@ -160,6 +197,23 @@ func vfsTest(configPath string, opts serveHTTPOptions) { | |||
| loop: | |||
| for { | |||
| select { | |||
| case e := <-accTokenEvt.Chan(): | |||
| if e.Err != nil { | |||
| logger.Errorf("receive access token event: %v", err) | |||
| break loop | |||
| } | |||
| switch e := e.Value.(type) { | |||
| case accesstoken.ExitEvent: | |||
| if e.Err != nil { | |||
| logger.Errorf("access token keeper exit with error: %v", err) | |||
| } else { | |||
| logger.Info("access token keeper exited") | |||
| } | |||
| break loop | |||
| } | |||
| accTokenEvt = accTokenChan.Receive() | |||
| case e := <-evtPubEvt.Chan(): | |||
| if e.Err != nil { | |||
| logger.Errorf("receive publisher event: %v", err) | |||
| @@ -3,6 +3,7 @@ package config | |||
| import ( | |||
| "gitlink.org.cn/cloudream/common/pkgs/logger" | |||
| "gitlink.org.cn/cloudream/common/utils/config" | |||
| "gitlink.org.cn/cloudream/jcs-pub/client/internal/accesstoken" | |||
| "gitlink.org.cn/cloudream/jcs-pub/client/internal/db" | |||
| "gitlink.org.cn/cloudream/jcs-pub/client/internal/downloader" | |||
| "gitlink.org.cn/cloudream/jcs-pub/client/internal/downloader/strategy" | |||
| @@ -18,8 +19,8 @@ import ( | |||
| type Config struct { | |||
| Local stgglb.LocalMachineInfo `json:"local"` | |||
| HubRPC hubrpc.PoolConfig `json:"hubRPC"` | |||
| CoordinatorRPC corrpc.PoolConfig `json:"coordinatorRPC"` | |||
| HubRPC hubrpc.PoolConfigJSON `json:"hubRPC"` | |||
| CoordinatorRPC corrpc.PoolConfigJSON `json:"coordinatorRPC"` | |||
| Logger logger.Config `json:"logger"` | |||
| DB db.Config `json:"db"` | |||
| SysEvent sysevent.Config `json:"sysEvent"` | |||
| @@ -29,6 +30,7 @@ type Config struct { | |||
| TickTock ticktock.Config `json:"tickTock"` | |||
| HTTP *http.Config `json:"http"` | |||
| Mount *mntcfg.Config `json:"mount"` | |||
| AccessToken *accesstoken.Config `json:"accessToken"` | |||
| } | |||
| var cfg Config | |||
| @@ -5,9 +5,16 @@ | |||
| "externalIP": "127.0.0.1", | |||
| "locationID": 1 | |||
| }, | |||
| "hubRPC": {}, | |||
| "hubRPC": { | |||
| "rootCA": "", | |||
| "clientCert": "", | |||
| "clientKey": "" | |||
| }, | |||
| "coordinatorRPC": { | |||
| "address": "127.0.0.1:5009" | |||
| "address": "127.0.0.1:5009", | |||
| "rootCA": "", | |||
| "clientCert": "", | |||
| "clientKey": "" | |||
| }, | |||
| "logger": { | |||
| "output": "stdout", | |||
| @@ -62,5 +69,9 @@ | |||
| "cacheActiveTime": "1m", | |||
| "cacheExpireTime": "1m", | |||
| "scanDataDirInterval": "10m" | |||
| }, | |||
| "accessToken": { | |||
| "account": "", | |||
| "password": "" | |||
| } | |||
| } | |||
| @@ -15,6 +15,14 @@ | |||
| "hubUnavailableTime": "20s" | |||
| }, | |||
| "rpc": { | |||
| "listen": "127.0.0.1:5009" | |||
| "listen": "127.0.0.1:5009", | |||
| "rootCA": "", | |||
| "serverCert": "", | |||
| "serverKey": "" | |||
| }, | |||
| "hubRPC": { | |||
| "rootCA": "", | |||
| "clientCert": "", | |||
| "clientKey": "" | |||
| } | |||
| } | |||
| @@ -1,24 +0,0 @@ | |||
| { | |||
| "logger": { | |||
| "output": "file", | |||
| "outputFileName": "datamap", | |||
| "outputDirectory": "log", | |||
| "level": "debug" | |||
| }, | |||
| "db": { | |||
| "address": "106.75.6.194:3306", | |||
| "account": "root", | |||
| "password": "cloudream123456", | |||
| "databaseName": "cloudream" | |||
| }, | |||
| "rabbitMQ": { | |||
| "address": "106.75.6.194:5672", | |||
| "account": "cloudream", | |||
| "password": "123456", | |||
| "vhost": "/", | |||
| "param": { | |||
| "retryNum": 5, | |||
| "retryInterval": 5000 | |||
| } | |||
| } | |||
| } | |||
| @@ -6,13 +6,24 @@ | |||
| "locationID": 1 | |||
| }, | |||
| "rpc": { | |||
| "listen": "127.0.0.1:5010" | |||
| "listen": "127.0.0.1:5010", | |||
| "rootCA": "", | |||
| "serverCert": "", | |||
| "serverKey": "" | |||
| }, | |||
| "http": { | |||
| "listen": "127.0.0.1:5110" | |||
| }, | |||
| "coordinatorRPC": { | |||
| "address": "127.0.0.1:5009" | |||
| "address": "127.0.0.1:5009", | |||
| "rootCA": "", | |||
| "clientCert": "", | |||
| "clientKey": "" | |||
| }, | |||
| "hubRPC": { | |||
| "rootCA": "", | |||
| "clientCert": "", | |||
| "clientKey": "" | |||
| }, | |||
| "logger": { | |||
| "output": "file", | |||
| @@ -0,0 +1,233 @@ | |||
| package accesstoken | |||
| import ( | |||
| "crypto/ed25519" | |||
| "encoding/hex" | |||
| "fmt" | |||
| "sync" | |||
| "time" | |||
| "gitlink.org.cn/cloudream/common/pkgs/async" | |||
| "gitlink.org.cn/cloudream/common/pkgs/logger" | |||
| "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" | |||
| cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types" | |||
| ) | |||
| type CacheEvent interface { | |||
| IsAccessTokenCacheEvent() bool | |||
| } | |||
| type ExitEvent struct { | |||
| CacheEvent | |||
| Err error | |||
| } | |||
| type CacheKey struct { | |||
| UserID cortypes.UserID | |||
| TokenID cortypes.AccessTokenID | |||
| } | |||
| var ErrTokenNotFound = fmt.Errorf("token not found") | |||
| type AccessTokenLoader func(key CacheKey) (cortypes.UserAccessToken, error) | |||
| type CacheEntry struct { | |||
| IsTokenValid bool | |||
| Token cortypes.UserAccessToken | |||
| PublicKey ed25519.PublicKey | |||
| LoadedAt time.Time | |||
| LastUsedAt time.Time | |||
| } | |||
| type Cache struct { | |||
| lock sync.Mutex | |||
| cache map[CacheKey]*CacheEntry | |||
| done chan any | |||
| loader AccessTokenLoader | |||
| } | |||
| func New(loader AccessTokenLoader) *Cache { | |||
| return &Cache{ | |||
| cache: make(map[CacheKey]*CacheEntry), | |||
| done: make(chan any, 1), | |||
| loader: loader, | |||
| } | |||
| } | |||
| func (nc *Cache) Start() *async.UnboundChannel[CacheEvent] { | |||
| log := logger.WithField("Mod", "AccessTokenCache") | |||
| ch := async.NewUnboundChannel[CacheEvent]() | |||
| go func() { | |||
| ticker := time.NewTicker(time.Second * 10) | |||
| defer ticker.Stop() | |||
| loop: | |||
| for { | |||
| select { | |||
| case <-nc.done: | |||
| break loop | |||
| case <-ticker.C: | |||
| nc.lock.Lock() | |||
| for key, entry := range nc.cache { | |||
| if !entry.IsTokenValid { | |||
| // 无效Token的记录5分钟后删除 | |||
| if time.Since(entry.LoadedAt) > time.Minute*5 { | |||
| log.WithField("UserID", key.UserID).WithField("TokenID", key.TokenID).Infof("delete expired invalid token") | |||
| delete(nc.cache, key) | |||
| continue | |||
| } | |||
| } else { | |||
| // 5分钟没有使用的Token则删除 | |||
| if time.Since(entry.LastUsedAt) > time.Minute*5 { | |||
| log.WithField("UserID", key.UserID).WithField("TokenID", key.TokenID).Infof("delete unused token") | |||
| delete(nc.cache, key) | |||
| continue | |||
| } | |||
| // 过期Token标记为无效 | |||
| if time.Now().After(entry.Token.ExpiresAt) { | |||
| entry.IsTokenValid = false | |||
| entry.LastUsedAt = time.Now() | |||
| log.WithField("UserID", key.UserID).WithField("TokenID", key.TokenID).Infof("token expired") | |||
| } else if time.Since(entry.LoadedAt) > time.Minute*5 { | |||
| // 依然有效的Token,则5分钟检查一次有效性 | |||
| go nc.load(key) | |||
| } | |||
| } | |||
| } | |||
| nc.lock.Unlock() | |||
| } | |||
| } | |||
| ch.Send(&ExitEvent{}) | |||
| }() | |||
| return ch | |||
| } | |||
| func (mc *Cache) Stop() { | |||
| select { | |||
| case mc.done <- true: | |||
| default: | |||
| } | |||
| } | |||
| func (mc *Cache) Get(key CacheKey) (*CacheEntry, bool) { | |||
| var ret *CacheEntry | |||
| var ok bool | |||
| for i := 0; i < 2; i++ { | |||
| mc.lock.Lock() | |||
| entry, getOk := mc.cache[key] | |||
| if getOk { | |||
| ret = entry | |||
| ok = true | |||
| ret.LastUsedAt = time.Now() | |||
| // 如果Token已经过期,则直接设置为无效Token。因为Token是随机生成的,几乎不可能把一个过期的Token再用上 | |||
| if entry.IsTokenValid && time.Now().After(entry.Token.ExpiresAt) { | |||
| entry.IsTokenValid = false | |||
| entry.LastUsedAt = time.Now() | |||
| } | |||
| } | |||
| mc.lock.Unlock() | |||
| if ok { | |||
| break | |||
| } | |||
| mc.load(key) | |||
| } | |||
| return ret, ok | |||
| } | |||
| func (mc *Cache) NotifyTokenInvalid(key CacheKey) { | |||
| log := logger.WithField("Mod", "AccessTokenCache") | |||
| mc.lock.Lock() | |||
| defer mc.lock.Unlock() | |||
| entry, ok := mc.cache[key] | |||
| if !ok { | |||
| entry = &CacheEntry{ | |||
| IsTokenValid: false, | |||
| LoadedAt: time.Now(), | |||
| LastUsedAt: time.Now(), | |||
| } | |||
| mc.cache[key] = entry | |||
| return | |||
| } | |||
| entry.IsTokenValid = false | |||
| entry.LastUsedAt = time.Now() | |||
| log.WithField("UserID", key.UserID).WithField("TokenID", key.TokenID).Infof("notify token invalid") | |||
| } | |||
| func (mc *Cache) load(key CacheKey) { | |||
| log := logger.WithField("Mod", "AccessTokenCache") | |||
| loadToken, cerr := mc.loader(key) | |||
| mc.lock.Lock() | |||
| defer mc.lock.Unlock() | |||
| if cerr != nil { | |||
| log.WithField("UserID", key.UserID).WithField("TokenID", key.TokenID).Warnf("load token: %v", cerr) | |||
| // 明确是无效的Token的也缓存一下,用于快速拒绝请求 | |||
| if cerr == ErrTokenNotFound { | |||
| mc.cache[key] = &CacheEntry{ | |||
| IsTokenValid: false, | |||
| LoadedAt: time.Now(), | |||
| LastUsedAt: time.Now(), | |||
| } | |||
| } | |||
| return | |||
| } | |||
| pubKey, err := hex.DecodeString(loadToken.PublicKey) | |||
| if err != nil { | |||
| log.WithField("UserID", key.UserID).WithField("TokenID", key.TokenID).Warnf("invalid public key: %v", err) | |||
| return | |||
| } | |||
| mc.cache[key] = &CacheEntry{ | |||
| IsTokenValid: true, | |||
| Token: loadToken, | |||
| PublicKey: pubKey, | |||
| LoadedAt: time.Now(), | |||
| LastUsedAt: time.Now(), | |||
| } | |||
| log.WithField("UserID", key.UserID).WithField("TokenID", key.TokenID).Infof("load token success, expires at: %v", loadToken.ExpiresAt) | |||
| } | |||
| func (mc *Cache) Verify(authInfo rpc.AccessTokenAuthInfo) bool { | |||
| token, ok := mc.Get(CacheKey{ | |||
| UserID: authInfo.UserID, | |||
| TokenID: authInfo.AccessTokenID, | |||
| }) | |||
| if !ok { | |||
| return false | |||
| } | |||
| if !token.IsTokenValid { | |||
| return false | |||
| } | |||
| sig, err := hex.DecodeString(authInfo.Signature) | |||
| if err != nil { | |||
| return false | |||
| } | |||
| return ed25519.Verify(token.PublicKey, []byte(MakeStringToSign(authInfo.UserID, authInfo.AccessTokenID, authInfo.Nonce)), []byte(sig)) | |||
| } | |||
| func MakeStringToSign(userID cortypes.UserID, tokenID cortypes.AccessTokenID, nonce string) string { | |||
| return fmt.Sprintf("%v.%v.%v", userID, tokenID, nonce) | |||
| } | |||
| @@ -0,0 +1,232 @@ | |||
| package rpc | |||
| import ( | |||
| "crypto/tls" | |||
| "fmt" | |||
| "strconv" | |||
| cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types" | |||
| "golang.org/x/net/context" | |||
| "google.golang.org/grpc" | |||
| "google.golang.org/grpc/codes" | |||
| "google.golang.org/grpc/credentials" | |||
| "google.golang.org/grpc/metadata" | |||
| "google.golang.org/grpc/peer" | |||
| "google.golang.org/grpc/status" | |||
| ) | |||
| const ( | |||
| ClientAPISNIV1 = "rpc.client.jcs-pub.v1" | |||
| InternalAPISNIV1 = "rpc.internal.jcs-pub.v1" | |||
| MetaUserID = "x-jcs-user-id" | |||
| MetaAccessTokenID = "x-jcs-access-token-id" | |||
| MetaNonce = "x-jcs-nonce" | |||
| MetaSignature = "x-jcs-signature" | |||
| MetaTokenAuthInfo = "x-jcs-token-auth-info" | |||
| ) | |||
| type AccessTokenAuthInfo struct { | |||
| UserID cortypes.UserID | |||
| AccessTokenID cortypes.AccessTokenID | |||
| Nonce string | |||
| Signature string | |||
| } | |||
| type AccessTokenVerifier interface { | |||
| Verify(authInfo AccessTokenAuthInfo) bool | |||
| } | |||
| type AccessTokenProvider interface { | |||
| GetAuthInfo() (AccessTokenAuthInfo, error) | |||
| } | |||
| func (s *ServerBase) tlsConfigSelector(hello *tls.ClientHelloInfo) (*tls.Config, error) { | |||
| switch hello.ServerName { | |||
| case ClientAPISNIV1: | |||
| return &tls.Config{ | |||
| Certificates: []tls.Certificate{s.serverCert}, | |||
| ClientAuth: tls.NoClientCert, | |||
| NextProtos: []string{"h2"}, | |||
| }, nil | |||
| case InternalAPISNIV1: | |||
| return &tls.Config{ | |||
| Certificates: []tls.Certificate{s.serverCert}, | |||
| ClientAuth: tls.RequireAndVerifyClientCert, | |||
| ClientCAs: s.rootCA, | |||
| NextProtos: []string{"h2"}, | |||
| }, nil | |||
| default: | |||
| return nil, fmt.Errorf("unknown server name: %s", hello.ServerName) | |||
| } | |||
| } | |||
| func (s *ServerBase) authUnary( | |||
| ctx context.Context, | |||
| req interface{}, | |||
| info *grpc.UnaryServerInfo, | |||
| handler grpc.UnaryHandler, | |||
| ) (resp any, err error) { | |||
| pr, ok := peer.FromContext(ctx) | |||
| if !ok { | |||
| return nil, status.Error(codes.Unauthenticated, "no peer found in context") | |||
| } | |||
| tlsInfo, ok := pr.AuthInfo.(credentials.TLSInfo) | |||
| if !ok { | |||
| return nil, status.Error(codes.Unauthenticated, "no tls info found in peer") | |||
| } | |||
| // 如果是使用interanl ServerName通过的TLS认证,则直接放行 | |||
| if tlsInfo.State.ServerName == InternalAPISNIV1 { | |||
| return handler(ctx, req) | |||
| } | |||
| // 如果是无需认证的API,则直接放行 | |||
| if s.noAuthAPIs[info.FullMethod] { | |||
| return handler(ctx, req) | |||
| } | |||
| // 否则要进行额外的Token认证 | |||
| if !s.accessTokenAuthAPIs[info.FullMethod] { | |||
| return nil, status.Error(codes.Unauthenticated, "unauthorized access") | |||
| } | |||
| meta, ok := metadata.FromIncomingContext(ctx) | |||
| if !ok { | |||
| return nil, status.Error(codes.Unauthenticated, "no metadata found in context") | |||
| } | |||
| userIDs := meta.Get(MetaUserID) | |||
| if len(userIDs) != 1 { | |||
| return nil, status.Error(codes.Unauthenticated, "missing or multiple user ids in metadata") | |||
| } | |||
| userID, err := strconv.ParseInt(userIDs[0], 10, 64) | |||
| if err != nil { | |||
| return nil, status.Error(codes.Unauthenticated, "invalid user id in metadata") | |||
| } | |||
| accessTokenIDs := meta.Get(MetaAccessTokenID) | |||
| if len(accessTokenIDs) != 1 { | |||
| return nil, status.Error(codes.Unauthenticated, "missing or multiple access token ids in metadata") | |||
| } | |||
| nonce := meta.Get(MetaNonce) | |||
| if len(nonce) != 1 { | |||
| return nil, status.Error(codes.Unauthenticated, "missing or multiple nonces in metadata") | |||
| } | |||
| signature := meta.Get(MetaSignature) | |||
| if len(signature) != 1 { | |||
| return nil, status.Error(codes.Unauthenticated, "missing or multiple signatures in metadata") | |||
| } | |||
| authInfo := AccessTokenAuthInfo{ | |||
| UserID: cortypes.UserID(userID), | |||
| AccessTokenID: cortypes.AccessTokenID(accessTokenIDs[0]), | |||
| Nonce: nonce[0], | |||
| Signature: signature[0], | |||
| } | |||
| if !s.tokenVerifier.Verify(authInfo) { | |||
| return nil, status.Error(codes.Unauthenticated, "invalid access token") | |||
| } | |||
| ctx = context.WithValue(ctx, MetaTokenAuthInfo, authInfo) | |||
| return handler(ctx, req) | |||
| } | |||
| func (s *ServerBase) authStream( | |||
| srv any, | |||
| stream grpc.ServerStream, | |||
| info *grpc.StreamServerInfo, | |||
| handler grpc.StreamHandler, | |||
| ) error { | |||
| pr, ok := peer.FromContext(stream.Context()) | |||
| if !ok { | |||
| return status.Error(codes.Unauthenticated, "no peer found in context") | |||
| } | |||
| tlsInfo, ok := pr.AuthInfo.(credentials.TLSInfo) | |||
| if !ok { | |||
| return status.Error(codes.Unauthenticated, "no tls info found in peer") | |||
| } | |||
| // 如果是使用interanl ServerName通过的TLS认证,则直接放行 | |||
| if tlsInfo.State.ServerName == InternalAPISNIV1 { | |||
| return handler(srv, stream) | |||
| } | |||
| // 如果是无需认证的API,则直接放行 | |||
| if s.noAuthAPIs[info.FullMethod] { | |||
| return handler(srv, stream) | |||
| } | |||
| // 否则要进行额外的Token认证 | |||
| if !s.accessTokenAuthAPIs[info.FullMethod] { | |||
| return status.Error(codes.Unauthenticated, "unauthorized access") | |||
| } | |||
| meta, ok := metadata.FromIncomingContext(stream.Context()) | |||
| if !ok { | |||
| return status.Error(codes.Unauthenticated, "no metadata found in context") | |||
| } | |||
| userIDs := meta.Get(MetaUserID) | |||
| if len(userIDs) != 1 { | |||
| return status.Error(codes.Unauthenticated, "missing or multiple user ids in metadata") | |||
| } | |||
| userID, err := strconv.ParseInt(userIDs[0], 10, 64) | |||
| if err != nil { | |||
| return status.Error(codes.Unauthenticated, "invalid user id in metadata") | |||
| } | |||
| accessTokenIDs := meta.Get(MetaAccessTokenID) | |||
| if len(accessTokenIDs) != 1 { | |||
| return status.Error(codes.Unauthenticated, "missing or multiple access token ids in metadata") | |||
| } | |||
| nonce := meta.Get(MetaNonce) | |||
| if len(nonce) != 1 { | |||
| return status.Error(codes.Unauthenticated, "missing or multiple nonces in metadata") | |||
| } | |||
| signature := meta.Get(MetaSignature) | |||
| if len(signature) != 1 { | |||
| return status.Error(codes.Unauthenticated, "missing or multiple signatures in metadata") | |||
| } | |||
| authInfo := AccessTokenAuthInfo{ | |||
| UserID: cortypes.UserID(userID), | |||
| AccessTokenID: cortypes.AccessTokenID(accessTokenIDs[0]), | |||
| Nonce: nonce[0], | |||
| Signature: signature[0], | |||
| } | |||
| if !s.tokenVerifier.Verify(authInfo) { | |||
| return status.Error(codes.Unauthenticated, "invalid access token") | |||
| } | |||
| return handler(srv, &serverStream{stream, context.WithValue(stream.Context(), MetaTokenAuthInfo, authInfo)}) | |||
| } | |||
| type serverStream struct { | |||
| grpc.ServerStream | |||
| ctx context.Context | |||
| } | |||
| func (s *serverStream) Context() context.Context { | |||
| return s.ctx | |||
| } | |||
| func GetAuthInfo(ctx context.Context) (AccessTokenAuthInfo, bool) { | |||
| val := ctx.Value(MetaTokenAuthInfo) | |||
| if val == nil { | |||
| return AccessTokenAuthInfo{}, false | |||
| } | |||
| authInfo, ok := val.(AccessTokenAuthInfo) | |||
| return authInfo, ok | |||
| } | |||
| @@ -14,9 +14,17 @@ type Client struct { | |||
| func (c *Client) Release() { | |||
| if c.con != nil { | |||
| c.pool.release() | |||
| c.pool.connPool.Release(c.pool.cfg.Address) | |||
| } | |||
| } | |||
| type TempClient struct { | |||
| Client | |||
| } | |||
| func (c *TempClient) Release() { | |||
| c.con.Close() | |||
| } | |||
| // 客户端的API要和服务端的API保持一致 | |||
| var _ CoordinatorAPI = (*Client)(nil) | |||
| @@ -27,7 +27,7 @@ var file_pkgs_rpc_coordinator_coordinator_proto_rawDesc = []byte{ | |||
| 0x69, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2f, 0x63, 0x6f, 0x6f, 0x72, 0x64, 0x69, 0x6e, 0x61, 0x74, | |||
| 0x6f, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x06, 0x63, 0x6f, 0x72, 0x72, 0x70, 0x63, | |||
| 0x1a, 0x12, 0x70, 0x6b, 0x67, 0x73, 0x2f, 0x72, 0x70, 0x63, 0x2f, 0x72, 0x70, 0x63, 0x2e, 0x70, | |||
| 0x72, 0x6f, 0x74, 0x6f, 0x32, 0xfe, 0x01, 0x0a, 0x0b, 0x43, 0x6f, 0x6f, 0x72, 0x64, 0x69, 0x6e, | |||
| 0x72, 0x6f, 0x74, 0x6f, 0x32, 0xb7, 0x03, 0x0a, 0x0b, 0x43, 0x6f, 0x6f, 0x72, 0x64, 0x69, 0x6e, | |||
| 0x61, 0x74, 0x6f, 0x72, 0x12, 0x2b, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x48, 0x75, 0x62, 0x43, 0x6f, | |||
| 0x6e, 0x66, 0x69, 0x67, 0x12, 0x0c, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x71, 0x75, 0x65, | |||
| 0x73, 0x74, 0x1a, 0x0d, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, | |||
| @@ -43,11 +43,23 @@ var file_pkgs_rpc_coordinator_coordinator_proto_rawDesc = []byte{ | |||
| 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x2f, 0x0a, 0x10, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x53, 0x74, | |||
| 0x6f, 0x72, 0x61, 0x67, 0x65, 0x48, 0x75, 0x62, 0x12, 0x0c, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, | |||
| 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0d, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x73, | |||
| 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x40, 0x5a, 0x3e, 0x67, 0x69, 0x74, 0x6c, 0x69, 0x6e, 0x6b, | |||
| 0x2e, 0x6f, 0x72, 0x67, 0x2e, 0x63, 0x6e, 0x2f, 0x63, 0x6c, 0x6f, 0x75, 0x64, 0x72, 0x65, 0x61, | |||
| 0x6d, 0x2f, 0x6a, 0x63, 0x73, 0x2d, 0x70, 0x75, 0x62, 0x2f, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, | |||
| 0x2f, 0x70, 0x6b, 0x67, 0x73, 0x2f, 0x72, 0x70, 0x63, 0x2f, 0x63, 0x6f, 0x72, 0x72, 0x70, 0x63, | |||
| 0x3b, 0x63, 0x6f, 0x72, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, | |||
| 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x28, 0x0a, 0x09, 0x55, 0x73, 0x65, 0x72, 0x4c, 0x6f, 0x67, | |||
| 0x69, 0x6e, 0x12, 0x0c, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, | |||
| 0x1a, 0x0d, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, | |||
| 0x2f, 0x0a, 0x10, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x54, 0x6f, | |||
| 0x6b, 0x65, 0x6e, 0x12, 0x0c, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, | |||
| 0x74, 0x1a, 0x0d, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, | |||
| 0x12, 0x29, 0x0a, 0x0a, 0x55, 0x73, 0x65, 0x72, 0x4c, 0x6f, 0x67, 0x6f, 0x75, 0x74, 0x12, 0x0c, | |||
| 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0d, 0x2e, 0x72, | |||
| 0x70, 0x63, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x31, 0x0a, 0x12, 0x48, | |||
| 0x75, 0x62, 0x4c, 0x6f, 0x61, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65, | |||
| 0x6e, 0x12, 0x0c, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, | |||
| 0x0d, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x40, | |||
| 0x5a, 0x3e, 0x67, 0x69, 0x74, 0x6c, 0x69, 0x6e, 0x6b, 0x2e, 0x6f, 0x72, 0x67, 0x2e, 0x63, 0x6e, | |||
| 0x2f, 0x63, 0x6c, 0x6f, 0x75, 0x64, 0x72, 0x65, 0x61, 0x6d, 0x2f, 0x6a, 0x63, 0x73, 0x2d, 0x70, | |||
| 0x75, 0x62, 0x2f, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2f, 0x70, 0x6b, 0x67, 0x73, 0x2f, 0x72, | |||
| 0x70, 0x63, 0x2f, 0x63, 0x6f, 0x72, 0x72, 0x70, 0x63, 0x3b, 0x63, 0x6f, 0x72, 0x72, 0x70, 0x63, | |||
| 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, | |||
| } | |||
| var file_pkgs_rpc_coordinator_coordinator_proto_goTypes = []any{ | |||
| @@ -60,13 +72,21 @@ var file_pkgs_rpc_coordinator_coordinator_proto_depIdxs = []int32{ | |||
| 0, // 2: corrpc.Coordinator.GetHubConnectivities:input_type -> rpc.Request | |||
| 0, // 3: corrpc.Coordinator.ReportHubConnectivity:input_type -> rpc.Request | |||
| 0, // 4: corrpc.Coordinator.SelectStorageHub:input_type -> rpc.Request | |||
| 1, // 5: corrpc.Coordinator.GetHubConfig:output_type -> rpc.Response | |||
| 1, // 6: corrpc.Coordinator.GetHubs:output_type -> rpc.Response | |||
| 1, // 7: corrpc.Coordinator.GetHubConnectivities:output_type -> rpc.Response | |||
| 1, // 8: corrpc.Coordinator.ReportHubConnectivity:output_type -> rpc.Response | |||
| 1, // 9: corrpc.Coordinator.SelectStorageHub:output_type -> rpc.Response | |||
| 5, // [5:10] is the sub-list for method output_type | |||
| 0, // [0:5] is the sub-list for method input_type | |||
| 0, // 5: corrpc.Coordinator.UserLogin:input_type -> rpc.Request | |||
| 0, // 6: corrpc.Coordinator.UserRefreshToken:input_type -> rpc.Request | |||
| 0, // 7: corrpc.Coordinator.UserLogout:input_type -> rpc.Request | |||
| 0, // 8: corrpc.Coordinator.HubLoadAccessToken:input_type -> rpc.Request | |||
| 1, // 9: corrpc.Coordinator.GetHubConfig:output_type -> rpc.Response | |||
| 1, // 10: corrpc.Coordinator.GetHubs:output_type -> rpc.Response | |||
| 1, // 11: corrpc.Coordinator.GetHubConnectivities:output_type -> rpc.Response | |||
| 1, // 12: corrpc.Coordinator.ReportHubConnectivity:output_type -> rpc.Response | |||
| 1, // 13: corrpc.Coordinator.SelectStorageHub:output_type -> rpc.Response | |||
| 1, // 14: corrpc.Coordinator.UserLogin:output_type -> rpc.Response | |||
| 1, // 15: corrpc.Coordinator.UserRefreshToken:output_type -> rpc.Response | |||
| 1, // 16: corrpc.Coordinator.UserLogout:output_type -> rpc.Response | |||
| 1, // 17: corrpc.Coordinator.HubLoadAccessToken:output_type -> rpc.Response | |||
| 9, // [9:18] is the sub-list for method output_type | |||
| 0, // [0:9] is the sub-list for method input_type | |||
| 0, // [0:0] is the sub-list for extension type_name | |||
| 0, // [0:0] is the sub-list for extension extendee | |||
| 0, // [0:0] is the sub-list for field type_name | |||
| @@ -14,4 +14,9 @@ service Coordinator { | |||
| rpc ReportHubConnectivity(rpc.Request) returns(rpc.Response); | |||
| rpc SelectStorageHub(rpc.Request) returns(rpc.Response); | |||
| rpc UserLogin(rpc.Request) returns(rpc.Response); | |||
| rpc UserRefreshToken(rpc.Request) returns(rpc.Response); | |||
| rpc UserLogout(rpc.Request) returns(rpc.Response); | |||
| rpc HubLoadAccessToken(rpc.Request) returns(rpc.Response); | |||
| } | |||
| @@ -25,6 +25,10 @@ const ( | |||
| Coordinator_GetHubConnectivities_FullMethodName = "/corrpc.Coordinator/GetHubConnectivities" | |||
| Coordinator_ReportHubConnectivity_FullMethodName = "/corrpc.Coordinator/ReportHubConnectivity" | |||
| Coordinator_SelectStorageHub_FullMethodName = "/corrpc.Coordinator/SelectStorageHub" | |||
| Coordinator_UserLogin_FullMethodName = "/corrpc.Coordinator/UserLogin" | |||
| Coordinator_UserRefreshToken_FullMethodName = "/corrpc.Coordinator/UserRefreshToken" | |||
| Coordinator_UserLogout_FullMethodName = "/corrpc.Coordinator/UserLogout" | |||
| Coordinator_HubLoadAccessToken_FullMethodName = "/corrpc.Coordinator/HubLoadAccessToken" | |||
| ) | |||
| // CoordinatorClient is the client API for Coordinator service. | |||
| @@ -36,6 +40,10 @@ type CoordinatorClient interface { | |||
| GetHubConnectivities(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) | |||
| ReportHubConnectivity(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) | |||
| SelectStorageHub(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) | |||
| UserLogin(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) | |||
| UserRefreshToken(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) | |||
| UserLogout(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) | |||
| HubLoadAccessToken(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) | |||
| } | |||
| type coordinatorClient struct { | |||
| @@ -91,6 +99,42 @@ func (c *coordinatorClient) SelectStorageHub(ctx context.Context, in *rpc.Reques | |||
| return out, nil | |||
| } | |||
| func (c *coordinatorClient) UserLogin(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) { | |||
| out := new(rpc.Response) | |||
| err := c.cc.Invoke(ctx, Coordinator_UserLogin_FullMethodName, in, out, opts...) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| return out, nil | |||
| } | |||
| func (c *coordinatorClient) UserRefreshToken(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) { | |||
| out := new(rpc.Response) | |||
| err := c.cc.Invoke(ctx, Coordinator_UserRefreshToken_FullMethodName, in, out, opts...) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| return out, nil | |||
| } | |||
| func (c *coordinatorClient) UserLogout(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) { | |||
| out := new(rpc.Response) | |||
| err := c.cc.Invoke(ctx, Coordinator_UserLogout_FullMethodName, in, out, opts...) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| return out, nil | |||
| } | |||
| func (c *coordinatorClient) HubLoadAccessToken(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) { | |||
| out := new(rpc.Response) | |||
| err := c.cc.Invoke(ctx, Coordinator_HubLoadAccessToken_FullMethodName, in, out, opts...) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| return out, nil | |||
| } | |||
| // CoordinatorServer is the server API for Coordinator service. | |||
| // All implementations must embed UnimplementedCoordinatorServer | |||
| // for forward compatibility | |||
| @@ -100,6 +144,10 @@ type CoordinatorServer interface { | |||
| GetHubConnectivities(context.Context, *rpc.Request) (*rpc.Response, error) | |||
| ReportHubConnectivity(context.Context, *rpc.Request) (*rpc.Response, error) | |||
| SelectStorageHub(context.Context, *rpc.Request) (*rpc.Response, error) | |||
| UserLogin(context.Context, *rpc.Request) (*rpc.Response, error) | |||
| UserRefreshToken(context.Context, *rpc.Request) (*rpc.Response, error) | |||
| UserLogout(context.Context, *rpc.Request) (*rpc.Response, error) | |||
| HubLoadAccessToken(context.Context, *rpc.Request) (*rpc.Response, error) | |||
| mustEmbedUnimplementedCoordinatorServer() | |||
| } | |||
| @@ -122,6 +170,18 @@ func (UnimplementedCoordinatorServer) ReportHubConnectivity(context.Context, *rp | |||
| func (UnimplementedCoordinatorServer) SelectStorageHub(context.Context, *rpc.Request) (*rpc.Response, error) { | |||
| return nil, status.Errorf(codes.Unimplemented, "method SelectStorageHub not implemented") | |||
| } | |||
| func (UnimplementedCoordinatorServer) UserLogin(context.Context, *rpc.Request) (*rpc.Response, error) { | |||
| return nil, status.Errorf(codes.Unimplemented, "method UserLogin not implemented") | |||
| } | |||
| func (UnimplementedCoordinatorServer) UserRefreshToken(context.Context, *rpc.Request) (*rpc.Response, error) { | |||
| return nil, status.Errorf(codes.Unimplemented, "method UserRefreshToken not implemented") | |||
| } | |||
| func (UnimplementedCoordinatorServer) UserLogout(context.Context, *rpc.Request) (*rpc.Response, error) { | |||
| return nil, status.Errorf(codes.Unimplemented, "method UserLogout not implemented") | |||
| } | |||
| func (UnimplementedCoordinatorServer) HubLoadAccessToken(context.Context, *rpc.Request) (*rpc.Response, error) { | |||
| return nil, status.Errorf(codes.Unimplemented, "method HubLoadAccessToken not implemented") | |||
| } | |||
| func (UnimplementedCoordinatorServer) mustEmbedUnimplementedCoordinatorServer() {} | |||
| // UnsafeCoordinatorServer may be embedded to opt out of forward compatibility for this service. | |||
| @@ -225,6 +285,78 @@ func _Coordinator_SelectStorageHub_Handler(srv interface{}, ctx context.Context, | |||
| return interceptor(ctx, in, info, handler) | |||
| } | |||
| func _Coordinator_UserLogin_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { | |||
| in := new(rpc.Request) | |||
| if err := dec(in); err != nil { | |||
| return nil, err | |||
| } | |||
| if interceptor == nil { | |||
| return srv.(CoordinatorServer).UserLogin(ctx, in) | |||
| } | |||
| info := &grpc.UnaryServerInfo{ | |||
| Server: srv, | |||
| FullMethod: Coordinator_UserLogin_FullMethodName, | |||
| } | |||
| handler := func(ctx context.Context, req interface{}) (interface{}, error) { | |||
| return srv.(CoordinatorServer).UserLogin(ctx, req.(*rpc.Request)) | |||
| } | |||
| return interceptor(ctx, in, info, handler) | |||
| } | |||
| func _Coordinator_UserRefreshToken_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { | |||
| in := new(rpc.Request) | |||
| if err := dec(in); err != nil { | |||
| return nil, err | |||
| } | |||
| if interceptor == nil { | |||
| return srv.(CoordinatorServer).UserRefreshToken(ctx, in) | |||
| } | |||
| info := &grpc.UnaryServerInfo{ | |||
| Server: srv, | |||
| FullMethod: Coordinator_UserRefreshToken_FullMethodName, | |||
| } | |||
| handler := func(ctx context.Context, req interface{}) (interface{}, error) { | |||
| return srv.(CoordinatorServer).UserRefreshToken(ctx, req.(*rpc.Request)) | |||
| } | |||
| return interceptor(ctx, in, info, handler) | |||
| } | |||
| func _Coordinator_UserLogout_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { | |||
| in := new(rpc.Request) | |||
| if err := dec(in); err != nil { | |||
| return nil, err | |||
| } | |||
| if interceptor == nil { | |||
| return srv.(CoordinatorServer).UserLogout(ctx, in) | |||
| } | |||
| info := &grpc.UnaryServerInfo{ | |||
| Server: srv, | |||
| FullMethod: Coordinator_UserLogout_FullMethodName, | |||
| } | |||
| handler := func(ctx context.Context, req interface{}) (interface{}, error) { | |||
| return srv.(CoordinatorServer).UserLogout(ctx, req.(*rpc.Request)) | |||
| } | |||
| return interceptor(ctx, in, info, handler) | |||
| } | |||
| func _Coordinator_HubLoadAccessToken_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { | |||
| in := new(rpc.Request) | |||
| if err := dec(in); err != nil { | |||
| return nil, err | |||
| } | |||
| if interceptor == nil { | |||
| return srv.(CoordinatorServer).HubLoadAccessToken(ctx, in) | |||
| } | |||
| info := &grpc.UnaryServerInfo{ | |||
| Server: srv, | |||
| FullMethod: Coordinator_HubLoadAccessToken_FullMethodName, | |||
| } | |||
| handler := func(ctx context.Context, req interface{}) (interface{}, error) { | |||
| return srv.(CoordinatorServer).HubLoadAccessToken(ctx, req.(*rpc.Request)) | |||
| } | |||
| return interceptor(ctx, in, info, handler) | |||
| } | |||
| // Coordinator_ServiceDesc is the grpc.ServiceDesc for Coordinator service. | |||
| // It's only intended for direct use with grpc.RegisterService, | |||
| // and not to be introspected or modified (even as a copy) | |||
| @@ -252,6 +384,22 @@ var Coordinator_ServiceDesc = grpc.ServiceDesc{ | |||
| MethodName: "SelectStorageHub", | |||
| Handler: _Coordinator_SelectStorageHub_Handler, | |||
| }, | |||
| { | |||
| MethodName: "UserLogin", | |||
| Handler: _Coordinator_UserLogin_Handler, | |||
| }, | |||
| { | |||
| MethodName: "UserRefreshToken", | |||
| Handler: _Coordinator_UserRefreshToken_Handler, | |||
| }, | |||
| { | |||
| MethodName: "UserLogout", | |||
| Handler: _Coordinator_UserLogout_Handler, | |||
| }, | |||
| { | |||
| MethodName: "HubLoadAccessToken", | |||
| Handler: _Coordinator_HubLoadAccessToken_Handler, | |||
| }, | |||
| }, | |||
| Streams: []grpc.StreamDesc{}, | |||
| Metadata: "pkgs/rpc/coordinator/coordinator.proto", | |||
| @@ -1,100 +1,115 @@ | |||
| package corrpc | |||
| import ( | |||
| "sync" | |||
| "time" | |||
| "crypto/tls" | |||
| "crypto/x509" | |||
| "fmt" | |||
| "os" | |||
| "gitlink.org.cn/cloudream/common/consts/errorcode" | |||
| "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" | |||
| grpc "google.golang.org/grpc" | |||
| "google.golang.org/grpc/credentials/insecure" | |||
| "google.golang.org/grpc" | |||
| "google.golang.org/grpc/credentials" | |||
| ) | |||
| type PoolConfig struct { | |||
| Address string `json:"address"` | |||
| Address string | |||
| Conn rpc.PoolConfig | |||
| } | |||
| type Pool struct { | |||
| cfg PoolConfig | |||
| grpcCon *grpcCon | |||
| lock sync.Mutex | |||
| } | |||
| type grpcCon struct { | |||
| grpcCon *grpc.ClientConn | |||
| refCount int | |||
| stopClosing chan any | |||
| type PoolConfigJSON struct { | |||
| Address string `json:"address"` | |||
| RootCA string `json:"rootCA"` | |||
| ClientCert string `json:"clientCert"` | |||
| ClientKey string `json:"clientKey"` | |||
| } | |||
| func NewPool(cfg PoolConfig) *Pool { | |||
| return &Pool{ | |||
| cfg: cfg, | |||
| func (c *PoolConfigJSON) Build(tokenProv rpc.AccessTokenProvider) (*PoolConfig, error) { | |||
| pc := &PoolConfig{ | |||
| Address: c.Address, | |||
| } | |||
| } | |||
| pc.Conn.AccessTokenProvider = tokenProv | |||
| func (p *Pool) Get() *Client { | |||
| p.lock.Lock() | |||
| defer p.lock.Unlock() | |||
| rootCA, err := os.ReadFile(c.RootCA) | |||
| if err != nil { | |||
| return nil, fmt.Errorf("load root ca: %v", err) | |||
| } | |||
| pc.Conn.RootCA = x509.NewCertPool() | |||
| if !pc.Conn.RootCA.AppendCertsFromPEM(rootCA) { | |||
| return nil, fmt.Errorf("failed to parse root ca") | |||
| } | |||
| con := p.grpcCon | |||
| if con == nil { | |||
| gcon, err := grpc.NewClient(p.cfg.Address, grpc.WithTransportCredentials(insecure.NewCredentials())) | |||
| if c.ClientCert != "" && c.ClientKey != "" { | |||
| cert, err := tls.LoadX509KeyPair(c.ClientCert, c.ClientKey) | |||
| if err != nil { | |||
| return &Client{ | |||
| con: nil, | |||
| pool: p, | |||
| fusedErr: rpc.Failed(errorcode.OperationFailed, err.Error()), | |||
| } | |||
| } | |||
| con = &grpcCon{ | |||
| grpcCon: gcon, | |||
| refCount: 0, | |||
| stopClosing: nil, | |||
| return nil, fmt.Errorf("load client cert: %v", err) | |||
| } | |||
| p.grpcCon = con | |||
| } else if con.stopClosing != nil { | |||
| close(con.stopClosing) | |||
| con.stopClosing = nil | |||
| pc.Conn.ClientCert = &cert | |||
| } else if tokenProv == nil { | |||
| return nil, fmt.Errorf("must provide client cert or access token provider") | |||
| } | |||
| con.refCount++ | |||
| return pc, nil | |||
| } | |||
| return &Client{ | |||
| con: con.grpcCon, | |||
| cli: NewCoordinatorClient(con.grpcCon), | |||
| pool: p, | |||
| func (c *PoolConfigJSON) BuildTempClient() (*TempClient, error) { | |||
| rootCA, err := os.ReadFile(c.RootCA) | |||
| if err != nil { | |||
| return nil, fmt.Errorf("load root ca: %v", err) | |||
| } | |||
| rootCAs := x509.NewCertPool() | |||
| if !rootCAs.AppendCertsFromPEM(rootCA) { | |||
| return nil, fmt.Errorf("failed to parse root ca") | |||
| } | |||
| } | |||
| func (p *Pool) release() { | |||
| p.lock.Lock() | |||
| defer p.lock.Unlock() | |||
| gcon, err := grpc.NewClient(c.Address, | |||
| grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ | |||
| RootCAs: rootCAs, | |||
| ServerName: rpc.ClientAPISNIV1, | |||
| NextProtos: []string{"h2"}, | |||
| })), | |||
| ) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| grpcCon := p.grpcCon | |||
| grpcCon.refCount-- | |||
| grpcCon.refCount = max(grpcCon.refCount, 0) | |||
| return &TempClient{ | |||
| Client: Client{ | |||
| con: gcon, | |||
| cli: NewCoordinatorClient(gcon), | |||
| pool: nil, | |||
| fusedErr: nil, | |||
| }, | |||
| }, nil | |||
| } | |||
| if grpcCon.refCount == 0 { | |||
| stopClosing := make(chan any) | |||
| grpcCon.stopClosing = stopClosing | |||
| type Pool struct { | |||
| cfg PoolConfig | |||
| connPool *rpc.ConnPool | |||
| } | |||
| go func() { | |||
| select { | |||
| case <-stopClosing: | |||
| return | |||
| func NewPool(cfg PoolConfig) *Pool { | |||
| return &Pool{ | |||
| cfg: cfg, | |||
| connPool: rpc.NewConnPool(cfg.Conn), | |||
| } | |||
| } | |||
| case <-time.After(time.Minute): | |||
| p.lock.Lock() | |||
| defer p.lock.Unlock() | |||
| func (p *Pool) Get() *Client { | |||
| con, err := p.connPool.GetConnection(p.cfg.Address) | |||
| if err != nil { | |||
| return &Client{ | |||
| con: nil, | |||
| cli: nil, | |||
| pool: p, | |||
| fusedErr: rpc.Failed(errorcode.OperationFailed, err.Error()), | |||
| } | |||
| } | |||
| if p.grpcCon.refCount == 0 { | |||
| p.grpcCon.grpcCon.Close() | |||
| p.grpcCon = nil | |||
| } | |||
| } | |||
| }() | |||
| return &Client{ | |||
| con: con, | |||
| cli: NewCoordinatorClient(con), | |||
| pool: p, | |||
| fusedErr: nil, | |||
| } | |||
| } | |||
| @@ -7,6 +7,7 @@ import ( | |||
| type CoordinatorAPI interface { | |||
| HubService | |||
| StorageService | |||
| UserService | |||
| } | |||
| type Server struct { | |||
| @@ -15,12 +16,26 @@ type Server struct { | |||
| svrImpl CoordinatorAPI | |||
| } | |||
| func NewServer(cfg rpc.Config, impl CoordinatorAPI) *Server { | |||
| func NewServer(cfg rpc.Config, impl CoordinatorAPI, tokenVerifier rpc.AccessTokenVerifier) *Server { | |||
| svr := &Server{ | |||
| svrImpl: impl, | |||
| } | |||
| svr.ServerBase = rpc.NewServerBase(cfg, svr, &Coordinator_ServiceDesc) | |||
| svr.ServerBase = rpc.NewServerBase(cfg, svr, &Coordinator_ServiceDesc, tokenAuthAPIs, tokenVerifier, noAuthAPIs) | |||
| return svr | |||
| } | |||
| var _ CoordinatorServer = (*Server)(nil) | |||
| var tokenAuthAPIs []string | |||
| func TokenAuth(api string) bool { | |||
| tokenAuthAPIs = append(tokenAuthAPIs, api) | |||
| return true | |||
| } | |||
| var noAuthAPIs []string | |||
| func NoAuth(api string) bool { | |||
| noAuthAPIs = append(noAuthAPIs, api) | |||
| return true | |||
| } | |||
| @@ -19,6 +19,8 @@ type SelectStorageHubResp struct { | |||
| Hubs []*cortypes.Hub | |||
| } | |||
| var _ = TokenAuth(Coordinator_SelectStorageHub_FullMethodName) | |||
| func (c *Client) SelectStorageHub(ctx context.Context, msg *SelectStorageHub) (*SelectStorageHubResp, *rpc.CodeError) { | |||
| if c.fusedErr != nil { | |||
| return nil, c.fusedErr | |||
| @@ -0,0 +1,95 @@ | |||
| package corrpc | |||
| import ( | |||
| context "context" | |||
| "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" | |||
| cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types" | |||
| ) | |||
| type UserService interface { | |||
| UserLogin(ctx context.Context, msg *UserLogin) (*UserLoginResp, *rpc.CodeError) | |||
| UserLogout(ctx context.Context, msg *UserLogout) (*UserLogoutResp, *rpc.CodeError) | |||
| UserRefreshToken(ctx context.Context, msg *UserRefreshToken) (*UserRefreshTokenResp, *rpc.CodeError) | |||
| HubLoadAccessToken(ctx context.Context, msg *HubLoadAccessToken) (*HubLoadAccessTokenResp, *rpc.CodeError) | |||
| } | |||
| // 客户端登录 | |||
| type UserLogin struct { | |||
| Account string | |||
| Password string | |||
| } | |||
| type UserLoginResp struct { | |||
| Token cortypes.UserAccessToken | |||
| PrivateKey string | |||
| } | |||
| var _ = NoAuth(Coordinator_UserLogin_FullMethodName) | |||
| func (c *Client) UserLogin(ctx context.Context, msg *UserLogin) (*UserLoginResp, *rpc.CodeError) { | |||
| if c.fusedErr != nil { | |||
| return nil, c.fusedErr | |||
| } | |||
| return rpc.UnaryClient[*UserLoginResp](c.cli.UserLogin, ctx, msg) | |||
| } | |||
| func (s *Server) UserLogin(ctx context.Context, req *rpc.Request) (*rpc.Response, error) { | |||
| return rpc.UnaryServer(s.svrImpl.UserLogin, ctx, req) | |||
| } | |||
| // 客户端刷新Token,原始Token会继续有效。 | |||
| type UserRefreshToken struct{} | |||
| type UserRefreshTokenResp struct { | |||
| Token cortypes.UserAccessToken | |||
| PrivateKey string | |||
| } | |||
| var _ = TokenAuth(Coordinator_UserLogin_FullMethodName) | |||
| func (c *Client) UserRefreshToken(ctx context.Context, msg *UserRefreshToken) (*UserRefreshTokenResp, *rpc.CodeError) { | |||
| if c.fusedErr != nil { | |||
| return nil, c.fusedErr | |||
| } | |||
| return rpc.UnaryClient[*UserRefreshTokenResp](c.cli.UserRefreshToken, ctx, msg) | |||
| } | |||
| func (s *Server) UserRefreshToken(ctx context.Context, req *rpc.Request) (*rpc.Response, error) { | |||
| return rpc.UnaryServer(s.svrImpl.UserRefreshToken, ctx, req) | |||
| } | |||
| // 客户端登出。会使用GRPC元数据中的TokenID和UserID来查找Token并删除。 | |||
| type UserLogout struct{} | |||
| type UserLogoutResp struct{} | |||
| var _ = TokenAuth(Coordinator_UserLogout_FullMethodName) | |||
| func (c *Client) UserLogout(ctx context.Context, msg *UserLogout) (*UserLogoutResp, *rpc.CodeError) { | |||
| if c.fusedErr != nil { | |||
| return nil, c.fusedErr | |||
| } | |||
| return rpc.UnaryClient[*UserLogoutResp](c.cli.UserLogout, ctx, msg) | |||
| } | |||
| func (s *Server) UserLogout(ctx context.Context, req *rpc.Request) (*rpc.Response, error) { | |||
| return rpc.UnaryServer(s.svrImpl.UserLogout, ctx, req) | |||
| } | |||
| // Hub服务加载AccessToken | |||
| type HubLoadAccessToken struct { | |||
| HubID cortypes.HubID | |||
| UserID cortypes.UserID | |||
| TokenID cortypes.AccessTokenID | |||
| } | |||
| type HubLoadAccessTokenResp struct { | |||
| Token cortypes.UserAccessToken | |||
| } | |||
| func (c *Client) HubLoadAccessToken(ctx context.Context, msg *HubLoadAccessToken) (*HubLoadAccessTokenResp, *rpc.CodeError) { | |||
| if c.fusedErr != nil { | |||
| return nil, c.fusedErr | |||
| } | |||
| return rpc.UnaryClient[*HubLoadAccessTokenResp](c.cli.HubLoadAccessToken, ctx, msg) | |||
| } | |||
| func (s *Server) HubLoadAccessToken(ctx context.Context, req *rpc.Request) (*rpc.Response, error) { | |||
| return rpc.UnaryServer(s.svrImpl.HubLoadAccessToken, ctx, req) | |||
| } | |||
| @@ -6,7 +6,7 @@ import ( | |||
| ) | |||
| type Client struct { | |||
| addr grpcAddr | |||
| addr string | |||
| con *grpc.ClientConn | |||
| cli HubClient | |||
| pool *Pool | |||
| @@ -15,7 +15,7 @@ type Client struct { | |||
| func (c *Client) Release() { | |||
| if c.con != nil { | |||
| c.pool.release(c.addr) | |||
| c.pool.connPool.Release(c.addr) | |||
| } | |||
| } | |||
| @@ -26,7 +26,7 @@ var file_pkgs_rpc_hub_hub_proto_rawDesc = []byte{ | |||
| 0x0a, 0x16, 0x70, 0x6b, 0x67, 0x73, 0x2f, 0x72, 0x70, 0x63, 0x2f, 0x68, 0x75, 0x62, 0x2f, 0x68, | |||
| 0x75, 0x62, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x06, 0x68, 0x75, 0x62, 0x72, 0x70, 0x63, | |||
| 0x1a, 0x12, 0x70, 0x6b, 0x67, 0x73, 0x2f, 0x72, 0x70, 0x63, 0x2f, 0x72, 0x70, 0x63, 0x2e, 0x70, | |||
| 0x72, 0x6f, 0x74, 0x6f, 0x32, 0xb8, 0x02, 0x0a, 0x03, 0x48, 0x75, 0x62, 0x12, 0x2c, 0x0a, 0x0d, | |||
| 0x72, 0x6f, 0x74, 0x6f, 0x32, 0xf5, 0x02, 0x0a, 0x03, 0x48, 0x75, 0x62, 0x12, 0x2c, 0x0a, 0x0d, | |||
| 0x45, 0x78, 0x65, 0x63, 0x75, 0x74, 0x65, 0x49, 0x4f, 0x50, 0x6c, 0x61, 0x6e, 0x12, 0x0c, 0x2e, | |||
| 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0d, 0x2e, 0x72, 0x70, | |||
| 0x63, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x31, 0x0a, 0x0c, 0x53, 0x65, | |||
| @@ -45,12 +45,16 @@ var file_pkgs_rpc_hub_hub_proto_rawDesc = []byte{ | |||
| 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0d, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, | |||
| 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x27, 0x0a, 0x08, 0x47, 0x65, 0x74, 0x53, 0x74, 0x61, | |||
| 0x74, 0x65, 0x12, 0x0c, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, | |||
| 0x1a, 0x0d, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, | |||
| 0x40, 0x5a, 0x3e, 0x67, 0x69, 0x74, 0x6c, 0x69, 0x6e, 0x6b, 0x2e, 0x6f, 0x72, 0x67, 0x2e, 0x63, | |||
| 0x6e, 0x2f, 0x63, 0x6c, 0x6f, 0x75, 0x64, 0x72, 0x65, 0x61, 0x6d, 0x2f, 0x6a, 0x63, 0x73, 0x2d, | |||
| 0x70, 0x75, 0x62, 0x2f, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2f, 0x70, 0x6b, 0x67, 0x73, 0x2f, | |||
| 0x72, 0x70, 0x63, 0x2f, 0x68, 0x75, 0x62, 0x72, 0x70, 0x63, 0x3b, 0x68, 0x75, 0x62, 0x72, 0x70, | |||
| 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, | |||
| 0x1a, 0x0d, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, | |||
| 0x3b, 0x0a, 0x1c, 0x4e, 0x6f, 0x74, 0x69, 0x66, 0x79, 0x55, 0x73, 0x65, 0x72, 0x41, 0x63, 0x63, | |||
| 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x49, 0x6e, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x12, | |||
| 0x0c, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0d, 0x2e, | |||
| 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x40, 0x5a, 0x3e, | |||
| 0x67, 0x69, 0x74, 0x6c, 0x69, 0x6e, 0x6b, 0x2e, 0x6f, 0x72, 0x67, 0x2e, 0x63, 0x6e, 0x2f, 0x63, | |||
| 0x6c, 0x6f, 0x75, 0x64, 0x72, 0x65, 0x61, 0x6d, 0x2f, 0x6a, 0x63, 0x73, 0x2d, 0x70, 0x75, 0x62, | |||
| 0x2f, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2f, 0x70, 0x6b, 0x67, 0x73, 0x2f, 0x72, 0x70, 0x63, | |||
| 0x2f, 0x68, 0x75, 0x62, 0x72, 0x70, 0x63, 0x3b, 0x68, 0x75, 0x62, 0x72, 0x70, 0x63, 0x62, 0x06, | |||
| 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, | |||
| } | |||
| var file_pkgs_rpc_hub_hub_proto_goTypes = []any{ | |||
| @@ -66,15 +70,17 @@ var file_pkgs_rpc_hub_hub_proto_depIdxs = []int32{ | |||
| 0, // 4: hubrpc.Hub.GetIOVar:input_type -> rpc.Request | |||
| 0, // 5: hubrpc.Hub.Ping:input_type -> rpc.Request | |||
| 0, // 6: hubrpc.Hub.GetState:input_type -> rpc.Request | |||
| 2, // 7: hubrpc.Hub.ExecuteIOPlan:output_type -> rpc.Response | |||
| 2, // 8: hubrpc.Hub.SendIOStream:output_type -> rpc.Response | |||
| 1, // 9: hubrpc.Hub.GetIOStream:output_type -> rpc.ChunkedData | |||
| 2, // 10: hubrpc.Hub.SendIOVar:output_type -> rpc.Response | |||
| 2, // 11: hubrpc.Hub.GetIOVar:output_type -> rpc.Response | |||
| 2, // 12: hubrpc.Hub.Ping:output_type -> rpc.Response | |||
| 2, // 13: hubrpc.Hub.GetState:output_type -> rpc.Response | |||
| 7, // [7:14] is the sub-list for method output_type | |||
| 0, // [0:7] is the sub-list for method input_type | |||
| 0, // 7: hubrpc.Hub.NotifyUserAccessTokenInvalid:input_type -> rpc.Request | |||
| 2, // 8: hubrpc.Hub.ExecuteIOPlan:output_type -> rpc.Response | |||
| 2, // 9: hubrpc.Hub.SendIOStream:output_type -> rpc.Response | |||
| 1, // 10: hubrpc.Hub.GetIOStream:output_type -> rpc.ChunkedData | |||
| 2, // 11: hubrpc.Hub.SendIOVar:output_type -> rpc.Response | |||
| 2, // 12: hubrpc.Hub.GetIOVar:output_type -> rpc.Response | |||
| 2, // 13: hubrpc.Hub.Ping:output_type -> rpc.Response | |||
| 2, // 14: hubrpc.Hub.GetState:output_type -> rpc.Response | |||
| 2, // 15: hubrpc.Hub.NotifyUserAccessTokenInvalid:output_type -> rpc.Response | |||
| 8, // [8:16] is the sub-list for method output_type | |||
| 0, // [0:8] is the sub-list for method input_type | |||
| 0, // [0:0] is the sub-list for extension type_name | |||
| 0, // [0:0] is the sub-list for extension extendee | |||
| 0, // [0:0] is the sub-list for field type_name | |||
| @@ -16,4 +16,6 @@ service Hub { | |||
| rpc Ping(rpc.Request) returns(rpc.Response); | |||
| rpc GetState(rpc.Request) returns(rpc.Response); | |||
| rpc NotifyUserAccessTokenInvalid(rpc.Request) returns(rpc.Response); | |||
| } | |||
| @@ -20,13 +20,14 @@ import ( | |||
| const _ = grpc.SupportPackageIsVersion7 | |||
| const ( | |||
| Hub_ExecuteIOPlan_FullMethodName = "/hubrpc.Hub/ExecuteIOPlan" | |||
| Hub_SendIOStream_FullMethodName = "/hubrpc.Hub/SendIOStream" | |||
| Hub_GetIOStream_FullMethodName = "/hubrpc.Hub/GetIOStream" | |||
| Hub_SendIOVar_FullMethodName = "/hubrpc.Hub/SendIOVar" | |||
| Hub_GetIOVar_FullMethodName = "/hubrpc.Hub/GetIOVar" | |||
| Hub_Ping_FullMethodName = "/hubrpc.Hub/Ping" | |||
| Hub_GetState_FullMethodName = "/hubrpc.Hub/GetState" | |||
| Hub_ExecuteIOPlan_FullMethodName = "/hubrpc.Hub/ExecuteIOPlan" | |||
| Hub_SendIOStream_FullMethodName = "/hubrpc.Hub/SendIOStream" | |||
| Hub_GetIOStream_FullMethodName = "/hubrpc.Hub/GetIOStream" | |||
| Hub_SendIOVar_FullMethodName = "/hubrpc.Hub/SendIOVar" | |||
| Hub_GetIOVar_FullMethodName = "/hubrpc.Hub/GetIOVar" | |||
| Hub_Ping_FullMethodName = "/hubrpc.Hub/Ping" | |||
| Hub_GetState_FullMethodName = "/hubrpc.Hub/GetState" | |||
| Hub_NotifyUserAccessTokenInvalid_FullMethodName = "/hubrpc.Hub/NotifyUserAccessTokenInvalid" | |||
| ) | |||
| // HubClient is the client API for Hub service. | |||
| @@ -40,6 +41,7 @@ type HubClient interface { | |||
| GetIOVar(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) | |||
| Ping(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) | |||
| GetState(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) | |||
| NotifyUserAccessTokenInvalid(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) | |||
| } | |||
| type hubClient struct { | |||
| @@ -161,6 +163,15 @@ func (c *hubClient) GetState(ctx context.Context, in *rpc.Request, opts ...grpc. | |||
| return out, nil | |||
| } | |||
| func (c *hubClient) NotifyUserAccessTokenInvalid(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) { | |||
| out := new(rpc.Response) | |||
| err := c.cc.Invoke(ctx, Hub_NotifyUserAccessTokenInvalid_FullMethodName, in, out, opts...) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| return out, nil | |||
| } | |||
| // HubServer is the server API for Hub service. | |||
| // All implementations must embed UnimplementedHubServer | |||
| // for forward compatibility | |||
| @@ -172,6 +183,7 @@ type HubServer interface { | |||
| GetIOVar(context.Context, *rpc.Request) (*rpc.Response, error) | |||
| Ping(context.Context, *rpc.Request) (*rpc.Response, error) | |||
| GetState(context.Context, *rpc.Request) (*rpc.Response, error) | |||
| NotifyUserAccessTokenInvalid(context.Context, *rpc.Request) (*rpc.Response, error) | |||
| mustEmbedUnimplementedHubServer() | |||
| } | |||
| @@ -200,6 +212,9 @@ func (UnimplementedHubServer) Ping(context.Context, *rpc.Request) (*rpc.Response | |||
| func (UnimplementedHubServer) GetState(context.Context, *rpc.Request) (*rpc.Response, error) { | |||
| return nil, status.Errorf(codes.Unimplemented, "method GetState not implemented") | |||
| } | |||
| func (UnimplementedHubServer) NotifyUserAccessTokenInvalid(context.Context, *rpc.Request) (*rpc.Response, error) { | |||
| return nil, status.Errorf(codes.Unimplemented, "method NotifyUserAccessTokenInvalid not implemented") | |||
| } | |||
| func (UnimplementedHubServer) mustEmbedUnimplementedHubServer() {} | |||
| // UnsafeHubServer may be embedded to opt out of forward compatibility for this service. | |||
| @@ -350,6 +365,24 @@ func _Hub_GetState_Handler(srv interface{}, ctx context.Context, dec func(interf | |||
| return interceptor(ctx, in, info, handler) | |||
| } | |||
| func _Hub_NotifyUserAccessTokenInvalid_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { | |||
| in := new(rpc.Request) | |||
| if err := dec(in); err != nil { | |||
| return nil, err | |||
| } | |||
| if interceptor == nil { | |||
| return srv.(HubServer).NotifyUserAccessTokenInvalid(ctx, in) | |||
| } | |||
| info := &grpc.UnaryServerInfo{ | |||
| Server: srv, | |||
| FullMethod: Hub_NotifyUserAccessTokenInvalid_FullMethodName, | |||
| } | |||
| handler := func(ctx context.Context, req interface{}) (interface{}, error) { | |||
| return srv.(HubServer).NotifyUserAccessTokenInvalid(ctx, req.(*rpc.Request)) | |||
| } | |||
| return interceptor(ctx, in, info, handler) | |||
| } | |||
| // Hub_ServiceDesc is the grpc.ServiceDesc for Hub service. | |||
| // It's only intended for direct use with grpc.RegisterService, | |||
| // and not to be introspected or modified (even as a copy) | |||
| @@ -377,6 +410,10 @@ var Hub_ServiceDesc = grpc.ServiceDesc{ | |||
| MethodName: "GetState", | |||
| Handler: _Hub_GetState_Handler, | |||
| }, | |||
| { | |||
| MethodName: "NotifyUserAccessTokenInvalid", | |||
| Handler: _Hub_NotifyUserAccessTokenInvalid_Handler, | |||
| }, | |||
| }, | |||
| Streams: []grpc.StreamDesc{ | |||
| { | |||
| @@ -22,6 +22,8 @@ type ExecuteIOPlan struct { | |||
| } | |||
| type ExecuteIOPlanResp struct{} | |||
| var _ = TokenAuth(Hub_ExecuteIOPlan_FullMethodName) | |||
| func (c *Client) ExecuteIOPlan(ctx context.Context, req *ExecuteIOPlan) (*ExecuteIOPlanResp, *rpc.CodeError) { | |||
| if c.fusedErr != nil { | |||
| return nil, c.fusedErr | |||
| @@ -49,6 +51,8 @@ func (s *SendIOStream) SetStream(str io.Reader) { | |||
| type SendIOStreamResp struct{} | |||
| var _ = TokenAuth(Hub_SendIOStream_FullMethodName) | |||
| func (c *Client) SendIOStream(ctx context.Context, req *SendIOStream) (*SendIOStreamResp, *rpc.CodeError) { | |||
| if c.fusedErr != nil { | |||
| return nil, c.fusedErr | |||
| @@ -71,6 +75,8 @@ type GetIOStreamResp struct { | |||
| Stream io.ReadCloser `json:"-"` | |||
| } | |||
| var _ = TokenAuth(Hub_GetIOStream_FullMethodName) | |||
| func (r *GetIOStreamResp) GetStream() io.ReadCloser { | |||
| return r.Stream | |||
| } | |||
| @@ -97,6 +103,8 @@ type SendIOVar struct { | |||
| } | |||
| type SendIOVarResp struct{} | |||
| var _ = TokenAuth(Hub_SendIOVar_FullMethodName) | |||
| func (c *Client) SendIOVar(ctx context.Context, req *SendIOVar) (*SendIOVarResp, *rpc.CodeError) { | |||
| if c.fusedErr != nil { | |||
| return nil, c.fusedErr | |||
| @@ -119,6 +127,8 @@ type GetIOVarResp struct { | |||
| Value exec.VarValue | |||
| } | |||
| var _ = TokenAuth(Hub_GetIOVar_FullMethodName) | |||
| func (c *Client) GetIOVar(ctx context.Context, req *GetIOVar) (*GetIOVarResp, *rpc.CodeError) { | |||
| if c.fusedErr != nil { | |||
| return nil, c.fusedErr | |||
| @@ -15,6 +15,8 @@ type MicsSvc interface { | |||
| type Ping struct{} | |||
| type PingResp struct{} | |||
| var _ = TokenAuth(Hub_Ping_FullMethodName) | |||
| func (c *Client) Ping(ctx context.Context, req *Ping) (*PingResp, *rpc.CodeError) { | |||
| if c.fusedErr != nil { | |||
| return nil, c.fusedErr | |||
| @@ -1,114 +1,77 @@ | |||
| package hubrpc | |||
| import ( | |||
| "crypto/tls" | |||
| "crypto/x509" | |||
| "fmt" | |||
| "sync" | |||
| "time" | |||
| "os" | |||
| "gitlink.org.cn/cloudream/common/consts/errorcode" | |||
| "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" | |||
| grpc "google.golang.org/grpc" | |||
| "google.golang.org/grpc/credentials/insecure" | |||
| ) | |||
| type PoolConfig struct{} | |||
| type PoolConfig struct { | |||
| Conn rpc.PoolConfig | |||
| } | |||
| type Pool struct { | |||
| grpcCons map[grpcAddr]*grpcCon | |||
| lock sync.Mutex | |||
| type PoolConfigJSON struct { | |||
| RootCA string `json:"rootCA"` | |||
| ClientCert string `json:"clientCert"` | |||
| ClientKey string `json:"clientKey"` | |||
| } | |||
| type grpcAddr struct { | |||
| IP string | |||
| Port int | |||
| func (c *PoolConfigJSON) Build(tokenProv rpc.AccessTokenProvider) (*PoolConfig, error) { | |||
| pc := &PoolConfig{} | |||
| pc.Conn.AccessTokenProvider = tokenProv | |||
| rootCA, err := os.ReadFile(c.RootCA) | |||
| if err != nil { | |||
| return nil, fmt.Errorf("load root ca: %v", err) | |||
| } | |||
| pc.Conn.RootCA = x509.NewCertPool() | |||
| if !pc.Conn.RootCA.AppendCertsFromPEM(rootCA) { | |||
| return nil, fmt.Errorf("failed to parse root ca") | |||
| } | |||
| if c.ClientCert != "" && c.ClientKey != "" { | |||
| cert, err := tls.LoadX509KeyPair(c.ClientCert, c.ClientKey) | |||
| if err != nil { | |||
| return nil, fmt.Errorf("load client cert: %v", err) | |||
| } | |||
| pc.Conn.ClientCert = &cert | |||
| } else if tokenProv == nil { | |||
| return nil, fmt.Errorf("must provide client cert or access token provider") | |||
| } | |||
| return pc, nil | |||
| } | |||
| type grpcCon struct { | |||
| grpcCon *grpc.ClientConn | |||
| refCount int | |||
| stopClosing chan any | |||
| type Pool struct { | |||
| connPool *rpc.ConnPool | |||
| } | |||
| func NewPool(cfg PoolConfig) *Pool { | |||
| return &Pool{ | |||
| grpcCons: make(map[grpcAddr]*grpcCon), | |||
| connPool: rpc.NewConnPool(cfg.Conn), | |||
| } | |||
| } | |||
| func (p *Pool) Get(ip string, port int) *Client { | |||
| p.lock.Lock() | |||
| defer p.lock.Unlock() | |||
| ga := grpcAddr{IP: ip, Port: port} | |||
| con := p.grpcCons[ga] | |||
| if con == nil { | |||
| gcon, err := grpc.NewClient(fmt.Sprintf("%v:%v", ip, port), grpc.WithTransportCredentials(insecure.NewCredentials())) | |||
| if err != nil { | |||
| return &Client{ | |||
| addr: ga, | |||
| con: nil, | |||
| pool: p, | |||
| fusedErr: rpc.Failed(errorcode.OperationFailed, err.Error()), | |||
| } | |||
| } | |||
| con = &grpcCon{ | |||
| grpcCon: gcon, | |||
| refCount: 0, | |||
| stopClosing: nil, | |||
| addr := fmt.Sprintf("%s:%d", ip, port) | |||
| con, err := p.connPool.GetConnection(addr) | |||
| if err != nil { | |||
| return &Client{ | |||
| addr: addr, | |||
| con: nil, | |||
| pool: p, | |||
| fusedErr: rpc.Failed(errorcode.OperationFailed, err.Error()), | |||
| } | |||
| p.grpcCons[ga] = con | |||
| } else if con.stopClosing != nil { | |||
| close(con.stopClosing) | |||
| con.stopClosing = nil | |||
| } | |||
| con.refCount++ | |||
| return &Client{ | |||
| addr: ga, | |||
| con: con.grpcCon, | |||
| cli: NewHubClient(con.grpcCon), | |||
| addr: addr, | |||
| con: con, | |||
| cli: NewHubClient(con), | |||
| pool: p, | |||
| } | |||
| } | |||
| func (p *Pool) release(addr grpcAddr) { | |||
| p.lock.Lock() | |||
| defer p.lock.Unlock() | |||
| grpcCon := p.grpcCons[addr] | |||
| if grpcCon == nil { | |||
| return | |||
| } | |||
| grpcCon.refCount-- | |||
| grpcCon.refCount = max(grpcCon.refCount, 0) | |||
| if grpcCon.refCount == 0 { | |||
| stopClosing := make(chan any) | |||
| grpcCon.stopClosing = stopClosing | |||
| go func() { | |||
| select { | |||
| case <-stopClosing: | |||
| return | |||
| case <-time.After(time.Minute): | |||
| p.lock.Lock() | |||
| defer p.lock.Unlock() | |||
| grpcCon := p.grpcCons[addr] | |||
| if grpcCon == nil { | |||
| return | |||
| } | |||
| if grpcCon.refCount == 0 { | |||
| grpcCon.grpcCon.Close() | |||
| delete(p.grpcCons, addr) | |||
| } | |||
| } | |||
| }() | |||
| } | |||
| } | |||
| @@ -8,7 +8,7 @@ type HubAPI interface { | |||
| // CacheSvc | |||
| IOSwitchSvc | |||
| MicsSvc | |||
| // UserSpaceSvc | |||
| UserSvc | |||
| } | |||
| type Server struct { | |||
| @@ -17,12 +17,26 @@ type Server struct { | |||
| svrImpl HubAPI | |||
| } | |||
| func NewServer(cfg rpc.Config, impl HubAPI) *Server { | |||
| func NewServer(cfg rpc.Config, impl HubAPI, tokenVerifier rpc.AccessTokenVerifier) *Server { | |||
| svr := &Server{ | |||
| svrImpl: impl, | |||
| } | |||
| svr.ServerBase = rpc.NewServerBase(cfg, svr, &Hub_ServiceDesc) | |||
| svr.ServerBase = rpc.NewServerBase(cfg, svr, &Hub_ServiceDesc, tokenAuthAPIs, tokenVerifier, noAuthAPIs) | |||
| return svr | |||
| } | |||
| var _ HubServer = (*Server)(nil) | |||
| var tokenAuthAPIs []string | |||
| func TokenAuth(api string) bool { | |||
| tokenAuthAPIs = append(tokenAuthAPIs, api) | |||
| return true | |||
| } | |||
| var noAuthAPIs []string | |||
| func NoAuth(api string) bool { | |||
| noAuthAPIs = append(noAuthAPIs, api) | |||
| return true | |||
| } | |||
| @@ -0,0 +1,29 @@ | |||
| package hubrpc | |||
| import ( | |||
| context "context" | |||
| "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" | |||
| cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types" | |||
| ) | |||
| type UserSvc interface { | |||
| NotifyUserAccessTokenInvalid(ctx context.Context, req *NotifyUserAccessTokenInvalid) (*NotifyUserAccessTokenInvalidResp, *rpc.CodeError) | |||
| } | |||
| // 通知用户的Token登出 | |||
| type NotifyUserAccessTokenInvalid struct { | |||
| UserID cortypes.UserID | |||
| TokenID cortypes.AccessTokenID | |||
| } | |||
| type NotifyUserAccessTokenInvalidResp struct{} | |||
| func (c *Client) NotifyUserAccessTokenInvalid(ctx context.Context, req *NotifyUserAccessTokenInvalid) (*NotifyUserAccessTokenInvalidResp, *rpc.CodeError) { | |||
| if c.fusedErr != nil { | |||
| return nil, c.fusedErr | |||
| } | |||
| return rpc.UnaryClient[*NotifyUserAccessTokenInvalidResp](c.cli.NotifyUserAccessTokenInvalid, ctx, req) | |||
| } | |||
| func (s *Server) NotifyUserAccessTokenInvalid(ctx context.Context, req *rpc.Request) (*rpc.Response, error) { | |||
| return rpc.UnaryServer(s.svrImpl.NotifyUserAccessTokenInvalid, ctx, req) | |||
| } | |||
| @@ -0,0 +1,176 @@ | |||
| package rpc | |||
| import ( | |||
| context "context" | |||
| "crypto/tls" | |||
| "crypto/x509" | |||
| "fmt" | |||
| "sync" | |||
| "time" | |||
| grpc "google.golang.org/grpc" | |||
| "google.golang.org/grpc/credentials" | |||
| "google.golang.org/grpc/metadata" | |||
| ) | |||
| type PoolConfig struct { | |||
| RootCA *x509.CertPool | |||
| // 客户端证书,与AccessTokenProvider二选一 | |||
| ClientCert *tls.Certificate | |||
| // AccessTokenProvider,与ClientCert二选一 | |||
| AccessTokenProvider AccessTokenProvider | |||
| } | |||
| type ConnPool struct { | |||
| cfg PoolConfig | |||
| grpcCons map[string]*grpcCon | |||
| lock sync.Mutex | |||
| } | |||
| type grpcCon struct { | |||
| grpcCon *grpc.ClientConn | |||
| refCount int | |||
| stopClosing chan any | |||
| } | |||
| func NewConnPool(cfg PoolConfig) *ConnPool { | |||
| return &ConnPool{ | |||
| cfg: cfg, | |||
| grpcCons: make(map[string]*grpcCon), | |||
| } | |||
| } | |||
| func (p *ConnPool) GetConnection(addr string) (*grpc.ClientConn, error) { | |||
| p.lock.Lock() | |||
| defer p.lock.Unlock() | |||
| con := p.grpcCons[addr] | |||
| if con == nil { | |||
| gcon, err := p.connecting(addr) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| con = &grpcCon{ | |||
| grpcCon: gcon, | |||
| refCount: 0, | |||
| stopClosing: nil, | |||
| } | |||
| p.grpcCons[addr] = con | |||
| } else if con.stopClosing != nil { | |||
| close(con.stopClosing) | |||
| con.stopClosing = nil | |||
| } | |||
| con.refCount++ | |||
| return con.grpcCon, nil | |||
| } | |||
| func (p *ConnPool) connecting(addr string) (*grpc.ClientConn, error) { | |||
| if p.cfg.ClientCert != nil { | |||
| gcon, err := grpc.NewClient(addr, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ | |||
| RootCAs: p.cfg.RootCA, | |||
| Certificates: []tls.Certificate{*p.cfg.ClientCert}, | |||
| ServerName: InternalAPISNIV1, | |||
| NextProtos: []string{"h2"}, | |||
| }))) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| return gcon, nil | |||
| } | |||
| if p.cfg.AccessTokenProvider == nil { | |||
| return nil, fmt.Errorf("no client cert or access token provider") | |||
| } | |||
| gcon, err := grpc.NewClient(addr, | |||
| grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ | |||
| RootCAs: p.cfg.RootCA, | |||
| ServerName: ClientAPISNIV1, | |||
| NextProtos: []string{"h2"}, | |||
| })), | |||
| grpc.WithUnaryInterceptor(p.populateAccessTokenUnary), | |||
| grpc.WithStreamInterceptor(p.populateAccessTokenStream), | |||
| ) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| return gcon, nil | |||
| } | |||
| func (p *ConnPool) Release(addr string) { | |||
| p.lock.Lock() | |||
| defer p.lock.Unlock() | |||
| grpcCon := p.grpcCons[addr] | |||
| if grpcCon == nil { | |||
| return | |||
| } | |||
| grpcCon.refCount-- | |||
| grpcCon.refCount = max(grpcCon.refCount, 0) | |||
| if grpcCon.refCount == 0 { | |||
| stopClosing := make(chan any) | |||
| grpcCon.stopClosing = stopClosing | |||
| go func() { | |||
| select { | |||
| case <-stopClosing: | |||
| return | |||
| case <-time.After(time.Minute): | |||
| p.lock.Lock() | |||
| defer p.lock.Unlock() | |||
| grpcCon := p.grpcCons[addr] | |||
| if grpcCon == nil { | |||
| return | |||
| } | |||
| if grpcCon.refCount == 0 { | |||
| grpcCon.grpcCon.Close() | |||
| delete(p.grpcCons, addr) | |||
| } | |||
| } | |||
| }() | |||
| } | |||
| } | |||
| func (p *ConnPool) populateAccessTokenUnary(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { | |||
| authInfo, err := p.cfg.AccessTokenProvider.GetAuthInfo() | |||
| if err != nil { | |||
| return err | |||
| } | |||
| md := metadata.Pairs( | |||
| MetaUserID, fmt.Sprintf("%v", authInfo.UserID), | |||
| MetaAccessTokenID, fmt.Sprintf("%v", authInfo.AccessTokenID), | |||
| MetaNonce, authInfo.Nonce, | |||
| MetaSignature, authInfo.Signature, | |||
| ) | |||
| ctx = metadata.NewOutgoingContext(ctx, md) | |||
| return invoker(ctx, method, req, reply, cc, opts...) | |||
| } | |||
| func (p *ConnPool) populateAccessTokenStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { | |||
| authInfo, err := p.cfg.AccessTokenProvider.GetAuthInfo() | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| md := metadata.Pairs( | |||
| MetaUserID, fmt.Sprintf("%v", authInfo.UserID), | |||
| MetaAccessTokenID, fmt.Sprintf("%v", authInfo.AccessTokenID), | |||
| MetaNonce, authInfo.Nonce, | |||
| MetaSignature, authInfo.Signature, | |||
| ) | |||
| ctx = metadata.NewOutgoingContext(ctx, md) | |||
| return streamer(ctx, desc, cc, method, opts...) | |||
| } | |||
| @@ -0,0 +1 @@ | |||
| package rpc | |||
| @@ -1,11 +1,16 @@ | |||
| package rpc | |||
| import ( | |||
| "crypto/tls" | |||
| "crypto/x509" | |||
| "fmt" | |||
| "net" | |||
| "os" | |||
| "gitlink.org.cn/cloudream/common/pkgs/async" | |||
| "gitlink.org.cn/cloudream/common/pkgs/logger" | |||
| "google.golang.org/grpc" | |||
| "google.golang.org/grpc/credentials" | |||
| ) | |||
| type ServerEventChan = async.UnboundChannel[RPCServerEvent] | |||
| @@ -20,34 +25,84 @@ type ExitEvent struct { | |||
| } | |||
| type Config struct { | |||
| Listen string `json:"listen"` | |||
| Listen string `json:"listen"` | |||
| RootCA string `json:"rootCA"` | |||
| ServerCert string `json:"serverCert"` | |||
| ServerKey string `json:"serverKey"` | |||
| } | |||
| type ServerBase struct { | |||
| cfg Config | |||
| grpcSvr *grpc.Server | |||
| srvImpl any | |||
| svcDesc *grpc.ServiceDesc | |||
| cfg Config | |||
| grpcSvr *grpc.Server | |||
| srvImpl any | |||
| svcDesc *grpc.ServiceDesc | |||
| rootCA *x509.CertPool | |||
| serverCert tls.Certificate | |||
| accessTokenAuthAPIs map[string]bool | |||
| tokenVerifier AccessTokenVerifier | |||
| noAuthAPIs map[string]bool | |||
| } | |||
| func NewServerBase(cfg Config, srvImpl any, svcDesc *grpc.ServiceDesc) *ServerBase { | |||
| func NewServerBase(cfg Config, srvImpl any, svcDesc *grpc.ServiceDesc, accessTokenAuthAPIs []string, tokenVerifier AccessTokenVerifier, noAuthAPIs []string) *ServerBase { | |||
| tokenAuthAPIs := make(map[string]bool) | |||
| for _, api := range accessTokenAuthAPIs { | |||
| tokenAuthAPIs[api] = true | |||
| } | |||
| noAuths := make(map[string]bool) | |||
| for _, api := range noAuthAPIs { | |||
| noAuths[api] = true | |||
| } | |||
| return &ServerBase{ | |||
| cfg: cfg, | |||
| srvImpl: srvImpl, | |||
| svcDesc: svcDesc, | |||
| cfg: cfg, | |||
| srvImpl: srvImpl, | |||
| svcDesc: svcDesc, | |||
| accessTokenAuthAPIs: tokenAuthAPIs, | |||
| tokenVerifier: tokenVerifier, | |||
| noAuthAPIs: noAuths, | |||
| } | |||
| } | |||
| func (s *ServerBase) Start() *ServerEventChan { | |||
| ch := async.NewUnboundChannel[RPCServerEvent]() | |||
| go func() { | |||
| svrCert, err := tls.LoadX509KeyPair(s.cfg.ServerCert, s.cfg.ServerKey) | |||
| if err != nil { | |||
| logger.Warnf("load server cert: %v", err) | |||
| ch.Send(&ExitEvent{Err: err}) | |||
| return | |||
| } | |||
| s.serverCert = svrCert | |||
| rootCA, err := os.ReadFile(s.cfg.RootCA) | |||
| if err != nil { | |||
| logger.Warnf("load root ca: %v", err) | |||
| ch.Send(&ExitEvent{Err: err}) | |||
| return | |||
| } | |||
| s.rootCA = x509.NewCertPool() | |||
| if !s.rootCA.AppendCertsFromPEM(rootCA) { | |||
| logger.Warnf("load root ca: failed to parse root ca") | |||
| ch.Send(&ExitEvent{Err: fmt.Errorf("failed to parse root ca")}) | |||
| return | |||
| } | |||
| logger.Infof("start serving rpc at: %v", s.cfg.Listen) | |||
| lis, err := net.Listen("tcp", s.cfg.Listen) | |||
| if err != nil { | |||
| ch.Send(&ExitEvent{Err: err}) | |||
| return | |||
| } | |||
| s.grpcSvr = grpc.NewServer() | |||
| s.grpcSvr = grpc.NewServer( | |||
| grpc.Creds(credentials.NewTLS(&tls.Config{ | |||
| GetConfigForClient: s.tlsConfigSelector, | |||
| })), | |||
| grpc.UnaryInterceptor(s.authUnary), | |||
| grpc.StreamInterceptor(s.authStream), | |||
| ) | |||
| s.grpcSvr.RegisterService(s.svcDesc, s.srvImpl) | |||
| err = s.grpcSvr.Serve(lis) | |||
| ch.Send(&ExitEvent{Err: err}) | |||
| @@ -38,17 +38,17 @@ func UnaryClient[Resp, Req any](apiFn func(context.Context, *Request, ...grpc.Ca | |||
| func UnaryServer[Resp, Req any](apiFn func(context.Context, Req) (Resp, *CodeError), ctx context.Context, req *Request) (*Response, error) { | |||
| rreq, err := serder.JSONToObjectEx[Req](req.Payload) | |||
| if err != nil { | |||
| return nil, makeCodeError(errorcode.OperationFailed, err.Error()) | |||
| return nil, MakeCodeError(errorcode.OperationFailed, err.Error()) | |||
| } | |||
| ret, cerr := apiFn(ctx, rreq) | |||
| if cerr != nil { | |||
| return nil, wrapCodeError(cerr) | |||
| return nil, WrapCodeError(cerr) | |||
| } | |||
| data, err := serder.ObjectToJSONEx(ret) | |||
| if err != nil { | |||
| return nil, makeCodeError(errorcode.OperationFailed, err.Error()) | |||
| return nil, MakeCodeError(errorcode.OperationFailed, err.Error()) | |||
| } | |||
| return &Response{ | |||
| @@ -120,33 +120,33 @@ func UploadStreamServer[Resp any, Req UploadStreamReq, APIRet UploadStreamAPISer | |||
| cr := NewChunkedReader(req) | |||
| _, data, err := cr.NextDataPart() | |||
| if err != nil { | |||
| return makeCodeError(errorcode.OperationFailed, err.Error()) | |||
| return MakeCodeError(errorcode.OperationFailed, err.Error()) | |||
| } | |||
| _, pr, err := cr.NextPart() | |||
| if err != nil { | |||
| return makeCodeError(errorcode.OperationFailed, err.Error()) | |||
| return MakeCodeError(errorcode.OperationFailed, err.Error()) | |||
| } | |||
| rreq, err := serder.JSONToObjectEx[Req](data) | |||
| if err != nil { | |||
| return makeCodeError(errorcode.OperationFailed, err.Error()) | |||
| return MakeCodeError(errorcode.OperationFailed, err.Error()) | |||
| } | |||
| rreq.SetStream(pr) | |||
| resp, cerr := apiFn(req.Context(), rreq) | |||
| if cerr != nil { | |||
| return wrapCodeError(cerr) | |||
| return WrapCodeError(cerr) | |||
| } | |||
| respData, err := serder.ObjectToJSONEx(resp) | |||
| if err != nil { | |||
| return makeCodeError(errorcode.OperationFailed, err.Error()) | |||
| return MakeCodeError(errorcode.OperationFailed, err.Error()) | |||
| } | |||
| err = req.SendAndClose(&Response{Payload: respData}) | |||
| if err != nil { | |||
| return makeCodeError(errorcode.OperationFailed, err.Error()) | |||
| return MakeCodeError(errorcode.OperationFailed, err.Error()) | |||
| } | |||
| return nil | |||
| @@ -211,33 +211,33 @@ func DownloadStreamClient[Resp DownloadStreamResp, Req any, APIRet DownloadStrea | |||
| func DownloadStreamServer[Resp DownloadStreamResp, Req any, APIRet DownloadStreamAPIServer](apiFn func(context.Context, Req) (Resp, *CodeError), req *Request, ret APIRet) error { | |||
| rreq, err := serder.JSONToObjectEx[Req](req.Payload) | |||
| if err != nil { | |||
| return makeCodeError(errorcode.OperationFailed, err.Error()) | |||
| return MakeCodeError(errorcode.OperationFailed, err.Error()) | |||
| } | |||
| resp, cerr := apiFn(ret.Context(), rreq) | |||
| if cerr != nil { | |||
| return wrapCodeError(cerr) | |||
| return WrapCodeError(cerr) | |||
| } | |||
| cw := NewChunkedWriter(ret) | |||
| data, err := serder.ObjectToJSONEx(resp) | |||
| if err != nil { | |||
| return makeCodeError(errorcode.OperationFailed, err.Error()) | |||
| return MakeCodeError(errorcode.OperationFailed, err.Error()) | |||
| } | |||
| err = cw.WriteDataPart("", data) | |||
| if err != nil { | |||
| return makeCodeError(errorcode.OperationFailed, err.Error()) | |||
| return MakeCodeError(errorcode.OperationFailed, err.Error()) | |||
| } | |||
| _, err = cw.WriteStreamPart("", resp.GetStream()) | |||
| if err != nil { | |||
| return makeCodeError(errorcode.OperationFailed, err.Error()) | |||
| return MakeCodeError(errorcode.OperationFailed, err.Error()) | |||
| } | |||
| err = cw.Finish() | |||
| if err != nil { | |||
| return makeCodeError(errorcode.OperationFailed, err.Error()) | |||
| return MakeCodeError(errorcode.OperationFailed, err.Error()) | |||
| } | |||
| return nil | |||
| @@ -282,12 +282,12 @@ func getCodeError(err error) *CodeError { | |||
| return Failed(errorcode.OperationFailed, err.Error()) | |||
| } | |||
| func makeCodeError(code string, msg string) error { | |||
| func MakeCodeError(code string, msg string) error { | |||
| ce, _ := status.New(codes.Unknown, "custom error").WithDetails(Failed(code, msg)) | |||
| return ce.Err() | |||
| } | |||
| func wrapCodeError(ce *CodeError) error { | |||
| func WrapCodeError(ce *CodeError) error { | |||
| e, _ := status.New(codes.Unknown, "custom error").WithDetails(ce) | |||
| return e.Err() | |||
| } | |||
| @@ -0,0 +1,38 @@ | |||
| package accesstoken | |||
| import ( | |||
| "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/accesstoken" | |||
| "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/db" | |||
| cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types" | |||
| "gorm.io/gorm" | |||
| ) | |||
| type ExitEvent = accesstoken.ExitEvent | |||
| type CacheKey = accesstoken.CacheKey | |||
| type Cache struct { | |||
| *accesstoken.Cache | |||
| db *db.DB | |||
| } | |||
| func New(db *db.DB) *Cache { | |||
| c := &Cache{ | |||
| db: db, | |||
| } | |||
| c.Cache = accesstoken.New(c.load) | |||
| return c | |||
| } | |||
| func (c *Cache) load(key accesstoken.CacheKey) (cortypes.UserAccessToken, error) { | |||
| token, err := c.db.UserAccessToken().GetByID(c.db.DefCtx(), key.UserID, key.TokenID) | |||
| if err == gorm.ErrRecordNotFound { | |||
| return cortypes.UserAccessToken{}, accesstoken.ErrTokenNotFound | |||
| } | |||
| if err != nil { | |||
| return cortypes.UserAccessToken{}, err | |||
| } | |||
| return token, nil | |||
| } | |||
| @@ -0,0 +1,212 @@ | |||
| package cmd | |||
| import ( | |||
| "crypto/rand" | |||
| "crypto/rsa" | |||
| "crypto/x509" | |||
| "crypto/x509/pkix" | |||
| "encoding/pem" | |||
| "fmt" | |||
| "math/big" | |||
| "os" | |||
| "path/filepath" | |||
| "time" | |||
| "github.com/spf13/cobra" | |||
| "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" | |||
| ) | |||
| func init() { | |||
| certCmd := cobra.Command{ | |||
| Use: "cert", | |||
| } | |||
| RootCmd.AddCommand(&certCmd) | |||
| certRoot := cobra.Command{ | |||
| Use: "root [outputDir]", | |||
| Args: cobra.ExactArgs(1), | |||
| Run: func(cmd *cobra.Command, args []string) { | |||
| certRoot(args[0]) | |||
| }, | |||
| } | |||
| certCmd.AddCommand(&certRoot) | |||
| var certFilePath string | |||
| var keyFilePath string | |||
| certServer := cobra.Command{ | |||
| Use: "server [outputDir]", | |||
| Args: cobra.ExactArgs(1), | |||
| Run: func(cmd *cobra.Command, args []string) { | |||
| certServer(certFilePath, keyFilePath, args[0]) | |||
| }, | |||
| } | |||
| certServer.Flags().StringVar(&certFilePath, "cert", "", "CA certificate file path") | |||
| certServer.Flags().StringVar(&keyFilePath, "key", "", "CA key file path") | |||
| certCmd.AddCommand(&certServer) | |||
| certClient := cobra.Command{ | |||
| Use: "client [outputDir]", | |||
| Args: cobra.ExactArgs(1), | |||
| Run: func(cmd *cobra.Command, args []string) { | |||
| certClient(certFilePath, keyFilePath, args[0]) | |||
| }, | |||
| } | |||
| certClient.Flags().StringVar(&certFilePath, "cert", "", "CA certificate file path") | |||
| certClient.Flags().StringVar(&keyFilePath, "key", "", "CA key file path") | |||
| certCmd.AddCommand(&certClient) | |||
| } | |||
| func certRoot(output string) { | |||
| caPriv, _ := rsa.GenerateKey(rand.Reader, 2048) | |||
| // 创建 CA 证书模板 | |||
| caTemplate := &x509.Certificate{ | |||
| SerialNumber: big.NewInt(1), | |||
| Subject: pkix.Name{ | |||
| Organization: []string{"JCS"}, | |||
| }, | |||
| NotBefore: time.Now(), | |||
| NotAfter: time.Now().AddDate(10, 0, 0), // 有效期10年 | |||
| KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature, | |||
| BasicConstraintsValid: true, | |||
| IsCA: true, | |||
| } | |||
| // 自签名 CA 证书 | |||
| caCertDER, _ := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caPriv.PublicKey, caPriv) | |||
| // 保存 CA 证书和私钥 | |||
| writePem(filepath.Join(output, "ca_cert.pem"), "CERTIFICATE", caCertDER) | |||
| writePem(filepath.Join(output, "ca_key.pem"), "RSA PRIVATE KEY", x509.MarshalPKCS1PrivateKey(caPriv)) | |||
| fmt.Println("CA certificate and key saved to", output) | |||
| } | |||
| func certServer(certFile string, keyFile string, output string) { | |||
| // 读取 CA 证书和私钥 | |||
| caCertPEM, err := os.ReadFile(certFile) | |||
| if err != nil { | |||
| fmt.Println("Failed to read CA certificate:", err) | |||
| return | |||
| } | |||
| caKeyPEM, err := os.ReadFile(keyFile) | |||
| if err != nil { | |||
| fmt.Println("Failed to read CA key:", err) | |||
| return | |||
| } | |||
| caCertPEMBlock, _ := pem.Decode(caCertPEM) | |||
| if caCertPEMBlock == nil { | |||
| fmt.Println("Failed to decode CA certificate") | |||
| return | |||
| } | |||
| caKeyPEMBlock, _ := pem.Decode(caKeyPEM) | |||
| if caKeyPEMBlock == nil { | |||
| fmt.Println("Failed to decode CA key") | |||
| return | |||
| } | |||
| caCert, err := x509.ParseCertificate(caCertPEMBlock.Bytes) | |||
| if err != nil { | |||
| fmt.Println("Failed to parse CA certificate:", err) | |||
| return | |||
| } | |||
| caKey, err := x509.ParsePKCS1PrivateKey(caKeyPEMBlock.Bytes) | |||
| if err != nil { | |||
| fmt.Println("Failed to parse CA key:", err) | |||
| return | |||
| } | |||
| // 生成服务端私钥 | |||
| serverPriv, _ := rsa.GenerateKey(rand.Reader, 2048) | |||
| // 服务端证书模板 | |||
| serverTemplate := &x509.Certificate{ | |||
| SerialNumber: big.NewInt(2), | |||
| Subject: pkix.Name{ | |||
| CommonName: "localhost", | |||
| }, | |||
| NotBefore: time.Now(), | |||
| NotAfter: time.Now().AddDate(1, 0, 0), // 有效期1年 | |||
| KeyUsage: x509.KeyUsageDigitalSignature, | |||
| ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, | |||
| BasicConstraintsValid: true, | |||
| } | |||
| // 添加主机名/IP 到证书 | |||
| serverTemplate.DNSNames = []string{rpc.ClientAPISNIV1, rpc.InternalAPISNIV1} | |||
| // 用 CA 签发服务端证书 | |||
| serverCertDER, _ := x509.CreateCertificate(rand.Reader, serverTemplate, caCert, &serverPriv.PublicKey, caKey) | |||
| // 保存服务端证书和私钥 | |||
| writePem(filepath.Join(output, "server_cert.pem"), "CERTIFICATE", serverCertDER) | |||
| writePem(filepath.Join(output, "server_key.pem"), "RSA PRIVATE KEY", x509.MarshalPKCS1PrivateKey(serverPriv)) | |||
| fmt.Println("Server certificate and key saved to", output) | |||
| } | |||
| func certClient(certFile string, keyFile string, output string) { | |||
| // 读取 CA 证书和私钥 | |||
| caCertPEM, err := os.ReadFile(certFile) | |||
| if err != nil { | |||
| fmt.Println("Failed to read CA certificate:", err) | |||
| return | |||
| } | |||
| caKeyPEM, err := os.ReadFile(keyFile) | |||
| if err != nil { | |||
| fmt.Println("Failed to read CA key:", err) | |||
| return | |||
| } | |||
| caCertPEMBlock, _ := pem.Decode(caCertPEM) | |||
| if caCertPEMBlock == nil { | |||
| fmt.Println("Failed to decode CA certificate") | |||
| return | |||
| } | |||
| caKeyPEMBlock, _ := pem.Decode(caKeyPEM) | |||
| if caKeyPEMBlock == nil { | |||
| fmt.Println("Failed to decode CA key") | |||
| return | |||
| } | |||
| caCert, err := x509.ParseCertificate(caCertPEMBlock.Bytes) | |||
| if err != nil { | |||
| fmt.Println("Failed to parse CA certificate:", err) | |||
| return | |||
| } | |||
| caKey, err := x509.ParsePKCS1PrivateKey(caKeyPEMBlock.Bytes) | |||
| if err != nil { | |||
| fmt.Println("Failed to parse CA key:", err) | |||
| return | |||
| } | |||
| // 生成客户端私钥 | |||
| clientPriv, _ := rsa.GenerateKey(rand.Reader, 2048) | |||
| // 客户端证书模板 | |||
| clientTemplate := &x509.Certificate{ | |||
| SerialNumber: big.NewInt(3), | |||
| Subject: pkix.Name{ | |||
| CommonName: "client", | |||
| }, | |||
| NotBefore: time.Now(), | |||
| NotAfter: time.Now().AddDate(1, 0, 0), // 有效期1年 | |||
| KeyUsage: x509.KeyUsageDigitalSignature, | |||
| ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, | |||
| BasicConstraintsValid: true, | |||
| } | |||
| // 用 CA 签发客户端证书 | |||
| clientCertDER, _ := x509.CreateCertificate(rand.Reader, clientTemplate, caCert, &clientPriv.PublicKey, caKey) | |||
| // 保存客户端证书和私钥 | |||
| writePem(filepath.Join(output, "client_cert.pem"), "CERTIFICATE", clientCertDER) | |||
| writePem(filepath.Join(output, "client_key.pem"), "RSA PRIVATE KEY", x509.MarshalPKCS1PrivateKey(clientPriv)) | |||
| fmt.Println("Client certificate and key saved to", output) | |||
| } | |||
| func writePem(filename, pemType string, bytes []byte) { | |||
| f, _ := os.Create(filename) | |||
| pem.Encode(f, &pem.Block{Type: pemType, Bytes: bytes}) | |||
| f.Close() | |||
| } | |||
| @@ -42,6 +42,8 @@ func migrate(configPath string) { | |||
| migrateOne(db, cortypes.Hub{}) | |||
| migrateOne(db, cortypes.HubLocation{}) | |||
| migrateOne(db, cortypes.User{}) | |||
| migrateOne(db, cortypes.UserAccessToken{}) | |||
| migrateOne(db, cortypes.LoadedAccessToken{}) | |||
| fmt.Println("migrate success") | |||
| } | |||
| @@ -9,7 +9,7 @@ import ( | |||
| stgglb "gitlink.org.cn/cloudream/jcs-pub/common/globals" | |||
| "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" | |||
| corrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/coordinator" | |||
| hubrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/hub" | |||
| "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/accesstoken" | |||
| "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/config" | |||
| "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/db" | |||
| "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/repl" | |||
| @@ -44,7 +44,13 @@ func serve(configPath string) { | |||
| os.Exit(1) | |||
| } | |||
| stgglb.InitPools(&hubrpc.PoolConfig{}, nil) | |||
| hubRPCCfg, err := config.Cfg().HubRPC.Build(nil) | |||
| if err != nil { | |||
| logger.Errorf("build hub rpc config: %v", err) | |||
| os.Exit(1) | |||
| } | |||
| stgglb.InitPools(hubRPCCfg, nil) | |||
| db2, err := db.NewDB(&config.Cfg().DB) | |||
| if err != nil { | |||
| @@ -59,13 +65,18 @@ func serve(configPath string) { | |||
| // } | |||
| // go servePublisher(evtPub) | |||
| // 客户端访问令牌缓存 | |||
| accToken := accesstoken.New(db2) | |||
| accTokenChan := accToken.Start() | |||
| defer accToken.Stop() | |||
| // RPC服务 | |||
| rpcSvr := corrpc.NewServer(config.Cfg().RPC, myrpc.NewService(db2)) | |||
| rpcSvr := corrpc.NewServer(config.Cfg().RPC, myrpc.NewService(db2, accToken), accToken) | |||
| rpcSvrChan := rpcSvr.Start() | |||
| defer rpcSvr.Stop() | |||
| // 定时任务 | |||
| tktk := ticktock.New(config.Cfg().TickTock, db2) | |||
| tktk := ticktock.New(config.Cfg().TickTock, db2, accToken) | |||
| tktk.Start() | |||
| defer tktk.Stop() | |||
| @@ -74,11 +85,29 @@ func serve(configPath string) { | |||
| replCh := rep.Start() | |||
| /// 开始监听各个模块的事件 | |||
| accTokenEvt := accTokenChan.Receive() | |||
| replEvt := replCh.Receive() | |||
| rpcEvt := rpcSvrChan.Receive() | |||
| loop: | |||
| for { | |||
| select { | |||
| case e := <-accTokenEvt.Chan(): | |||
| if e.Err != nil { | |||
| logger.Errorf("receive access token event: %v", e.Err) | |||
| break loop | |||
| } | |||
| switch e := e.Value.(type) { | |||
| case accesstoken.ExitEvent: | |||
| if e.Err != nil { | |||
| logger.Errorf("access token cache exited with error: %v", e.Err) | |||
| } else { | |||
| logger.Info("access token cache exited") | |||
| } | |||
| break loop | |||
| } | |||
| accTokenEvt = accTokenChan.Receive() | |||
| case e := <-replEvt.Chan(): | |||
| if e.Err != nil { | |||
| logger.Errorf("receive repl event: %v", err) | |||
| @@ -4,15 +4,17 @@ import ( | |||
| log "gitlink.org.cn/cloudream/common/pkgs/logger" | |||
| c "gitlink.org.cn/cloudream/common/utils/config" | |||
| "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" | |||
| hubrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/hub" | |||
| "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/db" | |||
| "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/ticktock" | |||
| ) | |||
| type Config struct { | |||
| Logger log.Config `json:"logger"` | |||
| DB db.Config `json:"db"` | |||
| TickTock ticktock.Config `json:"tickTock"` | |||
| RPC rpc.Config `json:"rpc"` | |||
| Logger log.Config `json:"logger"` | |||
| DB db.Config `json:"db"` | |||
| TickTock ticktock.Config `json:"tickTock"` | |||
| RPC rpc.Config `json:"rpc"` | |||
| HubRPC hubrpc.PoolConfigJSON `json:"hubRPC"` | |||
| } | |||
| var cfg Config | |||
| @@ -0,0 +1,49 @@ | |||
| package db | |||
| import ( | |||
| "time" | |||
| cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types" | |||
| "gorm.io/gorm/clause" | |||
| ) | |||
| type LoadedAccessTokenDB struct { | |||
| *DB | |||
| } | |||
| func (db *DB) LoadedAccessToken() *LoadedAccessTokenDB { | |||
| return &LoadedAccessTokenDB{DB: db} | |||
| } | |||
| func (db *LoadedAccessTokenDB) GetByUserIDAndTokenID(ctx SQLContext, userID cortypes.UserID, tokenID cortypes.AccessTokenID) ([]cortypes.LoadedAccessToken, error) { | |||
| var ret []cortypes.LoadedAccessToken | |||
| err := ctx.Table("LoadedAccessToken").Where("UserID = ? AND TokenID = ?", userID, tokenID).Find(&ret).Error | |||
| return ret, err | |||
| } | |||
| func (*LoadedAccessTokenDB) CreateOrUpdate(ctx SQLContext, token cortypes.LoadedAccessToken) error { | |||
| return ctx.Clauses(clause.OnConflict{ | |||
| Columns: []clause.Column{{Name: "UserID"}, {Name: "TokenID"}, {Name: "HubID"}}, | |||
| DoUpdates: clause.AssignmentColumns([]string{"LoadedAt"}), | |||
| }).Create(token).Error | |||
| } | |||
| func (*LoadedAccessTokenDB) GetExpired(ctx SQLContext, expireAt time.Time) ([]cortypes.LoadedAccessToken, error) { | |||
| var ret []cortypes.LoadedAccessToken | |||
| err := ctx.Table("LoadedAccessToken"). | |||
| Select("LoadedAccessToken.*"). | |||
| Joins("join UserAccessToken on UserAccessToken.UserID = LoadedAccessToken.UserID and UserAccessToken.TokenID = LoadedAccessToken.TokenID"). | |||
| Where("UserAccessToken.ExpiresAt < ?", expireAt). | |||
| Find(&ret).Error | |||
| return ret, err | |||
| } | |||
| func (*LoadedAccessTokenDB) DeleteExpired(ctx SQLContext, expireAt time.Time) error { | |||
| return ctx.Table("LoadedAccessToken"). | |||
| Where("UserID in (select UserID from UserAccessToken where ExpiresAt < ?)", expireAt). | |||
| Delete(&cortypes.LoadedAccessToken{}).Error | |||
| } | |||
| func (db *LoadedAccessTokenDB) DeleteAllByUserIDAndTokenID(ctx SQLContext, userID cortypes.UserID, tokenID cortypes.AccessTokenID) error { | |||
| return ctx.Table("LoadedAccessToken").Where("UserID = ? AND TokenID = ?", userID, tokenID).Delete(&cortypes.LoadedAccessToken{}).Error | |||
| } | |||
| @@ -19,14 +19,14 @@ func (db *UserDB) GetByID(ctx SQLContext, userID cortypes.UserID) (cortypes.User | |||
| return ret, err | |||
| } | |||
| func (db *UserDB) GetByName(ctx SQLContext, name string) (cortypes.User, error) { | |||
| func (db *UserDB) GetByAccount(ctx SQLContext, account string) (cortypes.User, error) { | |||
| var ret cortypes.User | |||
| err := ctx.Table("User").Where("Name = ?", name).First(&ret).Error | |||
| err := ctx.Table("User").Where("Account = ?", account).First(&ret).Error | |||
| return ret, err | |||
| } | |||
| func (db *UserDB) Create(ctx SQLContext, name string) (cortypes.User, error) { | |||
| _, err := db.GetByName(ctx, name) | |||
| func (db *UserDB) Create(ctx SQLContext, account string, password string, nickName string) (cortypes.User, error) { | |||
| _, err := db.GetByAccount(ctx, account) | |||
| if err == nil { | |||
| return cortypes.User{}, gorm.ErrDuplicatedKey | |||
| } | |||
| @@ -34,7 +34,7 @@ func (db *UserDB) Create(ctx SQLContext, name string) (cortypes.User, error) { | |||
| return cortypes.User{}, err | |||
| } | |||
| user := cortypes.User{Name: name} | |||
| user := cortypes.User{NickName: nickName, Account: account, Password: password} | |||
| err = ctx.Table("User").Create(&user).Error | |||
| return user, err | |||
| } | |||
| @@ -0,0 +1,33 @@ | |||
| package db | |||
| import ( | |||
| "time" | |||
| cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types" | |||
| ) | |||
| type UserAccessTokenDB struct { | |||
| *DB | |||
| } | |||
| func (db *DB) UserAccessToken() *UserAccessTokenDB { | |||
| return &UserAccessTokenDB{DB: db} | |||
| } | |||
| func (db *UserAccessTokenDB) GetByID(ctx SQLContext, userID cortypes.UserID, tokenID cortypes.AccessTokenID) (cortypes.UserAccessToken, error) { | |||
| var ret cortypes.UserAccessToken | |||
| err := ctx.Table("UserAccessToken").Where("UserID = ? AND TokenID = ?", userID, tokenID).First(&ret).Error | |||
| return ret, err | |||
| } | |||
| func (*UserAccessTokenDB) Create(ctx SQLContext, token *cortypes.UserAccessToken) error { | |||
| return ctx.Table("UserAccessToken").Create(token).Error | |||
| } | |||
| func (db *UserAccessTokenDB) DeleteByID(ctx SQLContext, userID cortypes.UserID, tokenID cortypes.AccessTokenID) error { | |||
| return ctx.Table("UserAccessToken").Where("UserID = ? AND TokenID = ?", userID, tokenID).Delete(&cortypes.UserAccessToken{}).Error | |||
| } | |||
| func (*UserAccessTokenDB) DeleteExpired(ctx SQLContext, expireTime time.Time) error { | |||
| return ctx.Table("UserAccessToken").Where("ExpiresAt < ?", expireTime).Delete(&cortypes.UserAccessToken{}).Error | |||
| } | |||
| @@ -0,0 +1,62 @@ | |||
| package repl | |||
| import ( | |||
| "encoding/hex" | |||
| "fmt" | |||
| "os" | |||
| "github.com/spf13/cobra" | |||
| "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/db" | |||
| cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types" | |||
| "golang.org/x/crypto/bcrypt" | |||
| "golang.org/x/term" | |||
| ) | |||
| func init() { | |||
| userCmd := &cobra.Command{ | |||
| Use: "user", | |||
| Short: "user command", | |||
| } | |||
| RootCmd.AddCommand(userCmd) | |||
| createCmd := &cobra.Command{ | |||
| Use: "create [account] [nickName]", | |||
| Short: "create a new user account", | |||
| Args: cobra.ExactArgs(2), | |||
| Run: func(cmd *cobra.Command, args []string) { | |||
| userCreate(GetCmdCtx(cmd), args[0], args[1]) | |||
| }, | |||
| } | |||
| userCmd.AddCommand(createCmd) | |||
| } | |||
| func userCreate(ctx *CommandContext, account string, nickName string) { | |||
| _, err := ctx.repl.db.User().GetByAccount(ctx.repl.db.DefCtx(), account) | |||
| if err == nil { | |||
| fmt.Printf("user %s already exists\n", account) | |||
| return | |||
| } | |||
| fmt.Printf("input account password: ") | |||
| pass, err := term.ReadPassword(int(os.Stdin.Fd())) | |||
| if err != nil { | |||
| fmt.Println("error reading password:", err) | |||
| return | |||
| } | |||
| passHash, err := bcrypt.GenerateFromPassword(pass, bcrypt.DefaultCost) | |||
| if err != nil { | |||
| fmt.Println("error hashing password:", err) | |||
| return | |||
| } | |||
| user, err := db.DoTx02(ctx.repl.db, func(tx db.SQLContext) (cortypes.User, error) { | |||
| return ctx.repl.db.User().Create(tx, account, hex.EncodeToString(passHash), nickName) | |||
| }) | |||
| if err != nil { | |||
| fmt.Println("error creating user:", err) | |||
| return | |||
| } | |||
| fmt.Printf("user %s created\n", user.Account) | |||
| } | |||
| @@ -1,15 +1,18 @@ | |||
| package rpc | |||
| import ( | |||
| "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/accesstoken" | |||
| "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/db" | |||
| ) | |||
| type Service struct { | |||
| db *db.DB | |||
| db *db.DB | |||
| accessToken *accesstoken.Cache | |||
| } | |||
| func NewService(db *db.DB) *Service { | |||
| func NewService(db *db.DB, accessToken *accesstoken.Cache) *Service { | |||
| return &Service{ | |||
| db: db, | |||
| db: db, | |||
| accessToken: accessToken, | |||
| } | |||
| } | |||
| @@ -0,0 +1,227 @@ | |||
| package rpc | |||
| import ( | |||
| "context" | |||
| "crypto/ed25519" | |||
| "encoding/hex" | |||
| "fmt" | |||
| "time" | |||
| "github.com/google/uuid" | |||
| "gitlink.org.cn/cloudream/common/consts/errorcode" | |||
| "gitlink.org.cn/cloudream/common/pkgs/logger" | |||
| stgglb "gitlink.org.cn/cloudream/jcs-pub/common/globals" | |||
| "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/accesstoken" | |||
| "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" | |||
| corrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/coordinator" | |||
| hubrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/hub" | |||
| "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/db" | |||
| cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types" | |||
| "golang.org/x/crypto/bcrypt" | |||
| "gorm.io/gorm" | |||
| ) | |||
| func (svc *Service) UserLogin(ctx context.Context, msg *corrpc.UserLogin) (*corrpc.UserLoginResp, *rpc.CodeError) { | |||
| log := logger.WithField("Account", msg.Account) | |||
| user, err := svc.db.User().GetByAccount(svc.db.DefCtx(), msg.Account) | |||
| if err != nil { | |||
| if err == gorm.ErrRecordNotFound { | |||
| log.Warnf("account not found") | |||
| return nil, rpc.Failed(errorcode.DataNotFound, "account not found") | |||
| } | |||
| log.Warnf("getting account: %v", err) | |||
| return nil, rpc.Failed(errorcode.OperationFailed, "getting account: %v", err) | |||
| } | |||
| dbPass, err := hex.DecodeString(user.Password) | |||
| if err != nil { | |||
| log.Warnf("decoding password: %v", err) | |||
| return nil, rpc.Failed(errorcode.OperationFailed, "decoding password: %v", err) | |||
| } | |||
| if bcrypt.CompareHashAndPassword(dbPass, []byte(msg.Password)) != nil { | |||
| log.Warnf("password not match") | |||
| return nil, rpc.Failed(errorcode.Unauthorized, "password not match") | |||
| } | |||
| pubKey, priKey, err := ed25519.GenerateKey(nil) | |||
| if err != nil { | |||
| log.Warnf("generating key: %v", err) | |||
| return nil, rpc.Failed(errorcode.OperationFailed, "generating key: %v", err) | |||
| } | |||
| pubKeyStr := hex.EncodeToString(pubKey) | |||
| nowTime := time.Now() | |||
| token := cortypes.UserAccessToken{ | |||
| UserID: user.UserID, | |||
| TokenID: cortypes.AccessTokenID(uuid.NewString()), | |||
| PublicKey: pubKeyStr, | |||
| ExpiresAt: nowTime.Add(time.Hour), | |||
| CreatedAt: nowTime, | |||
| } | |||
| err = svc.db.UserAccessToken().Create(svc.db.DefCtx(), &token) | |||
| if err != nil { | |||
| log.Warnf("creating token: %v", err) | |||
| return nil, rpc.Failed(errorcode.OperationFailed, "creating token: %v", err) | |||
| } | |||
| log.Infof("login success, token expires at %v", token.ExpiresAt) | |||
| return &corrpc.UserLoginResp{ | |||
| Token: token, | |||
| PrivateKey: hex.EncodeToString(priKey), | |||
| }, nil | |||
| } | |||
| func (svc *Service) UserRefreshToken(ctx context.Context, msg *corrpc.UserRefreshToken) (*corrpc.UserRefreshTokenResp, *rpc.CodeError) { | |||
| authInfo, ok := rpc.GetAuthInfo(ctx) | |||
| if !ok { | |||
| return nil, rpc.Failed(errorcode.Unauthorized, "unauthorized") | |||
| } | |||
| log := logger.WithField("UserID", authInfo.UserID).WithField("TokenID", authInfo.AccessTokenID) | |||
| pubKey, priKey, err := ed25519.GenerateKey(nil) | |||
| if err != nil { | |||
| log.Warnf("generating key: %v", err) | |||
| return nil, rpc.Failed(errorcode.OperationFailed, "generating key: %v", err) | |||
| } | |||
| pubKeyStr := hex.EncodeToString(pubKey) | |||
| nowTime := time.Now() | |||
| token := cortypes.UserAccessToken{ | |||
| UserID: authInfo.UserID, | |||
| TokenID: cortypes.AccessTokenID(uuid.NewString()), | |||
| PublicKey: pubKeyStr, | |||
| ExpiresAt: nowTime.Add(time.Hour), | |||
| CreatedAt: nowTime, | |||
| } | |||
| err = svc.db.UserAccessToken().Create(svc.db.DefCtx(), &token) | |||
| if err != nil { | |||
| log.Warnf("creating token: %v", err) | |||
| return nil, rpc.Failed(errorcode.OperationFailed, "creating token: %v", err) | |||
| } | |||
| log.Infof("refresh token success, new token expires at %v", token.ExpiresAt) | |||
| return &corrpc.UserRefreshTokenResp{ | |||
| Token: token, | |||
| PrivateKey: hex.EncodeToString(priKey), | |||
| }, nil | |||
| } | |||
| func (svc *Service) UserLogout(ctx context.Context, msg *corrpc.UserLogout) (*corrpc.UserLogoutResp, *rpc.CodeError) { | |||
| authInfo, ok := rpc.GetAuthInfo(ctx) | |||
| if !ok { | |||
| return nil, rpc.Failed(errorcode.Unauthorized, "unauthorized") | |||
| } | |||
| log := logger.WithField("UserID", authInfo.UserID).WithField("TokenID", authInfo.AccessTokenID) | |||
| loaded, err := db.DoTx02(svc.db, func(tx db.SQLContext) ([]cortypes.LoadedAccessToken, error) { | |||
| token, err := svc.db.UserAccessToken().GetByID(tx, authInfo.UserID, authInfo.AccessTokenID) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| err = svc.db.UserAccessToken().DeleteByID(tx, token.UserID, token.TokenID) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| loaded, err := svc.db.LoadedAccessToken().GetByUserIDAndTokenID(tx, token.UserID, token.TokenID) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| err = svc.db.LoadedAccessToken().DeleteAllByUserIDAndTokenID(tx, token.UserID, token.TokenID) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| return loaded, nil | |||
| }) | |||
| if err != nil { | |||
| log.Warnf("delete access token: %v", err) | |||
| if err == gorm.ErrRecordNotFound { | |||
| return nil, rpc.Failed(errorcode.DataNotFound, "token not found") | |||
| } | |||
| return nil, rpc.Failed(errorcode.OperationFailed, "delete access token: %v", err) | |||
| } | |||
| svc.accessToken.NotifyTokenInvalid(accesstoken.CacheKey{ | |||
| UserID: authInfo.UserID, | |||
| TokenID: authInfo.AccessTokenID, | |||
| }) | |||
| var loadedHubIDs []cortypes.HubID | |||
| for _, l := range loaded { | |||
| loadedHubIDs = append(loadedHubIDs, l.HubID) | |||
| } | |||
| svc.notifyLoadedHubs(authInfo.UserID, authInfo.AccessTokenID, loadedHubIDs) | |||
| return &corrpc.UserLogoutResp{}, nil | |||
| } | |||
| func (svc *Service) notifyLoadedHubs(userID cortypes.UserID, tokenID cortypes.AccessTokenID, loadedHubIDs []cortypes.HubID) { | |||
| log := logger.WithField("UserID", userID).WithField("TokenID", tokenID) | |||
| loadedHubs, err := svc.db.Hub().BatchGetByID(svc.db.DefCtx(), loadedHubIDs) | |||
| if err != nil { | |||
| log.Warnf("getting hubs: %v", err) | |||
| return | |||
| } | |||
| for _, l := range loadedHubs { | |||
| addr, ok := l.Address.(*cortypes.GRPCAddressInfo) | |||
| if !ok { | |||
| continue | |||
| } | |||
| cli := stgglb.HubRPCPool.Get(addr.ExternalIP, addr.ExternalGRPCPort) | |||
| // 不关心返回值 | |||
| cli.NotifyUserAccessTokenInvalid(context.Background(), &hubrpc.NotifyUserAccessTokenInvalid{ | |||
| UserID: userID, | |||
| TokenID: tokenID, | |||
| }) | |||
| cli.Release() | |||
| } | |||
| } | |||
| func (svc *Service) HubLoadAccessToken(ctx context.Context, msg *corrpc.HubLoadAccessToken) (*corrpc.HubLoadAccessTokenResp, *rpc.CodeError) { | |||
| token, err := db.DoTx02(svc.db, func(tx db.SQLContext) (cortypes.UserAccessToken, error) { | |||
| token, err := svc.db.UserAccessToken().GetByID(tx, msg.UserID, msg.TokenID) | |||
| if err != nil { | |||
| return cortypes.UserAccessToken{}, err | |||
| } | |||
| err = svc.db.LoadedAccessToken().CreateOrUpdate(tx, cortypes.LoadedAccessToken{ | |||
| UserID: msg.UserID, | |||
| TokenID: msg.TokenID, | |||
| HubID: msg.HubID, | |||
| LoadedAt: time.Now(), | |||
| }) | |||
| if err != nil { | |||
| return cortypes.UserAccessToken{}, fmt.Errorf("creating access token loaded record: %v", err) | |||
| } | |||
| return token, nil | |||
| }) | |||
| if err != nil { | |||
| if err == gorm.ErrRecordNotFound { | |||
| return nil, rpc.Failed(errorcode.DataNotFound, "token not found") | |||
| } | |||
| return nil, rpc.Failed(errorcode.OperationFailed, "loading access token: %v", err) | |||
| } | |||
| return &corrpc.HubLoadAccessTokenResp{ | |||
| Token: token, | |||
| }, nil | |||
| } | |||
| @@ -0,0 +1,112 @@ | |||
| package ticktock | |||
| import ( | |||
| "context" | |||
| "fmt" | |||
| "time" | |||
| "gitlink.org.cn/cloudream/common/pkgs/logger" | |||
| "gitlink.org.cn/cloudream/common/utils/reflect2" | |||
| stgglb "gitlink.org.cn/cloudream/jcs-pub/common/globals" | |||
| "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/accesstoken" | |||
| hubrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/hub" | |||
| "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/db" | |||
| cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types" | |||
| ) | |||
| type ClearExpiredAccessToken struct { | |||
| } | |||
| func (j *ClearExpiredAccessToken) Name() string { | |||
| return reflect2.TypeNameOf[ClearExpiredAccessToken]() | |||
| } | |||
| func (j *ClearExpiredAccessToken) Execute(t *TickTock) { | |||
| log := logger.WithType[ClearExpiredAccessToken]("TickTock") | |||
| log.Infof("job start") | |||
| startTime := time.Now() | |||
| defer func() { | |||
| log.Infof("job end, time: %v", time.Since(startTime)) | |||
| }() | |||
| expired, err := db.DoTx02(t.db, func(tx db.SQLContext) ([]cortypes.LoadedAccessToken, error) { | |||
| nowTime := time.Now() | |||
| expired, err := t.db.LoadedAccessToken().GetExpired(tx, nowTime) | |||
| if err != nil { | |||
| return nil, fmt.Errorf("get expired access token load record: %w", err) | |||
| } | |||
| err = t.db.LoadedAccessToken().DeleteExpired(tx, nowTime) | |||
| if err != nil { | |||
| return nil, fmt.Errorf("delete expired access token load record: %w", err) | |||
| } | |||
| err = t.db.UserAccessToken().DeleteExpired(tx, nowTime) | |||
| if err != nil { | |||
| return nil, fmt.Errorf("delete expired user access token: %w", err) | |||
| } | |||
| return expired, nil | |||
| }) | |||
| if err != nil { | |||
| log.Warn(err.Error()) | |||
| return | |||
| } | |||
| uniToken := make(map[accesstoken.CacheKey]bool) | |||
| for _, t := range expired { | |||
| uniToken[accesstoken.CacheKey{ | |||
| UserID: t.UserID, | |||
| TokenID: t.TokenID, | |||
| }] = true | |||
| } | |||
| log.Infof("%v expired access token cleared", len(uniToken)) | |||
| // 通知本服务的AccessToken缓存失效 | |||
| for k := range uniToken { | |||
| t.accessToken.NotifyTokenInvalid(k) | |||
| } | |||
| // 通知所有加载了失效Token的Hub | |||
| var loadedHubIDs []cortypes.HubID | |||
| for _, e := range expired { | |||
| loadedHubIDs = append(loadedHubIDs, e.HubID) | |||
| } | |||
| loadedHubs, err := t.db.Hub().BatchGetByID(t.db.DefCtx(), loadedHubIDs) | |||
| if err != nil { | |||
| log.Warnf("getting hubs: %v", err) | |||
| return | |||
| } | |||
| hubMap := make(map[cortypes.HubID]cortypes.Hub) | |||
| for _, h := range loadedHubs { | |||
| hubMap[h.HubID] = h | |||
| } | |||
| for _, e := range expired { | |||
| h, ok := hubMap[e.HubID] | |||
| if !ok { | |||
| continue | |||
| } | |||
| addr, ok := h.Address.(*cortypes.GRPCAddressInfo) | |||
| if !ok { | |||
| continue | |||
| } | |||
| cli := stgglb.HubRPCPool.Get(addr.ExternalIP, addr.ExternalGRPCPort) | |||
| // 不关心返回值 | |||
| _, err := cli.NotifyUserAccessTokenInvalid(context.Background(), &hubrpc.NotifyUserAccessTokenInvalid{ | |||
| UserID: e.UserID, | |||
| TokenID: e.TokenID, | |||
| }) | |||
| if err != nil { | |||
| log.Warnf("notify hub %v: %v", h.HubID, err) | |||
| } | |||
| cli.Release() | |||
| } | |||
| } | |||
| @@ -6,6 +6,7 @@ import ( | |||
| "github.com/go-co-op/gocron/v2" | |||
| "gitlink.org.cn/cloudream/common/pkgs/logger" | |||
| "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/accesstoken" | |||
| "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/db" | |||
| ) | |||
| @@ -20,19 +21,21 @@ type cronJob struct { | |||
| } | |||
| type TickTock struct { | |||
| cfg Config | |||
| sch gocron.Scheduler | |||
| jobs map[string]cronJob | |||
| db *db.DB | |||
| cfg Config | |||
| sch gocron.Scheduler | |||
| jobs map[string]cronJob | |||
| db *db.DB | |||
| accessToken *accesstoken.Cache | |||
| } | |||
| func New(cfg Config, db *db.DB) *TickTock { | |||
| func New(cfg Config, db *db.DB, accessToken *accesstoken.Cache) *TickTock { | |||
| sch, _ := gocron.NewScheduler() | |||
| t := &TickTock{ | |||
| cfg: cfg, | |||
| sch: sch, | |||
| jobs: map[string]cronJob{}, | |||
| db: db, | |||
| cfg: cfg, | |||
| sch: sch, | |||
| jobs: map[string]cronJob{}, | |||
| db: db, | |||
| accessToken: accessToken, | |||
| } | |||
| t.initJobs() | |||
| return t | |||
| @@ -70,4 +73,6 @@ func (t *TickTock) addJob(job Job, duration gocron.JobDefinition) { | |||
| func (t *TickTock) initJobs() { | |||
| t.addJob(&CheckHubState{}, gocron.DurationJob(time.Minute*5)) | |||
| t.addJob(&ClearExpiredAccessToken{}, gocron.DurationJob(time.Minute*5)) | |||
| } | |||
| @@ -68,21 +68,49 @@ func (HubConnectivity) TableName() string { | |||
| return "HubConnectivity" | |||
| } | |||
| type HubLocation struct { | |||
| HubID HubID `gorm:"column:HubID; type:bigint" json:"hubID"` | |||
| StorageName string `gorm:"column:StorageName; type:varchar(255); not null" json:"storageName"` | |||
| Location string `gorm:"column:Location; type:varchar(255); not null" json:"location"` | |||
| } | |||
| func (HubLocation) TableName() string { | |||
| return "HubLocation" | |||
| } | |||
| type User struct { | |||
| UserID UserID `gorm:"column:UserID; primaryKey; type:bigint; autoIncrement" json:"userID"` | |||
| Name string `gorm:"column:Name; type:varchar(255); not null" json:"name"` | |||
| UserID UserID `gorm:"column:UserID; primaryKey; type:bigint; autoIncrement" json:"userID"` | |||
| NickName string `gorm:"column:NickName; type:varchar(255); not null" json:"nickName"` | |||
| Account string `gorm:"column:Account; type:varchar(255); not null" json:"account"` | |||
| // bcrypt哈希过的密码,带有盐值 | |||
| Password string `gorm:"column:Password; type:varchar(255); not null" json:"password"` | |||
| } | |||
| func (User) TableName() string { | |||
| return "User" | |||
| } | |||
| type HubLocation struct { | |||
| HubID HubID `gorm:"column:HubID; type:bigint" json:"hubID"` | |||
| StorageName string `gorm:"column:StorageName; type:varchar(255); not null" json:"storageName"` | |||
| Location string `gorm:"column:Location; type:varchar(255); not null" json:"location"` | |||
| type AccessTokenID string | |||
| type UserAccessToken struct { | |||
| UserID UserID `gorm:"column:UserID; primaryKey; type:bigint" json:"userID"` | |||
| TokenID AccessTokenID `gorm:"column:TokenID; primaryKey; type:char(36); not null" json:"tokenID"` | |||
| PublicKey string `gorm:"column:PublicKey; type:char(64); not null" json:"publicKey"` | |||
| ExpiresAt time.Time `gorm:"column:ExpiresAt; type:datetime" json:"expiresAt"` | |||
| CreatedAt time.Time `gorm:"column:CreatedAt; type:datetime" json:"createdAt"` | |||
| } | |||
| func (HubLocation) TableName() string { | |||
| return "HubLocation" | |||
| func (UserAccessToken) TableName() string { | |||
| return "UserAccessToken" | |||
| } | |||
| type LoadedAccessToken struct { | |||
| UserID UserID `gorm:"column:UserID; primaryKey; type:bigint" json:"userID"` | |||
| TokenID AccessTokenID `gorm:"column:TokenID; primaryKey; type:char(36); not null" json:"tokenID"` | |||
| HubID HubID `gorm:"column:HubID; primaryKey; type:bigint" json:"hubID"` | |||
| LoadedAt time.Time `gorm:"column:LoadedAt; type:datetime" json:"loadedAt"` | |||
| } | |||
| func (LoadedAccessToken) TableName() string { | |||
| return "LoadedAccessToken" | |||
| } | |||
| @@ -28,10 +28,11 @@ require ( | |||
| github.com/spf13/cobra v1.8.1 | |||
| github.com/stretchr/testify v1.10.0 | |||
| gitlink.org.cn/cloudream/common v0.0.0 | |||
| golang.org/x/crypto v0.38.0 | |||
| golang.org/x/net v0.35.0 | |||
| golang.org/x/sync v0.13.0 | |||
| golang.org/x/sys v0.32.0 | |||
| golang.org/x/term v0.31.0 | |||
| golang.org/x/sync v0.14.0 | |||
| golang.org/x/sys v0.33.0 | |||
| golang.org/x/term v0.32.0 | |||
| google.golang.org/grpc v1.67.1 | |||
| google.golang.org/protobuf v1.36.6 | |||
| gorm.io/gorm v1.25.12 | |||
| @@ -70,9 +71,8 @@ require ( | |||
| github.com/twitchyliquid64/golang-asm v0.15.1 // indirect | |||
| go.mongodb.org/mongo-driver v1.12.0 // indirect | |||
| golang.org/x/arch v0.8.0 // indirect | |||
| golang.org/x/crypto v0.37.0 // indirect | |||
| golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect | |||
| golang.org/x/text v0.24.0 // indirect | |||
| golang.org/x/text v0.25.0 // indirect | |||
| google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38 // indirect | |||
| gopkg.in/ini.v1 v1.67.0 // indirect | |||
| gopkg.in/yaml.v3 v3.0.1 // indirect | |||
| @@ -243,8 +243,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y | |||
| golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= | |||
| golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= | |||
| golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= | |||
| golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= | |||
| golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= | |||
| golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= | |||
| golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= | |||
| golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= | |||
| golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4= | |||
| golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk= | |||
| @@ -275,8 +275,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ | |||
| golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | |||
| golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | |||
| golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | |||
| golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= | |||
| golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= | |||
| golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= | |||
| golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= | |||
| golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | |||
| golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | |||
| golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | |||
| @@ -300,16 +300,16 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | |||
| golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | |||
| golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= | |||
| golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= | |||
| golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= | |||
| golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= | |||
| golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= | |||
| golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= | |||
| golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= | |||
| golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= | |||
| golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= | |||
| golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= | |||
| golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= | |||
| golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= | |||
| golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o= | |||
| golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw= | |||
| golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg= | |||
| golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ= | |||
| golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= | |||
| golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | |||
| golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | |||
| @@ -319,8 +319,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= | |||
| golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= | |||
| golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= | |||
| golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= | |||
| golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= | |||
| golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= | |||
| golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= | |||
| golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= | |||
| golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= | |||
| golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= | |||
| golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= | |||
| @@ -0,0 +1,49 @@ | |||
| package accesstoken | |||
| import ( | |||
| "context" | |||
| "gitlink.org.cn/cloudream/common/consts/errorcode" | |||
| stgglb "gitlink.org.cn/cloudream/jcs-pub/common/globals" | |||
| "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/accesstoken" | |||
| corrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/coordinator" | |||
| cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types" | |||
| ) | |||
| type ExitEvent = accesstoken.ExitEvent | |||
| type CacheKey = accesstoken.CacheKey | |||
| type Cache struct { | |||
| localHubID cortypes.HubID | |||
| *accesstoken.Cache | |||
| } | |||
| func New(localHubID cortypes.HubID) *Cache { | |||
| c := &Cache{ | |||
| localHubID: localHubID, | |||
| } | |||
| c.Cache = accesstoken.New(c.load) | |||
| return c | |||
| } | |||
| func (c *Cache) load(key accesstoken.CacheKey) (cortypes.UserAccessToken, error) { | |||
| corCli := stgglb.CoordinatorRPCPool.Get() | |||
| defer corCli.Release() | |||
| tokenResp, cerr := corCli.HubLoadAccessToken(context.Background(), &corrpc.HubLoadAccessToken{ | |||
| UserID: key.UserID, | |||
| TokenID: key.TokenID, | |||
| HubID: c.localHubID, | |||
| }) | |||
| if cerr != nil { | |||
| if cerr.Code == errorcode.DataNotFound { | |||
| return cortypes.UserAccessToken{}, accesstoken.ErrTokenNotFound | |||
| } | |||
| return cortypes.UserAccessToken{}, cerr.ToError() | |||
| } | |||
| return tokenResp.Token, nil | |||
| } | |||
| @@ -9,6 +9,7 @@ import ( | |||
| "github.com/spf13/cobra" | |||
| "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" | |||
| "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/storage/pool" | |||
| "gitlink.org.cn/cloudream/jcs-pub/hub/internal/accesstoken" | |||
| "gitlink.org.cn/cloudream/jcs-pub/hub/internal/http" | |||
| myrpc "gitlink.org.cn/cloudream/jcs-pub/hub/internal/rpc" | |||
| @@ -22,7 +23,7 @@ import ( | |||
| "gitlink.org.cn/cloudream/jcs-pub/hub/internal/config" | |||
| "gitlink.org.cn/cloudream/jcs-pub/hub/internal/ticktock" | |||
| coormq "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/coordinator" | |||
| corrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/coordinator" | |||
| ) | |||
| func init() { | |||
| @@ -47,22 +48,36 @@ type serveOptions struct { | |||
| } | |||
| func serve(configPath string, opts serveOptions) { | |||
| // 加载服务配置 | |||
| err := config.Init(configPath) | |||
| if err != nil { | |||
| fmt.Printf("init config failed, err: %s", err.Error()) | |||
| os.Exit(1) | |||
| } | |||
| // 初始化日志 | |||
| err = logger.Init(&config.Cfg().Logger) | |||
| if err != nil { | |||
| fmt.Printf("init logger failed, err: %s", err.Error()) | |||
| os.Exit(1) | |||
| } | |||
| // 初始化全局变量 | |||
| stgglb.InitLocal(config.Cfg().Local) | |||
| stgglb.InitPools(&hubrpc.PoolConfig{}, &config.Cfg().CoordinatorRPC) | |||
| // stgglb.Stats.SetupHubStorageTransfer(*config.Cfg().Local.HubID) | |||
| // stgglb.Stats.SetupHubTransfer(*config.Cfg().Local.HubID) | |||
| // 初始化各服务客户端的连接池 | |||
| corRPCCfg, err := config.Cfg().CoordinatorRPC.Build(nil) | |||
| if err != nil { | |||
| logger.Errorf("building coordinator rpc config: %v", err) | |||
| os.Exit(1) | |||
| } | |||
| hubRPCCfg, err := config.Cfg().HubRPC.Build(nil) | |||
| if err != nil { | |||
| logger.Errorf("building hub rpc config: %v", err) | |||
| os.Exit(1) | |||
| } | |||
| stgglb.InitPools(hubRPCCfg, corRPCCfg) | |||
| // 获取Hub配置 | |||
| hubCfg := downloadHubConfig() | |||
| @@ -109,13 +124,19 @@ func serve(configPath string, opts serveOptions) { | |||
| tktk.Start() | |||
| defer tktk.Stop() | |||
| // 客户端访问令牌管理器 | |||
| accToken := accesstoken.New(config.Cfg().ID) | |||
| accTokenChan := accToken.Start() | |||
| defer accToken.Stop() | |||
| // RPC服务 | |||
| rpcSvr := hubrpc.NewServer(config.Cfg().RPC, myrpc.NewService(&worker, stgPool)) | |||
| rpcSvr := hubrpc.NewServer(config.Cfg().RPC, myrpc.NewService(&worker, stgPool, accToken), accToken) | |||
| rpcSvrChan := rpcSvr.Start() | |||
| defer rpcSvr.Stop() | |||
| /// 开始监听各个模块的事件 | |||
| evtPubEvt := evtPubChan.Receive() | |||
| accTokenEvt := accTokenChan.Receive() | |||
| rpcEvt := rpcSvrChan.Receive() | |||
| httpEvt := httpChan.Receive() | |||
| @@ -145,6 +166,23 @@ loop: | |||
| } | |||
| evtPubEvt = evtPubChan.Receive() | |||
| case e := <-accTokenEvt.Chan(): | |||
| if e.Err != nil { | |||
| logger.Errorf("receive access token event: %v", err) | |||
| break loop | |||
| } | |||
| switch e := e.Value.(type) { | |||
| case accesstoken.ExitEvent: | |||
| if e.Err != nil { | |||
| logger.Errorf("access token manager exited with error: %v", e.Err) | |||
| } else { | |||
| logger.Info("access token manager exited") | |||
| } | |||
| break loop | |||
| } | |||
| accTokenEvt = accTokenChan.Receive() | |||
| case e := <-rpcEvt.Chan(): | |||
| if e.Err != nil { | |||
| logger.Errorf("receive rpc event: %v", e.Err) | |||
| @@ -179,14 +217,14 @@ loop: | |||
| } | |||
| func downloadHubConfig() coormq.GetHubConfigResp { | |||
| func downloadHubConfig() corrpc.GetHubConfigResp { | |||
| coorCli := stgglb.CoordinatorRPCPool.Get() | |||
| defer coorCli.Release() | |||
| ctx, cancel := context.WithTimeout(context.Background(), time.Minute) | |||
| defer cancel() | |||
| cfgResp, cerr := coorCli.GetHubConfig(ctx, coormq.ReqGetHubConfig(cortypes.HubID(config.Cfg().ID))) | |||
| cfgResp, cerr := coorCli.GetHubConfig(ctx, corrpc.ReqGetHubConfig(cortypes.HubID(config.Cfg().ID))) | |||
| if cerr != nil { | |||
| logger.Errorf("getting hub config: %v", cerr) | |||
| os.Exit(1) | |||
| @@ -6,6 +6,7 @@ import ( | |||
| stgglb "gitlink.org.cn/cloudream/jcs-pub/common/globals" | |||
| "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" | |||
| corrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/coordinator" | |||
| hubrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/hub" | |||
| "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/sysevent" | |||
| cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types" | |||
| "gitlink.org.cn/cloudream/jcs-pub/hub/internal/http" | |||
| @@ -17,7 +18,8 @@ type Config struct { | |||
| Local stgglb.LocalMachineInfo `json:"local"` | |||
| RPC rpc.Config `json:"rpc"` | |||
| HTTP *http.Config `json:"http"` | |||
| CoordinatorRPC corrpc.PoolConfig `json:"coordinatorRPC"` | |||
| CoordinatorRPC corrpc.PoolConfigJSON `json:"coordinatorRPC"` | |||
| HubRPC hubrpc.PoolConfigJSON `json:"hubRPC"` | |||
| Logger log.Config `json:"logger"` | |||
| SysEvent sysevent.Config `json:"sysEvent"` | |||
| TickTock ticktock.Config `json:"tickTock"` | |||
| @@ -4,17 +4,20 @@ import ( | |||
| "gitlink.org.cn/cloudream/common/pkgs/ioswitch/exec" | |||
| hubrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/hub" | |||
| "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/storage/pool" | |||
| "gitlink.org.cn/cloudream/jcs-pub/hub/internal/accesstoken" | |||
| ) | |||
| type Service struct { | |||
| swWorker *exec.Worker | |||
| stgPool *pool.Pool | |||
| swWorker *exec.Worker | |||
| stgPool *pool.Pool | |||
| accessToken *accesstoken.Cache | |||
| } | |||
| func NewService(swWorker *exec.Worker, stgPool *pool.Pool) *Service { | |||
| func NewService(swWorker *exec.Worker, stgPool *pool.Pool, accessToken *accesstoken.Cache) *Service { | |||
| return &Service{ | |||
| swWorker: swWorker, | |||
| stgPool: stgPool, | |||
| swWorker: swWorker, | |||
| stgPool: stgPool, | |||
| accessToken: accessToken, | |||
| } | |||
| } | |||
| @@ -0,0 +1,19 @@ | |||
| package rpc | |||
| import ( | |||
| "context" | |||
| "gitlink.org.cn/cloudream/common/pkgs/logger" | |||
| "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/accesstoken" | |||
| "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" | |||
| hubrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/hub" | |||
| ) | |||
| func (s *Service) NotifyUserAccessTokenInvalid(ctx context.Context, msg *hubrpc.NotifyUserAccessTokenInvalid) (*hubrpc.NotifyUserAccessTokenInvalidResp, *rpc.CodeError) { | |||
| s.accessToken.NotifyTokenInvalid(accesstoken.CacheKey{ | |||
| UserID: msg.UserID, | |||
| TokenID: msg.TokenID, | |||
| }) | |||
| logger.WithField("UserID", msg.UserID).WithField("TokenID", msg.TokenID).Infof("user access token invalid") | |||
| return &hubrpc.NotifyUserAccessTokenInvalidResp{}, nil | |||
| } | |||