package rpc import ( "context" "crypto/ed25519" "encoding/hex" "fmt" "time" "github.com/google/uuid" "gitlink.org.cn/cloudream/common/consts/errorcode" "gitlink.org.cn/cloudream/common/pkgs/logger" stgglb "gitlink.org.cn/cloudream/jcs-pub/common/globals" "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/accesstoken" "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" corrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/coordinator" hubrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/hub" jcstypes "gitlink.org.cn/cloudream/jcs-pub/common/types" "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/db" "golang.org/x/crypto/bcrypt" "gorm.io/gorm" ) func (svc *Service) UserLogin(ctx context.Context, msg *corrpc.UserLogin) (*corrpc.UserLoginResp, *rpc.CodeError) { log := logger.WithField("Account", msg.Account) user, err := svc.db.User().GetByAccount(svc.db.DefCtx(), msg.Account) if err != nil { if err == gorm.ErrRecordNotFound { log.Warnf("account not found") return nil, rpc.Failed(errorcode.DataNotFound, "account not found") } log.Warnf("getting account: %v", err) return nil, rpc.Failed(errorcode.OperationFailed, "getting account: %v", err) } dbPass, err := hex.DecodeString(user.Password) if err != nil { log.Warnf("decoding password: %v", err) return nil, rpc.Failed(errorcode.OperationFailed, "decoding password: %v", err) } if bcrypt.CompareHashAndPassword(dbPass, []byte(msg.Password)) != nil { log.Warnf("password not match") return nil, rpc.Failed(errorcode.Unauthorized, "password not match") } pubKey, priKey, err := ed25519.GenerateKey(nil) if err != nil { log.Warnf("generating key: %v", err) return nil, rpc.Failed(errorcode.OperationFailed, "generating key: %v", err) } pubKeyStr := hex.EncodeToString(pubKey) nowTime := time.Now() token := jcstypes.UserAccessToken{ UserID: user.UserID, TokenID: jcstypes.AccessTokenID(uuid.NewString()), PublicKey: pubKeyStr, ExpiresAt: nowTime.Add(time.Hour), CreatedAt: nowTime, } err = svc.db.UserAccessToken().Create(svc.db.DefCtx(), &token) if err != nil { log.Warnf("creating token: %v", err) return nil, rpc.Failed(errorcode.OperationFailed, "creating token: %v", err) } log.Infof("login success, token expires at %v", token.ExpiresAt) return &corrpc.UserLoginResp{ Token: token, PrivateKey: hex.EncodeToString(priKey), }, nil } func (svc *Service) UserRefreshToken(ctx context.Context, msg *corrpc.UserRefreshToken) (*corrpc.UserRefreshTokenResp, *rpc.CodeError) { authInfo, ok := rpc.GetAuthInfo(ctx) if !ok { return nil, rpc.Failed(errorcode.Unauthorized, "unauthorized") } log := logger.WithField("UserID", authInfo.UserID).WithField("TokenID", authInfo.AccessTokenID) pubKey, priKey, err := ed25519.GenerateKey(nil) if err != nil { log.Warnf("generating key: %v", err) return nil, rpc.Failed(errorcode.OperationFailed, "generating key: %v", err) } pubKeyStr := hex.EncodeToString(pubKey) nowTime := time.Now() token := jcstypes.UserAccessToken{ UserID: authInfo.UserID, TokenID: jcstypes.AccessTokenID(uuid.NewString()), PublicKey: pubKeyStr, ExpiresAt: nowTime.Add(time.Hour), CreatedAt: nowTime, } err = svc.db.UserAccessToken().Create(svc.db.DefCtx(), &token) if err != nil { log.Warnf("creating token: %v", err) return nil, rpc.Failed(errorcode.OperationFailed, "creating token: %v", err) } log.Infof("refresh token success, new token expires at %v", token.ExpiresAt) return &corrpc.UserRefreshTokenResp{ Token: token, PrivateKey: hex.EncodeToString(priKey), }, nil } func (svc *Service) UserLogout(ctx context.Context, msg *corrpc.UserLogout) (*corrpc.UserLogoutResp, *rpc.CodeError) { authInfo, ok := rpc.GetAuthInfo(ctx) if !ok { return nil, rpc.Failed(errorcode.Unauthorized, "unauthorized") } log := logger.WithField("UserID", authInfo.UserID).WithField("TokenID", authInfo.AccessTokenID) loaded, err := db.DoTx01(svc.db, func(tx db.SQLContext) ([]jcstypes.LoadedAccessToken, error) { token, err := svc.db.UserAccessToken().GetByID(tx, authInfo.UserID, authInfo.AccessTokenID) if err != nil { return nil, err } err = svc.db.UserAccessToken().DeleteByID(tx, token.UserID, token.TokenID) if err != nil { return nil, err } loaded, err := svc.db.LoadedAccessToken().GetByUserIDAndTokenID(tx, token.UserID, token.TokenID) if err != nil { return nil, err } err = svc.db.LoadedAccessToken().DeleteAllByUserIDAndTokenID(tx, token.UserID, token.TokenID) if err != nil { return nil, err } return loaded, nil }) if err != nil { log.Warnf("delete access token: %v", err) if err == gorm.ErrRecordNotFound { return nil, rpc.Failed(errorcode.DataNotFound, "token not found") } return nil, rpc.Failed(errorcode.OperationFailed, "delete access token: %v", err) } svc.accessToken.NotifyTokenInvalid(accesstoken.CacheKey{ UserID: authInfo.UserID, TokenID: authInfo.AccessTokenID, }) var loadedHubIDs []jcstypes.HubID for _, l := range loaded { loadedHubIDs = append(loadedHubIDs, l.HubID) } svc.notifyLoadedHubs(authInfo.UserID, authInfo.AccessTokenID, loadedHubIDs) return &corrpc.UserLogoutResp{}, nil } func (svc *Service) notifyLoadedHubs(userID jcstypes.UserID, tokenID jcstypes.AccessTokenID, loadedHubIDs []jcstypes.HubID) { log := logger.WithField("UserID", userID).WithField("TokenID", tokenID) loadedHubs, err := svc.db.Hub().BatchGetByID(svc.db.DefCtx(), loadedHubIDs) if err != nil { log.Warnf("getting hubs: %v", err) return } for _, l := range loadedHubs { addr, ok := l.Address.(*jcstypes.GRPCAddressInfo) if !ok { continue } cli := stgglb.HubRPCPool.Get(addr.ExternalIP, addr.ExternalGRPCPort) // 不关心返回值 cli.NotifyUserAccessTokenInvalid(context.Background(), &hubrpc.NotifyUserAccessTokenInvalid{ UserID: userID, TokenID: tokenID, }) cli.Release() } } func (svc *Service) HubLoadAccessToken(ctx context.Context, msg *corrpc.HubLoadAccessToken) (*corrpc.HubLoadAccessTokenResp, *rpc.CodeError) { token, err := db.DoTx01(svc.db, func(tx db.SQLContext) (jcstypes.UserAccessToken, error) { token, err := svc.db.UserAccessToken().GetByID(tx, msg.UserID, msg.TokenID) if err != nil { return jcstypes.UserAccessToken{}, err } err = svc.db.LoadedAccessToken().CreateOrUpdate(tx, jcstypes.LoadedAccessToken{ UserID: msg.UserID, TokenID: msg.TokenID, HubID: msg.HubID, LoadedAt: time.Now(), }) if err != nil { return jcstypes.UserAccessToken{}, fmt.Errorf("creating access token loaded record: %v", err) } return token, nil }) if err != nil { if err == gorm.ErrRecordNotFound { return nil, rpc.Failed(errorcode.DataNotFound, "token not found") } return nil, rpc.Failed(errorcode.OperationFailed, "loading access token: %v", err) } return &corrpc.HubLoadAccessTokenResp{ Token: token, }, nil }