Last active
February 8, 2024 01:18
-
-
Save justinttl/37acb80fdc1b978e8fc6f68d5b1c3aff to your computer and use it in GitHub Desktop.
FastAPI marshaling performance analysis
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
from time import time | |
from typing import List, Type | |
import matplotlib.pyplot as plt | |
import orjson | |
import pandas as pd | |
from fastapi.encoders import jsonable_encoder | |
from fastapi.utils import create_response_field | |
from pydantic import BaseModel | |
class ThingSchema(BaseModel): | |
name: str | |
class ManyObjectsSchema(BaseModel): | |
things: List[ThingSchema] | |
class ListSchema(BaseModel): | |
things: List[str] | |
def build_many_objects_schema(length: int) -> ManyObjectsSchema: | |
return ManyObjectsSchema(things=[ThingSchema(name=str(i)) for i in range(length)]) | |
def build_list_schema(length: int) -> ListSchema: | |
return ListSchema(things=[str(i) for i in range(length)]) | |
def benchmark(): | |
max_length = 50000 | |
length_step = 1000 | |
lengths = range(0, max_length, length_step) | |
# List Schema | |
print("=== Benchmark for single object schemas ===") | |
list_schema_stats = [] | |
for length in lengths: | |
print(f"Benchmark for length {length}") | |
schema = build_list_schema(length) | |
list_schema_stats.append(time_ops(schema, model=ListSchema)) | |
list_times = pd.DataFrame.from_records(list_schema_stats, index=lengths) | |
# Many Objects Schema | |
print("=== Benchmark for many objects schemas ===") | |
many_objects_schema_stats = [] | |
for length in lengths: | |
print(f"Benchmark for length {length}") | |
schema = build_many_objects_schema(length) | |
many_objects_schema_stats.append(time_ops(schema, model=ManyObjectsSchema)) | |
many_objects_times = pd.DataFrame.from_records( | |
many_objects_schema_stats, index=lengths | |
) | |
# Plotting | |
list_times[["jsonable_encoder", "dict"]].plot( | |
title="Encoding Times (List of str)", | |
xlabel="Number of things", | |
ylabel="Seconds", | |
) | |
plt.savefig("list_encoding.png") | |
plt.clf() | |
many_objects_times[["jsonable_encoder", "dict"]].plot( | |
title="Encoding Times (List of objects)", | |
xlabel="Number of things", | |
ylabel="Seconds", | |
) | |
plt.savefig("many_objects_encoding.png") | |
plt.clf() | |
validate_times = pd.concat( | |
[list_times["validate"], many_objects_times["validate"]], | |
axis=1, | |
keys=["List of str", "List of objects"], | |
) | |
ax = validate_times.plot( | |
title="Validate Times", xlabel="Number of things", ylabel="Seconds" | |
) | |
ax.ticklabel_format(useOffset=False, style='plain') | |
plt.savefig("validate.png") | |
plt.clf() | |
list_times[["json", "orjson"]].plot( | |
title="Serialization Times (List of str)", | |
xlabel="Number of things", | |
ylabel="Seconds", | |
) | |
plt.savefig("list_serialization.png") | |
plt.clf() | |
many_objects_times[["json", "orjson"]].plot( | |
title="Serialization Times (List of objects)", | |
xlabel="Number of things", | |
ylabel="Seconds", | |
) | |
plt.savefig("many_objects_serialization.png") | |
plt.clf() | |
# Time spent percentages (Before vs After) | |
before_list_times = list_times[['jsonable_encoder', 'validate', 'json']] | |
after_list_times = list_times[['dict', 'orjson']] | |
before_list_times_normalized = before_list_times.divide( | |
before_list_times.sum(axis=1), axis=0 | |
) | |
before_list_times_normalized.plot.area( | |
title="Time spent in marshaling phases (Before, List of str)", | |
xlabel="Number of things", | |
ylabel="Percentage", | |
) | |
plt.savefig("list_before_ratio.png") | |
plt.clf() | |
after_list_times_normalized = after_list_times.divide( | |
after_list_times.sum(axis=1), axis=0 | |
) | |
after_list_times_normalized.plot.area( | |
title="Time spent in marshaling phases (After, List of str)", | |
xlabel="Number of things", | |
ylabel="Percentage", | |
) | |
plt.savefig("list_after_ratio.png") | |
plt.clf() | |
before_many_objects_times = many_objects_times[ | |
['jsonable_encoder', 'validate', 'json'] | |
] | |
after_many_objects_times = many_objects_times[['dict', 'orjson']] | |
before_many_objects_times_normalized = before_many_objects_times.divide( | |
before_many_objects_times.sum(axis=1), axis=0 | |
) | |
before_many_objects_times_normalized.plot.area( | |
title="Time spent in marshaling phases (Before, List of objects)", | |
xlabel="Number of things", | |
ylabel="Percentage", | |
) | |
plt.savefig("many_objects_before_ratio.png") | |
plt.clf() | |
after_many_objects_times_normalized = after_many_objects_times.divide( | |
after_many_objects_times.sum(axis=1), axis=0 | |
) | |
after_many_objects_times_normalized.plot.area( | |
title="Time spent in marshaling phases (After, List of objects)", | |
xlabel="Number of things", | |
ylabel="Percentage", | |
) | |
plt.savefig("many_objects_after_ratio.png") | |
plt.clf() | |
# Total time taken (Before vs After) | |
list_before_vs_after = pd.concat( | |
[before_list_times.sum(axis=1), after_list_times.sum(axis=1)], | |
axis=1, | |
keys=["before", "after"], | |
) | |
list_before_vs_after.plot( | |
title="Time spent in marshaling phases (List of str)", | |
xlabel="Number of things", | |
ylabel="Seconds", | |
) | |
plt.savefig("list_before_vs_after.png") | |
plt.clf() | |
many_objects_before_vs_after = pd.concat( | |
[before_many_objects_times.sum(axis=1), after_many_objects_times.sum(axis=1)], | |
axis=1, | |
keys=["before", "after"], | |
) | |
many_objects_before_vs_after.plot( | |
title="Time spent in marshaling phases (List of objects)", | |
xlabel="Number of things", | |
ylabel="Seconds", | |
) | |
plt.savefig("many_objects_before_vs_after.png") | |
plt.clf() | |
def time_ops(schema: BaseModel, model: Type[BaseModel]): | |
# Jsonable Encoder | |
start = time() | |
jsonable_encoded = jsonable_encoder(schema) | |
jsonable_encoder_time = time() - start | |
# .dict() | |
start = time() | |
dict_encoded = schema.dict() | |
dict_time = time() - start | |
assert jsonable_encoded == dict_encoded | |
response_field = create_response_field(name="Testing", type_=model) | |
# Validation | |
start = time() | |
response_field.validate(schema, {}, loc=("response")) | |
validate_time = time() - start | |
# OOTB json | |
start = time() | |
json.dumps(dict_encoded) | |
json_time = time() - start | |
# orjson | |
start = time() | |
orjson.dumps(dict_encoded) | |
orjson_time = time() - start | |
return { | |
"jsonable_encoder": jsonable_encoder_time, | |
"dict": dict_time, | |
"validate": validate_time, | |
"json": json_time, | |
"orjson": orjson_time, | |
} | |
if __name__ == "__main__": | |
benchmark() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment