|
- package rpc
-
- import (
- "context"
- "crypto/ed25519"
- "encoding/hex"
- "fmt"
- "time"
-
- "github.com/google/uuid"
- "gitlink.org.cn/cloudream/common/consts/errorcode"
- "gitlink.org.cn/cloudream/common/pkgs/logger"
- stgglb "gitlink.org.cn/cloudream/jcs-pub/common/globals"
- "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/accesstoken"
- "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc"
- corrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/coordinator"
- hubrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/hub"
- jcstypes "gitlink.org.cn/cloudream/jcs-pub/common/types"
- "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/db"
- "golang.org/x/crypto/bcrypt"
- "gorm.io/gorm"
- )
-
- func (svc *Service) UserLogin(ctx context.Context, msg *corrpc.UserLogin) (*corrpc.UserLoginResp, *rpc.CodeError) {
- log := logger.WithField("Account", msg.Account)
-
- user, err := svc.db.User().GetByAccount(svc.db.DefCtx(), msg.Account)
- if err != nil {
- if err == gorm.ErrRecordNotFound {
- log.Warnf("account not found")
- return nil, rpc.Failed(errorcode.DataNotFound, "account not found")
- }
-
- log.Warnf("getting account: %v", err)
- return nil, rpc.Failed(errorcode.OperationFailed, "getting account: %v", err)
- }
-
- dbPass, err := hex.DecodeString(user.Password)
- if err != nil {
- log.Warnf("decoding password: %v", err)
- return nil, rpc.Failed(errorcode.OperationFailed, "decoding password: %v", err)
- }
-
- if bcrypt.CompareHashAndPassword(dbPass, []byte(msg.Password)) != nil {
- log.Warnf("password not match")
- return nil, rpc.Failed(errorcode.Unauthorized, "password not match")
- }
-
- pubKey, priKey, err := ed25519.GenerateKey(nil)
- if err != nil {
- log.Warnf("generating key: %v", err)
- return nil, rpc.Failed(errorcode.OperationFailed, "generating key: %v", err)
- }
-
- pubKeyStr := hex.EncodeToString(pubKey)
- nowTime := time.Now()
- token := jcstypes.UserAccessToken{
- UserID: user.UserID,
- TokenID: jcstypes.AccessTokenID(uuid.NewString()),
- PublicKey: pubKeyStr,
- ExpiresAt: nowTime.Add(time.Hour),
- CreatedAt: nowTime,
- }
-
- err = svc.db.UserAccessToken().Create(svc.db.DefCtx(), &token)
- if err != nil {
- log.Warnf("creating token: %v", err)
- return nil, rpc.Failed(errorcode.OperationFailed, "creating token: %v", err)
- }
-
- log.Infof("login success, token expires at %v", token.ExpiresAt)
-
- return &corrpc.UserLoginResp{
- Token: token,
- PrivateKey: hex.EncodeToString(priKey),
- }, nil
- }
-
- func (svc *Service) UserRefreshToken(ctx context.Context, msg *corrpc.UserRefreshToken) (*corrpc.UserRefreshTokenResp, *rpc.CodeError) {
- authInfo, ok := rpc.GetAuthInfo(ctx)
- if !ok {
- return nil, rpc.Failed(errorcode.Unauthorized, "unauthorized")
- }
-
- log := logger.WithField("UserID", authInfo.UserID).WithField("TokenID", authInfo.AccessTokenID)
-
- pubKey, priKey, err := ed25519.GenerateKey(nil)
- if err != nil {
- log.Warnf("generating key: %v", err)
- return nil, rpc.Failed(errorcode.OperationFailed, "generating key: %v", err)
- }
-
- pubKeyStr := hex.EncodeToString(pubKey)
- nowTime := time.Now()
- token := jcstypes.UserAccessToken{
- UserID: authInfo.UserID,
- TokenID: jcstypes.AccessTokenID(uuid.NewString()),
- PublicKey: pubKeyStr,
- ExpiresAt: nowTime.Add(time.Hour),
- CreatedAt: nowTime,
- }
-
- err = svc.db.UserAccessToken().Create(svc.db.DefCtx(), &token)
- if err != nil {
- log.Warnf("creating token: %v", err)
- return nil, rpc.Failed(errorcode.OperationFailed, "creating token: %v", err)
- }
-
- log.Infof("refresh token success, new token expires at %v", token.ExpiresAt)
-
- return &corrpc.UserRefreshTokenResp{
- Token: token,
- PrivateKey: hex.EncodeToString(priKey),
- }, nil
- }
-
- func (svc *Service) UserLogout(ctx context.Context, msg *corrpc.UserLogout) (*corrpc.UserLogoutResp, *rpc.CodeError) {
- authInfo, ok := rpc.GetAuthInfo(ctx)
- if !ok {
- return nil, rpc.Failed(errorcode.Unauthorized, "unauthorized")
- }
-
- log := logger.WithField("UserID", authInfo.UserID).WithField("TokenID", authInfo.AccessTokenID)
-
- loaded, err := db.DoTx01(svc.db, func(tx db.SQLContext) ([]jcstypes.LoadedAccessToken, error) {
- token, err := svc.db.UserAccessToken().GetByID(tx, authInfo.UserID, authInfo.AccessTokenID)
- if err != nil {
- return nil, err
- }
-
- err = svc.db.UserAccessToken().DeleteByID(tx, token.UserID, token.TokenID)
- if err != nil {
- return nil, err
- }
-
- loaded, err := svc.db.LoadedAccessToken().GetByUserIDAndTokenID(tx, token.UserID, token.TokenID)
- if err != nil {
- return nil, err
- }
-
- err = svc.db.LoadedAccessToken().DeleteAllByUserIDAndTokenID(tx, token.UserID, token.TokenID)
- if err != nil {
- return nil, err
- }
-
- return loaded, nil
- })
- if err != nil {
- log.Warnf("delete access token: %v", err)
- if err == gorm.ErrRecordNotFound {
- return nil, rpc.Failed(errorcode.DataNotFound, "token not found")
- }
-
- return nil, rpc.Failed(errorcode.OperationFailed, "delete access token: %v", err)
- }
-
- svc.accessToken.NotifyTokenInvalid(accesstoken.CacheKey{
- UserID: authInfo.UserID,
- TokenID: authInfo.AccessTokenID,
- })
-
- var loadedHubIDs []jcstypes.HubID
- for _, l := range loaded {
- loadedHubIDs = append(loadedHubIDs, l.HubID)
- }
-
- svc.notifyLoadedHubs(authInfo.UserID, authInfo.AccessTokenID, loadedHubIDs)
-
- return &corrpc.UserLogoutResp{}, nil
- }
-
- func (svc *Service) notifyLoadedHubs(userID jcstypes.UserID, tokenID jcstypes.AccessTokenID, loadedHubIDs []jcstypes.HubID) {
- log := logger.WithField("UserID", userID).WithField("TokenID", tokenID)
-
- loadedHubs, err := svc.db.Hub().BatchGetByID(svc.db.DefCtx(), loadedHubIDs)
- if err != nil {
- log.Warnf("getting hubs: %v", err)
- return
- }
-
- for _, l := range loadedHubs {
- addr, ok := l.Address.(*jcstypes.GRPCAddressInfo)
- if !ok {
- continue
- }
-
- cli := stgglb.HubRPCPool.Get(addr.ExternalIP, addr.ExternalGRPCPort)
- // 不关心返回值
- cli.NotifyUserAccessTokenInvalid(context.Background(), &hubrpc.NotifyUserAccessTokenInvalid{
- UserID: userID,
- TokenID: tokenID,
- })
- cli.Release()
- }
- }
-
- func (svc *Service) HubLoadAccessToken(ctx context.Context, msg *corrpc.HubLoadAccessToken) (*corrpc.HubLoadAccessTokenResp, *rpc.CodeError) {
- token, err := db.DoTx01(svc.db, func(tx db.SQLContext) (jcstypes.UserAccessToken, error) {
- token, err := svc.db.UserAccessToken().GetByID(tx, msg.UserID, msg.TokenID)
- if err != nil {
- return jcstypes.UserAccessToken{}, err
- }
-
- err = svc.db.LoadedAccessToken().CreateOrUpdate(tx, jcstypes.LoadedAccessToken{
- UserID: msg.UserID,
- TokenID: msg.TokenID,
- HubID: msg.HubID,
- LoadedAt: time.Now(),
- })
- if err != nil {
- return jcstypes.UserAccessToken{}, fmt.Errorf("creating access token loaded record: %v", err)
- }
-
- return token, nil
- })
- if err != nil {
- if err == gorm.ErrRecordNotFound {
- return nil, rpc.Failed(errorcode.DataNotFound, "token not found")
- }
-
- return nil, rpc.Failed(errorcode.OperationFailed, "loading access token: %v", err)
- }
-
- return &corrpc.HubLoadAccessTokenResp{
- Token: token,
- }, nil
- }
|