|
- package rpc
-
- import (
- "crypto/tls"
- "fmt"
- "strconv"
-
- cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types"
- "golang.org/x/net/context"
- "google.golang.org/grpc"
- "google.golang.org/grpc/codes"
- "google.golang.org/grpc/credentials"
- "google.golang.org/grpc/metadata"
- "google.golang.org/grpc/peer"
- "google.golang.org/grpc/status"
- )
-
- const (
- ClientAPISNIV1 = "rpc.client.jcs-pub.v1"
- InternalAPISNIV1 = "rpc.internal.jcs-pub.v1"
-
- MetaUserID = "x-jcs-user-id"
- MetaAccessTokenID = "x-jcs-access-token-id"
- MetaNonce = "x-jcs-nonce"
- MetaSignature = "x-jcs-signature"
- MetaTokenAuthInfo = "x-jcs-token-auth-info"
- )
-
- type AccessTokenAuthInfo struct {
- UserID cortypes.UserID
- AccessTokenID cortypes.AccessTokenID
- Nonce string
- Signature string
- }
-
- type AccessTokenVerifier interface {
- Verify(authInfo AccessTokenAuthInfo) bool
- }
-
- type AccessTokenProvider interface {
- GetAuthInfo() (AccessTokenAuthInfo, error)
- }
-
- func (s *ServerBase) tlsConfigSelector(hello *tls.ClientHelloInfo) (*tls.Config, error) {
- switch hello.ServerName {
- case ClientAPISNIV1:
- return &tls.Config{
- Certificates: []tls.Certificate{s.serverCert},
- ClientAuth: tls.NoClientCert,
- NextProtos: []string{"h2"},
- }, nil
- case InternalAPISNIV1:
- return &tls.Config{
- Certificates: []tls.Certificate{s.serverCert},
- ClientAuth: tls.RequireAndVerifyClientCert,
- ClientCAs: s.rootCA,
- NextProtos: []string{"h2"},
- }, nil
-
- default:
- return nil, fmt.Errorf("unknown server name: %s", hello.ServerName)
- }
- }
-
- func (s *ServerBase) authUnary(
- ctx context.Context,
- req interface{},
- info *grpc.UnaryServerInfo,
- handler grpc.UnaryHandler,
-
- ) (resp any, err error) {
- pr, ok := peer.FromContext(ctx)
- if !ok {
- return nil, status.Error(codes.Unauthenticated, "no peer found in context")
- }
-
- tlsInfo, ok := pr.AuthInfo.(credentials.TLSInfo)
- if !ok {
- return nil, status.Error(codes.Unauthenticated, "no tls info found in peer")
- }
-
- // 如果是使用interanl ServerName通过的TLS认证,则直接放行
- if tlsInfo.State.ServerName == InternalAPISNIV1 {
- return handler(ctx, req)
- }
-
- // 如果是无需认证的API,则直接放行
- if s.noAuthAPIs[info.FullMethod] {
- return handler(ctx, req)
- }
-
- // 否则要进行额外的Token认证
-
- if !s.accessTokenAuthAPIs[info.FullMethod] {
- return nil, status.Error(codes.Unauthenticated, "unauthorized access")
- }
-
- meta, ok := metadata.FromIncomingContext(ctx)
- if !ok {
- return nil, status.Error(codes.Unauthenticated, "no metadata found in context")
- }
-
- userIDs := meta.Get(MetaUserID)
- if len(userIDs) != 1 {
- return nil, status.Error(codes.Unauthenticated, "missing or multiple user ids in metadata")
- }
- userID, err := strconv.ParseInt(userIDs[0], 10, 64)
- if err != nil {
- return nil, status.Error(codes.Unauthenticated, "invalid user id in metadata")
- }
-
- accessTokenIDs := meta.Get(MetaAccessTokenID)
- if len(accessTokenIDs) != 1 {
- return nil, status.Error(codes.Unauthenticated, "missing or multiple access token ids in metadata")
- }
-
- nonce := meta.Get(MetaNonce)
- if len(nonce) != 1 {
- return nil, status.Error(codes.Unauthenticated, "missing or multiple nonces in metadata")
- }
-
- signature := meta.Get(MetaSignature)
- if len(signature) != 1 {
- return nil, status.Error(codes.Unauthenticated, "missing or multiple signatures in metadata")
- }
-
- authInfo := AccessTokenAuthInfo{
- UserID: cortypes.UserID(userID),
- AccessTokenID: cortypes.AccessTokenID(accessTokenIDs[0]),
- Nonce: nonce[0],
- Signature: signature[0],
- }
- if !s.tokenVerifier.Verify(authInfo) {
- return nil, status.Error(codes.Unauthenticated, "invalid access token")
- }
-
- ctx = context.WithValue(ctx, MetaTokenAuthInfo, authInfo)
- return handler(ctx, req)
- }
-
- func (s *ServerBase) authStream(
- srv any,
- stream grpc.ServerStream,
- info *grpc.StreamServerInfo,
- handler grpc.StreamHandler,
- ) error {
- pr, ok := peer.FromContext(stream.Context())
- if !ok {
- return status.Error(codes.Unauthenticated, "no peer found in context")
- }
-
- tlsInfo, ok := pr.AuthInfo.(credentials.TLSInfo)
- if !ok {
- return status.Error(codes.Unauthenticated, "no tls info found in peer")
- }
-
- // 如果是使用interanl ServerName通过的TLS认证,则直接放行
- if tlsInfo.State.ServerName == InternalAPISNIV1 {
- return handler(srv, stream)
- }
-
- // 如果是无需认证的API,则直接放行
- if s.noAuthAPIs[info.FullMethod] {
- return handler(srv, stream)
- }
-
- // 否则要进行额外的Token认证
-
- if !s.accessTokenAuthAPIs[info.FullMethod] {
- return status.Error(codes.Unauthenticated, "unauthorized access")
- }
-
- meta, ok := metadata.FromIncomingContext(stream.Context())
- if !ok {
- return status.Error(codes.Unauthenticated, "no metadata found in context")
- }
-
- userIDs := meta.Get(MetaUserID)
- if len(userIDs) != 1 {
- return status.Error(codes.Unauthenticated, "missing or multiple user ids in metadata")
- }
- userID, err := strconv.ParseInt(userIDs[0], 10, 64)
- if err != nil {
- return status.Error(codes.Unauthenticated, "invalid user id in metadata")
- }
-
- accessTokenIDs := meta.Get(MetaAccessTokenID)
- if len(accessTokenIDs) != 1 {
- return status.Error(codes.Unauthenticated, "missing or multiple access token ids in metadata")
- }
-
- nonce := meta.Get(MetaNonce)
- if len(nonce) != 1 {
- return status.Error(codes.Unauthenticated, "missing or multiple nonces in metadata")
- }
-
- signature := meta.Get(MetaSignature)
- if len(signature) != 1 {
- return status.Error(codes.Unauthenticated, "missing or multiple signatures in metadata")
- }
-
- authInfo := AccessTokenAuthInfo{
- UserID: cortypes.UserID(userID),
- AccessTokenID: cortypes.AccessTokenID(accessTokenIDs[0]),
- Nonce: nonce[0],
- Signature: signature[0],
- }
-
- if !s.tokenVerifier.Verify(authInfo) {
- return status.Error(codes.Unauthenticated, "invalid access token")
- }
-
- return handler(srv, &serverStream{stream, context.WithValue(stream.Context(), MetaTokenAuthInfo, authInfo)})
- }
-
- type serverStream struct {
- grpc.ServerStream
- ctx context.Context
- }
-
- func (s *serverStream) Context() context.Context {
- return s.ctx
- }
-
- func GetAuthInfo(ctx context.Context) (AccessTokenAuthInfo, bool) {
- val := ctx.Value(MetaTokenAuthInfo)
- if val == nil {
- return AccessTokenAuthInfo{}, false
- }
- authInfo, ok := val.(AccessTokenAuthInfo)
- return authInfo, ok
- }
|