Last active
April 18, 2024 14:52
-
-
Save Kilo59/aac2d18fd59a82c6a5bb1173510bba64 to your computer and use it in GitHub Desktop.
FastAPI Decorators with Dependencies
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
pip install fastapi uvicorn | |
uvicorn dep_dec_api:app --reload | |
""" | |
import base64 | |
import enum | |
import logging | |
from collections.abc import Mapping | |
from pprint import pformat as pf | |
from typing import Any | |
from fastapi import Depends, FastAPI, HTTPException, Request | |
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer | |
from pydantic import BaseModel, model_validator | |
from typing_extensions import Self, override | |
SECURITY_SCHEME = HTTPBearer(description="Viewer: `Mw==` Editor: `Mg==` Admin: `MQ==`") | |
logging.basicConfig(level=logging.INFO) | |
LOGGER = logging.getLogger(__name__) | |
app = FastAPI(docs_url="/", redoc_url="/docs") | |
class Role(str, enum.Enum): | |
admin = "admin" | |
editor = "editor" | |
viewer = "viewer" | |
@classmethod | |
def _order_map(cls) -> Mapping[str, int]: | |
return {"admin": 3, "editor": 2, "viewer": 1} | |
@override | |
def __gt__(self, other) -> bool: | |
if not isinstance(other, Role): | |
raise TypeError(f"Cannot compare Role with {type(other)}") | |
order = self._order_map() | |
return order[self.value] > order[other.value] | |
assert Role.admin > Role.editor # noqa: S101 | |
assert Role.admin > Role.editor > Role.viewer # noqa: S101 | |
class Users(BaseModel): | |
id: int | |
encoded_id: str | None = None | |
name: str | |
role: Role | |
@model_validator(mode="before") | |
@classmethod | |
def encode_id(cls, data: Any) -> Self: | |
encoded_id = base64.b64encode(str(data["id"]).encode("utf-8")).decode("utf-8") | |
data["encoded_id"] = encoded_id | |
return data | |
class Item(BaseModel): | |
id: int | |
name: str | |
contents: list[str] | |
class ItemUpdate(BaseModel): | |
name: str | None = None | |
contents: list[str] | None = None | |
def extract_id_from_creds(creds: HTTPAuthorizationCredentials) -> int: | |
token = creds.credentials | |
try: | |
user_id = base64.b64decode(token).decode("utf-8") | |
return int(user_id) | |
except ValueError as err: | |
LOGGER.warning(f"Invalid token: {err!r}") | |
raise HTTPException(status_code=401, detail="Invalid token") from ValueError | |
def auth_required(handler): | |
async def wrapper( | |
request: Request, | |
creds=Depends(SECURITY_SCHEME), # noqa: B008 | |
*args, | |
**kwargs, | |
): | |
id = extract_id_from_creds(creds) | |
user = USERS.get(id) | |
if not user: | |
raise HTTPException(status_code=401, detail="Invalid credentials") | |
return await handler(*args, **kwargs) | |
# Fix signature of wrapper | |
import inspect | |
wrapper.__signature__ = inspect.Signature( | |
parameters=[ | |
# Use all parameters from handler | |
*inspect.signature(handler).parameters.values(), | |
# Skip *args and **kwargs from wrapper parameters: | |
*filter( | |
lambda p: p.kind | |
not in ( | |
inspect.Parameter.VAR_POSITIONAL, | |
inspect.Parameter.VAR_KEYWORD, | |
), | |
inspect.signature(wrapper).parameters.values(), | |
), | |
], | |
return_annotation=inspect.signature(handler).return_annotation, | |
) | |
return wrapper | |
def editor_required(handler): | |
async def wrapper( | |
request: Request, | |
creds=Depends(SECURITY_SCHEME), # noqa: B008 | |
*args, | |
**kwargs, | |
): | |
id = extract_id_from_creds(creds) | |
user = USERS.get(id) | |
if not user: | |
raise HTTPException(status_code=401, detail="Invalid credentials") | |
if user.role not in ("admin", "editor"): | |
raise HTTPException(status_code=403, detail=f"Forbidden for `{user.role}`") | |
return await handler(*args, **kwargs) | |
# Fix signature of wrapper | |
import inspect | |
wrapper.__signature__ = inspect.Signature( | |
parameters=[ | |
# Use all parameters from handler | |
*inspect.signature(handler).parameters.values(), | |
# Skip *args and **kwargs from wrapper parameters: | |
*filter( | |
lambda p: p.kind | |
not in ( | |
inspect.Parameter.VAR_POSITIONAL, | |
inspect.Parameter.VAR_KEYWORD, | |
), | |
inspect.signature(wrapper).parameters.values(), | |
), | |
], | |
return_annotation=inspect.signature(handler).return_annotation, | |
) | |
return wrapper | |
def role_required(role: Role): | |
def decorator(handler): | |
async def wrapper( | |
request: Request, | |
creds=Depends(SECURITY_SCHEME), # noqa: B008 | |
*args, | |
**kwargs, | |
): | |
id = extract_id_from_creds(creds) | |
user = USERS.get(id) | |
if not user: | |
raise HTTPException(status_code=401, detail="Invalid credentials") | |
if user.role >= role: | |
raise HTTPException( | |
status_code=403, detail=f"Forbidden for `{user.role}`" | |
) | |
return await handler(*args, **kwargs) | |
# Fix signature of wrapper | |
import inspect | |
wrapper.__signature__ = inspect.Signature( | |
parameters=[ | |
# Use all parameters from handler | |
*inspect.signature(handler).parameters.values(), | |
# Skip *args and **kwargs from wrapper parameters: | |
*filter( | |
lambda p: p.kind | |
not in ( | |
inspect.Parameter.VAR_POSITIONAL, | |
inspect.Parameter.VAR_KEYWORD, | |
), | |
inspect.signature(wrapper).parameters.values(), | |
), | |
], | |
return_annotation=inspect.signature(handler).return_annotation, | |
) | |
# update the doc | |
initial_doc = handler.__doc__ or "" | |
wrapper.__doc__ = f"Role required: `{role}`\n{initial_doc}" | |
return wrapper | |
return decorator | |
USERS: dict[int, Users] = { | |
3: Users(id=3, name="The Boss", role="admin"), | |
2: Users(id=2, name="The Engineer", role="editor"), | |
1: Users(id=1, name="The Intern", role="viewer"), | |
} | |
print(f"{pf(USERS)}") | |
ITEMS: dict[int, Item] = { | |
1: Item(id=1, name="Item Foo", contents=["foo", "bar"]), | |
2: Item(id=2, name="Item Bar", contents=["baz", "qux"]), | |
} | |
@app.get("/items", response_model=list[Item]) | |
@auth_required # Custom decorator | |
async def read_items(): | |
return list(ITEMS.values()) | |
@app.post("/items", response_model=Item) | |
@editor_required # Custom decorator | |
async def add_item(payload: Item): | |
ITEMS[payload.id] = payload | |
return payload | |
@app.patch("/items/{id}", response_model=Item) | |
@role_required(role=Role.admin) # Custom decorator | |
async def update_item(id: int, payload: ItemUpdate): | |
item = ITEMS.get(id) | |
if not item: | |
raise HTTPException(status_code=404, detail="Item not found") | |
if payload.name: | |
item.name = payload.name | |
if payload.contents: | |
item.contents = payload.contents | |
return item |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
uvicorn dep_dec_api:app --reload |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
See https://gist.github.com/md2perpe/ee146e547a0bd910ea9683a2eea47c59