Skip to content

Instantly share code, notes, and snippets.

@asardaes
Last active September 23, 2024 16:04
Show Gist options
  • Save asardaes/2267ef418e0e86d4926fd1112d5f7f73 to your computer and use it in GitHub Desktop.
Save asardaes/2267ef418e0e86d4926fd1112d5f7f73 to your computer and use it in GitHub Desktop.
Shiny app in Python hacked to support OAuth2 authorization code *for client requests*
import typing
import uuid
from urllib.parse import urlparse
from requests_oauthlib.oauth2_session import OAuth2Session
from shiny import App
from shiny._connection import StarletteConnection
from shiny.session._session import AppSession
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.sessions import SessionMiddleware
from starlette.requests import Request
from starlette.responses import RedirectResponse, Response
from starlette.routing import Route
from starlette.websockets import WebSocket
from com.login import ATTR_SESSION_ID, ATTR_OAUTH2_SESSION
from com.login.utils import create_session, login_url, complete_login
from com.shiny.server import StatefulServer
from com.shiny.ui import UI
class AuthenticatedShinyApp(App):
def __init__(
self,
client_id: str,
client_secret: str,
login_redirect_uri: str = "http://localhost:8080/oauth/callback",
*,
debug: bool = False,
) -> None:
server = StatefulServer()
super().__init__(UI, server, debug=debug)
self.on_shutdown(server.shutdown)
self._client_id = client_id
self._client_secret = client_secret
self._login_redirect_uri = login_redirect_uri
self._oauth2_sessions: dict[str, OAuth2Session] = {}
# OAuth2 callback handler
redirect_components = urlparse(login_redirect_uri)
self._dependency_handler.routes.append(
Route(redirect_components.path, endpoint=self._oauth2_callback)
)
def init_starlette_app(self) -> Starlette:
# SessionMiddleware will handle the cookie with the session_id,
# each session_id shall be associated with a different OAuth2Session
star = super().init_starlette_app()
star.user_middleware.append(Middleware(SessionMiddleware, secret_key="top_secret", max_age=60 * 60 * 12))
return star
async def _on_root_request_cb(self, request: Request) -> Response:
if ATTR_SESSION_ID in request.session:
session_id = request.session[ATTR_SESSION_ID]
else:
session_id = str(uuid.uuid4())
request.session[ATTR_SESSION_ID] = session_id
oauth2_session: OAuth2Session
if session_id in self._oauth2_sessions:
oauth2_session = self._oauth2_sessions[session_id]
else:
oauth2_session = create_session(self._client_id, self._client_secret, self._login_redirect_uri)
self._oauth2_sessions[session_id] = oauth2_session
if oauth2_session.authorized:
return await super()._on_root_request_cb(request)
else:
return RedirectResponse(url=login_url(oauth2_session))
async def _oauth2_callback(self, request: Request) -> Response:
session_id = request.session[ATTR_SESSION_ID]
oauth2_session = self._oauth2_sessions[session_id]
await complete_login(oauth2_session, request, self._client_secret)
return RedirectResponse(url="/")
async def _on_connect_cb(self, ws: WebSocket) -> None:
"""
Copied code from super class but injecting oauth2_session as an attribute so that
the server can make use of it.
"""
await ws.accept()
conn = StarletteConnection(ws)
session = self._create_session(conn)
session_id = ws.session[ATTR_SESSION_ID]
oauth2_session = self._oauth2_sessions[session_id]
setattr(session, ATTR_SESSION_ID, session_id)
setattr(session, ATTR_OAUTH2_SESSION, oauth2_session)
await session._run()
def _remove_session(self, session: AppSession | str) -> None:
if isinstance(session, AppSession):
session_id = getattr(session, ATTR_SESSION_ID)
else:
session_id = getattr(self._sessions[session.id], ATTR_SESSION_ID)
if session_id in self._oauth2_sessions:
self._oauth2_sessions[session_id].close()
del self._oauth2_sessions[session_id]
server = typing.cast(StatefulServer, self.server)
server.clear(session_id)
super()._remove_session(session)
anyio==4.5.0
appdirs==1.4.4
asgiref==3.8.1
certifi==2024.8.30
charset-normalizer==3.3.2
click==8.1.7
h11==0.14.0
htmltools==0.5.3
idna==3.10
itsdangerous==2.2.0
linkify-it-py==2.0.3
markdown-it-py==3.0.0
mdit-py-plugins==0.4.2
mdurl==0.1.2
oauthlib==3.2.2
packaging==24.1
prompt-toolkit==3.0.36
python-multipart==0.0.9
questionary==2.0.1
requests==2.32.3
requests-oauthlib==2.0.0
setuptools==75.1.0
shiny==1.1.0
sniffio==1.3.1
starlette==0.38.5
typing_extensions==4.12.2
uc-micro-py==1.0.3
urllib3==2.2.3
uvicorn==0.30.6
watchfiles==0.24.0
wcwidth==0.2.13
websockets==13.0.1
import asyncio
from requests_oauthlib.oauth2_session import OAuth2Session
from starlette.requests import Request
def create_session(client_id: str, client_secret: str, redirect_uri: str) -> OAuth2Session:
return OAuth2Session(
client_id,
scope=["openid", "profile", "offline_access"],
redirect_uri=redirect_uri,
auto_refresh_url="https://server.com/token",
auto_refresh_kwargs={
'client_id': client_id,
'client_secret': client_secret,
},
)
def login_url(session: OAuth2Session) -> str:
"""
Note that the session has state and calling this method modifies it.
"""
return session.authorization_url("https://server.com/auth")[0]
async def complete_login(session: OAuth2Session, request: Request, client_secret: str):
# host in this url doesn't matter, but https does!
url = f"https://localhost?code={request.query_params['code']}&state={request.query_params['state']}"
await asyncio.to_thread(
session.fetch_token,
token_url="https://server.com/token",
authorization_response=url,
client_secret=client_secret,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment