import logging from channels.db import database_sync_to_async from channels.middleware import BaseMiddleware from django.contrib.auth.models import AnonymousUser from rest_framework_simplejwt.exceptions import InvalidToken, TokenError from rest_framework_simplejwt.tokens import AccessToken logger = logging.getLogger(__name__) # ---------- WEBSOCKET JWT AUTH VALIDATION ----------- @database_sync_to_async def get_user_from_token(token_string): """ Validate JWT token and return user """ try: # Validate the token access_token = AccessToken(token_string) # Get user ID from token user_id = access_token.get('user_id') if not user_id: return AnonymousUser() # Import here to avoid circular imports from account.models import CustomUser # Get user from database try: user = CustomUser.objects.get(id=user_id) return user except CustomUser.DoesNotExist: return AnonymousUser() except (TokenError, InvalidToken, Exception) as e: logger.warning(f"JWT validation failed in websocket: {e}") return AnonymousUser() class JWTAuthMiddleware(BaseMiddleware): """ Custom middleware to authenticate WebSocket connections using JWT from cookies. Replaces AuthMiddlewareStack / CSRF validation for WebSocket routes. """ async def __call__(self, scope, receive, send): # Get headers from scope headers = dict(scope.get('headers', [])) # Extract cookies from headers cookie_header = headers.get(b'cookie', b'').decode('utf-8') logger.info(f"[WS] cookie header present: {bool(cookie_header)} | keys: {[c.split('=')[0].strip() for c in cookie_header.split('; ') if '=' in c]}") # Parse cookies cookies = {} if cookie_header: for cookie in cookie_header.split('; '): if '=' in cookie: key, value = cookie.split('=', 1) cookies[key.strip()] = value # Get access_token from cookies token = cookies.get('access_token') logger.info(f"[WS] token found: {bool(token)}") # Authenticate user if token: scope['user'] = await get_user_from_token(token) else: scope['user'] = AnonymousUser() logger.info(f"[WS] authenticated as: {scope['user']} | is_authenticated: {getattr(scope['user'], 'is_authenticated', False)}") return await super().__call__(scope, receive, send) # ---------- END | WEBSOCKET JWT AUTH VALIDATION -----------