Last active
October 31, 2024 20:08
-
-
Save exhuma/a27d12c8015d09ae85e863b09667bc44 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This module contains classes that dispatch calls depending on arg-types. | |
The goal is to handle "media-types" for HTTP APIs with well-defined functions. | |
An "incoming" request may contain a payload that needs to be stored in the | |
back-end. We need to "parse/decode" the incoming object depending on media-type. | |
Similarly, when we respond to the request we may need to convert the object to | |
something the client understands. The client can opt-in/-out using the "Accept" | |
media-type. | |
While the implementation in this module does not cover all the needs yet, the | |
dispatchers provide building blocks to build upon. | |
The dispatchers are kept separate for "incoming" requests and "outgoing" | |
responses. Handling for both cases is distinct enough that sharing code would | |
lead to unnecessary complexities. | |
""" | |
from datetime import datetime | |
from inspect import signature | |
from typing import Any, Callable | |
from pydantic import BaseModel | |
class CustomerV1(BaseModel): | |
name: str | |
class CustomerV2(BaseModel): | |
first_name: str | |
last_name: str | |
class CustomerV3(BaseModel): | |
fname: str | |
lname: str | |
class Dispatcher[T_external]: | |
def __init__(self): | |
self._incoming_handlers: dict[ | |
type[T_external], Callable[[Any], Any] | |
] = {} | |
def register[ | |
T_handler: Callable[[Any], Any] | |
](self, func: T_handler) -> T_handler: | |
sig = signature(func) | |
payload = sig.parameters.get("incoming_payload") | |
if payload is None: | |
raise SyntaxError( | |
f"Function {func!r} must have an argument named 'incoming_payload'" | |
) | |
self._incoming_handlers[payload.annotation] = func | |
return func | |
def dispatch(self, value: T_external) -> Any: | |
"""Dispatch based on Pydantic model type.""" | |
model_type = type(value) | |
if model_type in self._incoming_handlers: | |
return self._incoming_handlers[model_type](value) | |
else: | |
raise TypeError(f"No handler registered for type {model_type}") | |
class ReturnTypeDispatcher[T_internal, T_external]: | |
def __init__(self): | |
self._outgoing_handlers: dict[ | |
type[T_internal], dict[type[T_external], Callable[[Any], Any]] | |
] = {} | |
def register[ | |
T_handler: Callable[[Any], Any] | |
](self, func: T_handler) -> T_handler: | |
sig = signature(func) | |
payload = sig.parameters.get("outgoing_payload") | |
if payload is None: | |
raise SyntaxError( | |
f"Function {func!r} must have an argument named 'outgoing_payload'" | |
) | |
if payload.annotation not in self._outgoing_handlers: | |
self._outgoing_handlers[payload.annotation] = {} | |
self._outgoing_handlers[payload.annotation][ | |
sig.return_annotation | |
] = func | |
return func | |
def dispatch( | |
self, return_type: type[T_external], data: T_internal | |
) -> T_internal: | |
"""Dispatch based on the specified return type.""" | |
data_type = type(data) | |
handlers = self._outgoing_handlers[data_type] | |
if return_type in self._outgoing_handlers: | |
return handlers[return_type](data) | |
else: | |
raise TypeError( | |
f"No handler registered for return type {return_type}" | |
) | |
dispatcher = Dispatcher[CustomerV1 | CustomerV2]() | |
return_dispatcher = ReturnTypeDispatcher[ | |
CustomerV1 | CustomerV2, | |
CustomerV1 | CustomerV2, | |
]() | |
@dispatcher.register | |
def handle_v1(incoming_payload: CustomerV1) -> str: | |
return f"V1: {incoming_payload.name}" | |
@dispatcher.register | |
def handle_v2(incoming_payload: CustomerV2) -> str: | |
return f"V2: {incoming_payload.first_name} {incoming_payload.last_name}" | |
@return_dispatcher.register | |
def return_customer_v1_v1(outgoing_payload: CustomerV1) -> CustomerV1: | |
return outgoing_payload | |
@return_dispatcher.register | |
def return_customer_v1_v2(outgoing_payload: CustomerV1) -> CustomerV2: | |
parts = outgoing_payload.name.split() | |
return CustomerV2(first_name=parts[0], last_name=parts[1]) | |
@return_dispatcher.register | |
def return_customer_v2_v1(outgoing_payload: CustomerV2) -> CustomerV1: | |
return CustomerV1( | |
name=f"{outgoing_payload.first_name} {outgoing_payload.last_name}" | |
) | |
@return_dispatcher.register | |
def return_customer_v2_v2(outgoing_payload: CustomerV2) -> CustomerV2: | |
return outgoing_payload | |
@return_dispatcher.register | |
def return_customer_v3_v2(outgoing_payload: CustomerV3) -> CustomerV2: | |
raise ValueError() | |
instance1 = CustomerV1(name="John Doe") | |
instance2 = CustomerV2(first_name="John", last_name="Doe") | |
print(dispatcher.dispatch(instance1)) | |
print(dispatcher.dispatch(instance2)) | |
print("rV1", return_dispatcher.dispatch(CustomerV1, instance1)) | |
print("rV2", return_dispatcher.dispatch(CustomerV2, instance1)) | |
print("rV1", return_dispatcher.dispatch(CustomerV1, instance2)) | |
print("rV2", return_dispatcher.dispatch(CustomerV2, instance2)) | |
def process_request(customer: CustomerV1 | CustomerV2 | CustomerV3): | |
output = dispatcher.dispatch(customer) | |
return output | |
print(datetime.now()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment