Created
November 11, 2021 17:13
-
-
Save shirblc/cea7dfd0996d5ede90b9e0a46ad48e48 to your computer and use it in GitHub Desktop.
SQLAlchemy ORM Read/Write Sharding
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import os | |
from sqlalchemy.ext.horizontal_shard import ShardedSession, ShardedQuery | |
from sqlalchemy.orm import declarative_base, sessionmaker | |
from sqlalchemy import ( | |
create_engine, | |
Column, | |
Integer, | |
String, | |
ForeignKey, | |
) | |
# Sharding Setup | |
# ----------------------------------------------------------------- | |
# Based on https://docs.sqlalchemy.org/en/14/_modules/examples/sharding/attribute_shard.html | |
shards = { | |
"read": create_engine( | |
os.environ.get("READ_DB_URL") | |
), | |
"write": create_engine( | |
os.environ.get("WRITE_DB_URL") | |
), | |
} | |
def shard_chooser(mapper, instance, clause=None): | |
"""shard chooser. | |
By default returns write since that's the main DB.""" | |
return "write" | |
def id_chooser(query, ident): | |
"""id chooser. | |
given a primary key, returns a list of shards | |
to search. here, we don't have any particular information from a | |
pk so we just return all shard ids. often, you'd want to do some | |
kind of round-robin strategy here so that requests are evenly | |
distributed among DBs. | |
Adjusted from https://docs.sqlalchemy.org/en/14/_modules/examples/sharding/attribute_shard.html | |
""" | |
if query.lazy_loaded_from: | |
# if we are in a lazy load, we can look at the parent object | |
# and limit our search to that same shard, assuming that's how we've | |
# set things up. | |
return [query.lazy_loaded_from.identity_token] | |
else: | |
return ["read", "write"] | |
def execute_chooser(query): | |
"""execute chooser. | |
this also returns a list of shard ids, which can | |
just be all of them. By default returns the write db | |
Adjusted from https://docs.sqlalchemy.org/en/14/_modules/examples/sharding/attribute_shard.html | |
""" | |
return ["write"] | |
# Engine | |
Session = sessionmaker(class_=ShardedSession) | |
Session.configure( | |
shards=shards, | |
shard_chooser=shard_chooser, | |
id_chooser=id_chooser, | |
execute_chooser=execute_chooser, | |
query_cls=ShardedQuery, | |
) | |
# Models | |
# ----------------------------------------------------------------- | |
BaseModel = declarative_base() | |
class Category(BaseModel): | |
__tablename__ = "categories" | |
id = Column(Integer, primary_key=True) | |
name = Column(String(50), unique=True, nullable=False) | |
parent_id = Column(Integer, ForeignKey("categories.id")) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from models import Session, Category | |
# Query from write db | |
def query_categories(): | |
session = Session() | |
categories = session.query(Category).limit(10) | |
for category in categories: | |
print({ | |
'id': category.id, | |
'name': category.name, | |
'parent_id': category.parent_id | |
}) | |
# Query from read replica | |
def query_categories_shard(): | |
session = Session() | |
categories_query = session.query(Category).set_shard("read") | |
categories = categories_query.limit(10) | |
for category in categories: | |
print({ | |
'id': category.id, | |
'name': category.name, | |
'parent_id': category.parent_id | |
}) | |
# Add - adds to write db | |
def add_category(): | |
category = Category(name="lalala2") | |
try: | |
session.add(category) | |
session.commit() | |
except: | |
session.rollback() | |
finally: | |
session.close() | |
def update_category(): | |
session = Session() | |
category = session.query(Category).get(1) | |
category.name = "updated lalala" | |
try: | |
session.add(category) | |
session.commit() | |
except: | |
session.rollback() | |
finally: | |
session.close() | |
def delete_category(): | |
session = Session() | |
category = session.query(Category).get(1) | |
try: | |
session.delete(category) | |
session.commit() | |
except Exception as e: | |
session.rollback() | |
finally: | |
session.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment