Last active
May 24, 2022 13:47
-
-
Save alex-pobeditel-2004/5098bac720c4eeb79052b7234346f52d to your computer and use it in GitHub Desktop.
JWT Auth middleware for Django Channels 3.0 and rest_framework_simplejwt - update of @dmwyatt gist
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Original gist: https://gist.github.com/dmwyatt/5cf7e5102ed0a01b7d38aabf322e03b2 | |
""" | |
import logging | |
from typing import Awaitable, Final, List, TYPE_CHECKING, TypedDict | |
from channels.auth import AuthMiddlewareStack | |
from channels.db import database_sync_to_async | |
from django.contrib.auth.models import AnonymousUser | |
from rest_framework.exceptions import AuthenticationFailed | |
from rest_framework_simplejwt.authentication import JWTAuthentication | |
if TYPE_CHECKING: | |
# If you're using a type checker, change this line to whatever your user model is. | |
from authentication.models import CustomUser | |
logger = logging.getLogger(__name__) | |
TOKEN_STR_PREFIX: Final = "Bearer" | |
class Scope(TypedDict, total=False): | |
subprotocols: List[str] | |
class QueryAuthMiddleware: | |
""" | |
Middleware for django-channels that gets the user from a websocket subprotocol | |
containing the JWT. | |
""" | |
def __init__(self, inner): | |
# Store the ASGI application we were passed | |
self.inner = inner | |
def __call__(self, scope: Scope): | |
return QueryAuthMiddlewareInstance(scope, self) | |
class QueryAuthMiddlewareInstance: | |
""" | |
Inner class that is instantiated once per scope. | |
""" | |
def __init__(self, scope: Scope, middleware): | |
self.middleware = middleware | |
self.scope = scope | |
self.inner = self.middleware.inner | |
async def __call__(self, receive, send): | |
if not self.scope.get("user") or self.scope["user"].is_anonymous: | |
logger.debug("Attempting to authenticate user.") | |
try: | |
self.scope["user"] = await get_user_from_scope(self.scope) | |
if "auth_error" in self.scope: | |
del self.scope["auth_error"] | |
except (AuthenticationFailed, MissingTokenError) as e: | |
self.scope["user"] = AnonymousUser() | |
# Saves the error received during authentication into the scope so | |
# that we can do something with it later if we want. | |
self.scope["auth_error"] = str(e) | |
logger.info("Could not auth user: %s", str(e)) | |
inner = self.inner(self.scope, receive, send) | |
return await inner | |
JWTBearerProtocolAuthStack = lambda inner: QueryAuthMiddleware( | |
AuthMiddlewareStack(inner) | |
) | |
def get_bearer_subprotocol(scope: Scope): | |
for subproto in scope.get("subprotocols", []): | |
if subproto.startswith(TOKEN_STR_PREFIX): | |
return subproto | |
class JWTAuth(JWTAuthentication): | |
@classmethod | |
def get_token_from_request(cls, scope: Scope) -> str: | |
""" | |
Abuse this method to get token from django-channels scope instead of an http | |
request. | |
:param scope: Scope from django-channels middleware. | |
""" | |
token_string = get_bearer_subprotocol(scope) | |
if not token_string: | |
raise ValueError("No token provided.") | |
token = token_string.split(TOKEN_STR_PREFIX)[1] | |
return token | |
class MissingTokenError(Exception): | |
pass | |
class MetaRequest: | |
""" | |
This class puts headers from simple scope (request) to an object with META property for rest_framework_simplejwt | |
""" | |
def __init__(self, scope: dict): | |
""" | |
This code copied from django.core.handlers.asgi | |
:param scope: | |
""" | |
self.META = dict() | |
# Headers go into META. | |
for name, value in scope.get('headers', []): | |
name = name.decode('latin1') | |
if name == 'content-length': | |
corrected_name = 'CONTENT_LENGTH' | |
elif name == 'content-type': | |
corrected_name = 'CONTENT_TYPE' | |
else: | |
corrected_name = 'HTTP_%s' % name.upper().replace('-', '_') | |
# HTTP/2 say only ASCII chars are allowed in headers, but decode | |
# latin1 just in case. | |
value = value.decode('latin1') | |
if corrected_name in self.META: | |
value = self.META[corrected_name] + ',' + value | |
self.META[corrected_name] = value | |
@database_sync_to_async | |
def get_user_from_scope(scope) -> Awaitable[User]: | |
auth = JWTAuth() | |
# Fiddling META for rest_framework_simplejwt: | |
meta_request = MetaRequest(scope) | |
authenticated = auth.authenticate(meta_request) | |
if authenticated is None: | |
raise MissingTokenError("Cannot find token in scope.") | |
user, token = authenticated | |
logger.debug("Authenticated %s", user) | |
return user |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment