Created
September 25, 2024 14:20
-
-
Save jgadling/bf27a924cd9c34a2a64d2dbf8a5507e6 to your computer and use it in GitHub Desktop.
Example of using Strawberry FieldExtensions to implement FastAPI dependency injection in resolvers for FastAPI 0.115.0
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools | |
import inspect | |
import types | |
import typing | |
from contextlib import AsyncExitStack | |
from fastapi.dependencies import utils as deputils | |
from fastapi.dependencies.models import Dependant | |
from fastapi.params import Depends as DependsClass | |
from strawberry.extensions import FieldExtension | |
from strawberry.types import Info | |
from strawberry.types.field import StrawberryField | |
def get_func_with_only_deps(func: typing.Callable[..., typing.Any]) -> typing.Callable[..., typing.Any]: | |
"""This function returns a copy of the function with all the arguments that are not DependsClass | |
updated to have a type annotation of "str". We do this because Pydantic explodes if it sees any | |
parameter annotationss that rely on the strawberry.lazy() functionality that Strawberry requires t | |
handle forward-refs properly. Basically Strawberry and pydantic use different and incompatible tricks | |
for handling forward refs and we decided that it was better to workaround Pydantic than Strawberry.""" | |
tmp_func = types.FunctionType( | |
func.__code__, | |
func.__globals__, | |
name=func.__name__, | |
argdefs=func.__defaults__, | |
closure=func.__closure__, | |
) | |
newfunc = functools.update_wrapper(tmp_func, func) | |
signature = inspect.signature(func) | |
for param in signature.parameters.values(): | |
if isinstance(param.default, DependsClass): | |
continue | |
newfunc.__annotations__[param.name] = str | |
return newfunc | |
class DependencyExtension(FieldExtension): | |
def __init__(self) -> None: | |
self.dependency_args: list[typing.Any] = [] | |
self.strawberry_field_names = ["self"] | |
def apply(self, field: StrawberryField) -> None: | |
func = field.base_resolver.wrapped_func # type: ignore | |
func = get_func_with_only_deps(func) # type: ignore | |
self.dependant: Dependant = deputils.get_dependant(path="/", call=func) # type: ignore | |
# Remove fastapi Depends arguments from the list that strawberry tries to resolve | |
field.arguments = [item for item in field.arguments if not isinstance(item.default, DependsClass)] | |
async def resolve_async( | |
self, | |
next_: typing.Callable[..., typing.Any], | |
source: typing.Any, | |
info: Info, | |
**kwargs: dict[str, typing.Any], | |
) -> typing.Any: | |
request = info.context["request"] | |
try: | |
if "dependency_cache" not in request.context: | |
request.context["dependency_cache"] = {} | |
except AttributeError: | |
request.context = {"dependency_cache": {}} | |
async with AsyncExitStack() as async_exit_stack: | |
request.scope["sberrystack"] = AsyncExitStack() | |
solved_result = await deputils.solve_dependencies( | |
request=request, | |
dependant=self.dependant, | |
body={}, | |
dependency_overrides_provider=request.app, | |
dependency_cache=request.context["dependency_cache"], | |
async_exit_stack=async_exit_stack, | |
embed_body_fields=False, | |
) | |
solved_values = solved_result.values | |
new_cache = solved_result.dependency_cache | |
request.context["dependency_cache"].update(new_cache) | |
kwargs = solved_values | kwargs # solved_values has None values that need to be overridden by kwargs | |
res = await next_(source, info, **kwargs) | |
return res |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment