Skip to content

Instantly share code, notes, and snippets.

@Kilo59
Last active April 18, 2024 14:52
Show Gist options
  • Save Kilo59/aac2d18fd59a82c6a5bb1173510bba64 to your computer and use it in GitHub Desktop.
Save Kilo59/aac2d18fd59a82c6a5bb1173510bba64 to your computer and use it in GitHub Desktop.
FastAPI Decorators with Dependencies
"""
pip install fastapi uvicorn
uvicorn dep_dec_api:app --reload
"""
import base64
import enum
import logging
from collections.abc import Mapping
from pprint import pformat as pf
from typing import Any
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel, model_validator
from typing_extensions import Self, override
SECURITY_SCHEME = HTTPBearer(description="Viewer: `Mw==` Editor: `Mg==` Admin: `MQ==`")
logging.basicConfig(level=logging.INFO)
LOGGER = logging.getLogger(__name__)
app = FastAPI(docs_url="/", redoc_url="/docs")
class Role(str, enum.Enum):
admin = "admin"
editor = "editor"
viewer = "viewer"
@classmethod
def _order_map(cls) -> Mapping[str, int]:
return {"admin": 3, "editor": 2, "viewer": 1}
@override
def __gt__(self, other) -> bool:
if not isinstance(other, Role):
raise TypeError(f"Cannot compare Role with {type(other)}")
order = self._order_map()
return order[self.value] > order[other.value]
assert Role.admin > Role.editor # noqa: S101
assert Role.admin > Role.editor > Role.viewer # noqa: S101
class Users(BaseModel):
id: int
encoded_id: str | None = None
name: str
role: Role
@model_validator(mode="before")
@classmethod
def encode_id(cls, data: Any) -> Self:
encoded_id = base64.b64encode(str(data["id"]).encode("utf-8")).decode("utf-8")
data["encoded_id"] = encoded_id
return data
class Item(BaseModel):
id: int
name: str
contents: list[str]
class ItemUpdate(BaseModel):
name: str | None = None
contents: list[str] | None = None
def extract_id_from_creds(creds: HTTPAuthorizationCredentials) -> int:
token = creds.credentials
try:
user_id = base64.b64decode(token).decode("utf-8")
return int(user_id)
except ValueError as err:
LOGGER.warning(f"Invalid token: {err!r}")
raise HTTPException(status_code=401, detail="Invalid token") from ValueError
def auth_required(handler):
async def wrapper(
request: Request,
creds=Depends(SECURITY_SCHEME), # noqa: B008
*args,
**kwargs,
):
id = extract_id_from_creds(creds)
user = USERS.get(id)
if not user:
raise HTTPException(status_code=401, detail="Invalid credentials")
return await handler(*args, **kwargs)
# Fix signature of wrapper
import inspect
wrapper.__signature__ = inspect.Signature(
parameters=[
# Use all parameters from handler
*inspect.signature(handler).parameters.values(),
# Skip *args and **kwargs from wrapper parameters:
*filter(
lambda p: p.kind
not in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
),
inspect.signature(wrapper).parameters.values(),
),
],
return_annotation=inspect.signature(handler).return_annotation,
)
return wrapper
def editor_required(handler):
async def wrapper(
request: Request,
creds=Depends(SECURITY_SCHEME), # noqa: B008
*args,
**kwargs,
):
id = extract_id_from_creds(creds)
user = USERS.get(id)
if not user:
raise HTTPException(status_code=401, detail="Invalid credentials")
if user.role not in ("admin", "editor"):
raise HTTPException(status_code=403, detail=f"Forbidden for `{user.role}`")
return await handler(*args, **kwargs)
# Fix signature of wrapper
import inspect
wrapper.__signature__ = inspect.Signature(
parameters=[
# Use all parameters from handler
*inspect.signature(handler).parameters.values(),
# Skip *args and **kwargs from wrapper parameters:
*filter(
lambda p: p.kind
not in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
),
inspect.signature(wrapper).parameters.values(),
),
],
return_annotation=inspect.signature(handler).return_annotation,
)
return wrapper
def role_required(role: Role):
def decorator(handler):
async def wrapper(
request: Request,
creds=Depends(SECURITY_SCHEME), # noqa: B008
*args,
**kwargs,
):
id = extract_id_from_creds(creds)
user = USERS.get(id)
if not user:
raise HTTPException(status_code=401, detail="Invalid credentials")
if user.role >= role:
raise HTTPException(
status_code=403, detail=f"Forbidden for `{user.role}`"
)
return await handler(*args, **kwargs)
# Fix signature of wrapper
import inspect
wrapper.__signature__ = inspect.Signature(
parameters=[
# Use all parameters from handler
*inspect.signature(handler).parameters.values(),
# Skip *args and **kwargs from wrapper parameters:
*filter(
lambda p: p.kind
not in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
),
inspect.signature(wrapper).parameters.values(),
),
],
return_annotation=inspect.signature(handler).return_annotation,
)
# update the doc
initial_doc = handler.__doc__ or ""
wrapper.__doc__ = f"Role required: `{role}`\n{initial_doc}"
return wrapper
return decorator
USERS: dict[int, Users] = {
3: Users(id=3, name="The Boss", role="admin"),
2: Users(id=2, name="The Engineer", role="editor"),
1: Users(id=1, name="The Intern", role="viewer"),
}
print(f"{pf(USERS)}")
ITEMS: dict[int, Item] = {
1: Item(id=1, name="Item Foo", contents=["foo", "bar"]),
2: Item(id=2, name="Item Bar", contents=["baz", "qux"]),
}
@app.get("/items", response_model=list[Item])
@auth_required # Custom decorator
async def read_items():
return list(ITEMS.values())
@app.post("/items", response_model=Item)
@editor_required # Custom decorator
async def add_item(payload: Item):
ITEMS[payload.id] = payload
return payload
@app.patch("/items/{id}", response_model=Item)
@role_required(role=Role.admin) # Custom decorator
async def update_item(id: int, payload: ItemUpdate):
item = ITEMS.get(id)
if not item:
raise HTTPException(status_code=404, detail="Item not found")
if payload.name:
item.name = payload.name
if payload.contents:
item.contents = payload.contents
return item
uvicorn dep_dec_api:app --reload
@Kilo59
Copy link
Author

Kilo59 commented Apr 18, 2024

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment