Last active
September 23, 2024 16:04
-
-
Save asardaes/2267ef418e0e86d4926fd1112d5f7f73 to your computer and use it in GitHub Desktop.
Shiny app in Python hacked to support OAuth2 authorization code *for client requests*
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import typing | |
import uuid | |
from urllib.parse import urlparse | |
from requests_oauthlib.oauth2_session import OAuth2Session | |
from shiny import App | |
from shiny._connection import StarletteConnection | |
from shiny.session._session import AppSession | |
from starlette.applications import Starlette | |
from starlette.middleware import Middleware | |
from starlette.middleware.sessions import SessionMiddleware | |
from starlette.requests import Request | |
from starlette.responses import RedirectResponse, Response | |
from starlette.routing import Route | |
from starlette.websockets import WebSocket | |
from com.login import ATTR_SESSION_ID, ATTR_OAUTH2_SESSION | |
from com.login.utils import create_session, login_url, complete_login | |
from com.shiny.server import StatefulServer | |
from com.shiny.ui import UI | |
class AuthenticatedShinyApp(App): | |
def __init__( | |
self, | |
client_id: str, | |
client_secret: str, | |
login_redirect_uri: str = "http://localhost:8080/oauth/callback", | |
*, | |
debug: bool = False, | |
) -> None: | |
server = StatefulServer() | |
super().__init__(UI, server, debug=debug) | |
self.on_shutdown(server.shutdown) | |
self._client_id = client_id | |
self._client_secret = client_secret | |
self._login_redirect_uri = login_redirect_uri | |
self._oauth2_sessions: dict[str, OAuth2Session] = {} | |
# OAuth2 callback handler | |
redirect_components = urlparse(login_redirect_uri) | |
self._dependency_handler.routes.append( | |
Route(redirect_components.path, endpoint=self._oauth2_callback) | |
) | |
def init_starlette_app(self) -> Starlette: | |
# SessionMiddleware will handle the cookie with the session_id, | |
# each session_id shall be associated with a different OAuth2Session | |
star = super().init_starlette_app() | |
star.user_middleware.append(Middleware(SessionMiddleware, secret_key="top_secret", max_age=60 * 60 * 12)) | |
return star | |
async def _on_root_request_cb(self, request: Request) -> Response: | |
if ATTR_SESSION_ID in request.session: | |
session_id = request.session[ATTR_SESSION_ID] | |
else: | |
session_id = str(uuid.uuid4()) | |
request.session[ATTR_SESSION_ID] = session_id | |
oauth2_session: OAuth2Session | |
if session_id in self._oauth2_sessions: | |
oauth2_session = self._oauth2_sessions[session_id] | |
else: | |
oauth2_session = create_session(self._client_id, self._client_secret, self._login_redirect_uri) | |
self._oauth2_sessions[session_id] = oauth2_session | |
if oauth2_session.authorized: | |
return await super()._on_root_request_cb(request) | |
else: | |
return RedirectResponse(url=login_url(oauth2_session)) | |
async def _oauth2_callback(self, request: Request) -> Response: | |
session_id = request.session[ATTR_SESSION_ID] | |
oauth2_session = self._oauth2_sessions[session_id] | |
await complete_login(oauth2_session, request, self._client_secret) | |
return RedirectResponse(url="/") | |
async def _on_connect_cb(self, ws: WebSocket) -> None: | |
""" | |
Copied code from super class but injecting oauth2_session as an attribute so that | |
the server can make use of it. | |
""" | |
await ws.accept() | |
conn = StarletteConnection(ws) | |
session = self._create_session(conn) | |
session_id = ws.session[ATTR_SESSION_ID] | |
oauth2_session = self._oauth2_sessions[session_id] | |
setattr(session, ATTR_SESSION_ID, session_id) | |
setattr(session, ATTR_OAUTH2_SESSION, oauth2_session) | |
await session._run() | |
def _remove_session(self, session: AppSession | str) -> None: | |
if isinstance(session, AppSession): | |
session_id = getattr(session, ATTR_SESSION_ID) | |
else: | |
session_id = getattr(self._sessions[session.id], ATTR_SESSION_ID) | |
if session_id in self._oauth2_sessions: | |
self._oauth2_sessions[session_id].close() | |
del self._oauth2_sessions[session_id] | |
server = typing.cast(StatefulServer, self.server) | |
server.clear(session_id) | |
super()._remove_session(session) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
anyio==4.5.0 | |
appdirs==1.4.4 | |
asgiref==3.8.1 | |
certifi==2024.8.30 | |
charset-normalizer==3.3.2 | |
click==8.1.7 | |
h11==0.14.0 | |
htmltools==0.5.3 | |
idna==3.10 | |
itsdangerous==2.2.0 | |
linkify-it-py==2.0.3 | |
markdown-it-py==3.0.0 | |
mdit-py-plugins==0.4.2 | |
mdurl==0.1.2 | |
oauthlib==3.2.2 | |
packaging==24.1 | |
prompt-toolkit==3.0.36 | |
python-multipart==0.0.9 | |
questionary==2.0.1 | |
requests==2.32.3 | |
requests-oauthlib==2.0.0 | |
setuptools==75.1.0 | |
shiny==1.1.0 | |
sniffio==1.3.1 | |
starlette==0.38.5 | |
typing_extensions==4.12.2 | |
uc-micro-py==1.0.3 | |
urllib3==2.2.3 | |
uvicorn==0.30.6 | |
watchfiles==0.24.0 | |
wcwidth==0.2.13 | |
websockets==13.0.1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
from requests_oauthlib.oauth2_session import OAuth2Session | |
from starlette.requests import Request | |
def create_session(client_id: str, client_secret: str, redirect_uri: str) -> OAuth2Session: | |
return OAuth2Session( | |
client_id, | |
scope=["openid", "profile", "offline_access"], | |
redirect_uri=redirect_uri, | |
auto_refresh_url="https://server.com/token", | |
auto_refresh_kwargs={ | |
'client_id': client_id, | |
'client_secret': client_secret, | |
}, | |
) | |
def login_url(session: OAuth2Session) -> str: | |
""" | |
Note that the session has state and calling this method modifies it. | |
""" | |
return session.authorization_url("https://server.com/auth")[0] | |
async def complete_login(session: OAuth2Session, request: Request, client_secret: str): | |
# host in this url doesn't matter, but https does! | |
url = f"https://localhost?code={request.query_params['code']}&state={request.query_params['state']}" | |
await asyncio.to_thread( | |
session.fetch_token, | |
token_url="https://server.com/token", | |
authorization_response=url, | |
client_secret=client_secret, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment