Last active
June 20, 2024 08:28
-
-
Save samdbmg/23ffe8bfe0a30ca9072b282a94e9ac5d to your computer and use it in GitHub Desktop.
Proxy a remote API secured using an OAuth2 client credentials grant, exposing it locally without auth
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Taken partly from https://stackoverflow.com/a/36601467 | |
from datetime import datetime, timedelta | |
import os | |
from authlib.integrations.httpx_client import AsyncOAuth2Client | |
from sanic import Sanic | |
from sanic.log import logger | |
from sanic.response import raw | |
# This should be the full token endpoint URL, | |
# e.g. for Keycloak that's https://{KEYCLOAK_BASE_URL}/realms/{REALM_NAME}/protocol/openid-connect/token | |
TOKEN_URL = os.environ.get("TOKEN_URL") | |
CLIENT_ID = os.environ.get("CLIENT_ID") | |
CLIENT_SECRET = os.environ.get("CLIENT_SECRET") | |
API_URL = os.environ.get("API_URL") | |
app = Sanic("ProxyApp") | |
@app.before_server_start | |
async def setup_client(app): | |
client = AsyncOAuth2Client(CLIENT_ID, CLIENT_SECRET) | |
app.ctx.client = client | |
app.ctx.token = await client.fetch_token(TOKEN_URL, grant_type="client_credentials") | |
logger.info("Token fetch complete") | |
async def _refresh_token(): | |
token_expiry = datetime.fromtimestamp(app.ctx.token["expires_at"]) | |
if token_expiry - timedelta(seconds=30) < datetime.now(): | |
logger.info("Token expired or expiring soon, refreshing") | |
app.ctx.token = await app.ctx.client.fetch_token(TOKEN_URL, grant_type="client_credentials") | |
@app.route('/<path:path>', methods=["HEAD", "GET", "POST", "PUT", "DELETE"]) | |
async def proxy_request(request, path): | |
if path == "": | |
target_url = API_URL | |
else: | |
target_url = f"{API_URL}/{path}" | |
logger.info(f"Proxying {request.method} request to /{path} -> {target_url}") | |
del (request.headers["host"]) | |
# Suppress any HTML render, since the links won't work | |
del (request.headers["accept"]) | |
# Refresh the token | |
await _refresh_token() | |
res = await app.ctx.client.request( | |
method=request.method, | |
headers=request.headers, | |
url=target_url, | |
content=request.body, | |
params=request.args | |
) | |
if res.status_code == 301 or res.status_code == 302: | |
# Rewrite the location header to avoid redirecting upstream | |
res.headers["Location"] = res.headers["Location"].replace(API_URL, "http://127.0.0.1:8000") | |
return raw(res.content, res.status_code, res.headers) | |
if __name__ == "__main__": | |
app.run(access_log=True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
aiohttp | |
authlib | |
sanic |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment