package auth_jwt import ( "context" "github.com/gogf/gf/frame/g" "github.com/gogf/gf/os/glog" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "sort" ) // AuthInterceptor is a server interceptor for authentication and authorization type AuthInterceptor struct { jwtManager *JWTManager accessibleRoles map[string][]string } // NewAuthInterceptor returns a new auth interceptor func NewAuthInterceptor(jwtManager *JWTManager, accessibleRoles map[string][]string) *AuthInterceptor { return &AuthInterceptor{jwtManager, accessibleRoles} } // Unary returns a server interceptor function to authenticate and authorize unary RPC func (interceptor *AuthInterceptor) Unary() grpc.UnaryServerInterceptor { return func( ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler, ) (interface{}, error) { enableJWT := g.Config().GetBool("system.enableJWT") if !enableJWT { return handler(ctx, req) } //glog.Info("--> unary interceptor: ", info.FullMethod) noAuthenticationList := g.Config().GetStrings("system.noAuthenticationList") if IsContains(info.FullMethod, noAuthenticationList) { return handler(ctx, req) } err := interceptor.authorize(ctx, info.FullMethod) if err != nil { glog.Error(err) return nil, err } return handler(ctx, req) } } // Stream returns a server interceptor function to authenticate and authorize stream RPC func (interceptor *AuthInterceptor) Stream() grpc.StreamServerInterceptor { return func( srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler, ) error { glog.Info("--> stream interceptor: ", info.FullMethod) enableJWT := g.Config().GetBool("system.enableJWT") if !enableJWT { return handler(srv, stream) } err := interceptor.authorize(stream.Context(), info.FullMethod) if err != nil { glog.Error(err) return err } return handler(srv, stream) } } func (interceptor *AuthInterceptor) authorize(ctx context.Context, method string) error { accessibleRoles, ok := interceptor.accessibleRoles[method] md, ok := metadata.FromIncomingContext(ctx) if !ok { return status.Errorf(codes.Unauthenticated, "没有提供metadata") } values := md["authorization"] if len(values) == 0 { return status.Errorf(codes.Unauthenticated, "不是授权的token") } accessToken := values[0] claims, err := interceptor.jwtManager.Verify(accessToken) if err != nil { return status.Errorf(codes.Unauthenticated, err.Error()) } //TODO:角色处理 glog.Info("claims====", claims) glog.Info("accessibleRoles====", accessibleRoles) //for _, roles := range accessibleRoles { // glog.Info(roles) // //if roles == claims.Roles { // return nil // //} //} return nil //return status.Error(codes.PermissionDenied, "没有访问的权限") } // IsContains 查找值val是否在数组array中存在 func IsContains(target string, str_array []string) bool { sort.Strings(str_array) index := sort.SearchStrings(str_array, target) if index < len(str_array) && str_array[index] == target { return true } return false }