Skip to content

Instantly share code, notes, and snippets.

@Tishka17
Created May 19, 2025 13:48
Show Gist options
  • Save Tishka17/8b909f9e8cd1697d1d6f424d2975f0cf to your computer and use it in GitHub Desktop.
Save Tishka17/8b909f9e8cd1697d1d6f424d2975f0cf to your computer and use it in GitHub Desktop.
JwtIdProvider
class JwtIdProvider(IdentityProvider):
def __init__(self, jwk_uri: str, aud: str, context: ServicerContext) -> None:
self.jwks_client = PyJWKClient(jwk_uri)
self.aud = aud
self.context = context
async def get_user(self) -> User | None:
for header, value in self.context.invocation_metadata():
if header != "authorization":
continue
if not value.startswith("Bearer "):
continue
token = value.removeprefix("Bearer").strip()
try:
signing_key = await self.jwks_client.get_signing_key_from_jwt(token)
decoded = jwt.decode(
token,
signing_key.key,
audience=self.aud,
algorithms=["RS256"],
)
except jwt.PyJWTError as e:
logger.warning("Failed to decode JWT: %s", e)
raise ApiError(status=StatusCode.UNAUTHENTICATED, details="Invalid JWT") from e
roles = decoded.get("resource_access", {}).get(self.aud, {}).get("roles", [])
return User(
id=decoded.get("email") or decoded.get("preferred_username"),
roles=roles,
)
return None
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment