Skip to content

Instantly share code, notes, and snippets.

@neepoo
Created April 9, 2025 16:17
Show Gist options
  • Save neepoo/49bc87e994ba722908e8ee805bed1764 to your computer and use it in GitHub Desktop.
Save neepoo/49bc87e994ba722908e8ee805bed1764 to your computer and use it in GitHub Desktop.
sqla_async
import asyncio
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import declarative_base, relationship, sessionmaker
from sqlalchemy import Column, Integer, String, ForeignKey, select
from sqlalchemy.orm import selectinload, joinedload
# 创建基类
Base = declarative_base()
# 定义模型
class User(Base):
__tablename__ = 'users'
id = Column(Integer, primary_key=True)
name = Column(String(50), nullable=False)
email = Column(String(100))
# 设置关系 - 一对多
posts = relationship("Post", back_populates="author")
# 设置关系 - 一对一
profile = relationship("UserProfile", back_populates="user", uselist=False)
def __repr__(self):
return f"<User(id={self.id}, name='{self.name}')>"
class UserProfile(Base):
__tablename__ = 'user_profiles'
id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey('users.id'), unique=True)
bio = Column(String(200))
location = Column(String(100))
# 设置关系 - 多对一
user = relationship("User", back_populates="profile")
def __repr__(self):
return f"<UserProfile(user_id={self.user_id}, location='{self.location}')>"
class Post(Base):
__tablename__ = 'posts'
id = Column(Integer, primary_key=True)
title = Column(String(100), nullable=False)
content = Column(String(500))
author_id = Column(Integer, ForeignKey('users.id'))
# 设置关系 - 多对一
author = relationship("User", back_populates="posts")
# 设置关系 - 一对多
comments = relationship("Comment", back_populates="post")
def __repr__(self):
return f"<Post(id={self.id}, title='{self.title}')>"
class Comment(Base):
__tablename__ = 'comments'
id = Column(Integer, primary_key=True)
content = Column(String(200), nullable=False)
post_id = Column(Integer, ForeignKey('posts.id'))
user_id = Column(Integer, ForeignKey('users.id'))
# 设置关系
post = relationship("Post", back_populates="comments")
user = relationship("User")
def __repr__(self):
return f"<Comment(id={self.id}, content='{self.content[:20]}...')>"
# 创建异步引擎 - 使用 aiosqlite
async def setup_database():
# SQLite URL 中的 +aiosqlite 是关键,它告诉 SQLAlchemy 使用异步驱动
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
echo=True # 设置为 True 可以查看 SQL 语句,方便调试
)
# 创建表
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# 创建会话工厂
async_session = sessionmaker(
engine,
expire_on_commit=False,
class_=AsyncSession
)
# 添加示例数据
async with async_session() as session:
# 创建用户
user1 = User(id=1, name="张三", email="[email protected]")
user2 = User(id=2, name="李四", email="[email protected]")
# 创建用户资料
profile1 = UserProfile(user=user1, bio="我是张三", location="北京")
profile2 = UserProfile(user=user2, bio="我是李四", location="上海")
# 创建文章
post1 = Post(title="张三的第一篇文章", content="内容1", author=user1)
post2 = Post(title="张三的第二篇文章", content="内容2", author=user1)
post3 = Post(title="李四的文章", content="内容3", author=user2)
# 创建评论
comment1 = Comment(content="很好的文章!", post=post1, user=user2)
comment2 = Comment(content="我也想写这样的文章", post=post1, user=user2)
comment3 = Comment(content="学习了", post=post2, user=user2)
# 添加所有对象并提交
session.add_all([user1, user2, profile1, profile2,
post1, post2, post3,
comment1, comment2, comment3])
await session.commit()
return async_session, engine
async def demonstrate_lazy_loading(async_session):
print("\n=== 演示懒加载 ===")
async with async_session() as session:
# 获取单个用户
user = await session.get(User, 1)
print(f"用户: {user.name}")
# 访问关联的用户资料 - 会触发懒加载
print("加载用户资料...")
# SQLAlchemy 2.0+ 中的直接访问方式
# 使用 session.refresh 显式加载关系
await session.refresh(user, ["profile"])
profile = user.profile
print(f"用户资料: 位于 {profile.location}")
# 访问关联的文章 - 会触发懒加载
print("加载用户文章...")
await session.refresh(user, ["posts"])
posts = user.posts
print(f"用户有 {len(posts)} 篇文章:")
# 访问每篇文章的评论 - 每篇文章会触发一次懒加载
for post in posts:
print(f" - {post.title}")
print(f" 加载文章评论...")
await session.refresh(post, ["comments"])
comments = post.comments
if comments:
print(f" 评论数: {len(comments)}")
for comment in comments:
# 获取评论的用户 - 会触发懒加载
await session.refresh(comment, ["user"])
comment_user = comment.user
print(f" - {comment_user.name}: {comment.content}")
async def demonstrate_eager_loading(async_session):
print("\n=== 演示急加载 (Eager Loading) ===")
async with async_session() as session:
# 使用 joinedload 一次性加载用户及其资料
stmt = select(User).where(User.id == 1).options(
joinedload(User.profile)
)
result = await session.execute(stmt)
user = result.scalar_one()
print(f"用户: {user.name}")
# 访问资料不会触发新查询,因为已经预加载了
print(f"用户资料: 位于 {user.profile.location}")
# 使用 selectinload 预加载文章和文章评论
stmt = select(User).where(User.id == 1).options(
selectinload(User.posts).selectinload(Post.comments).joinedload(Comment.user)
)
result = await session.execute(stmt)
user = result.scalar_one()
print("\n预加载文章及评论:")
# 这些访问不会触发新查询,因为已经预加载了
for post in user.posts:
print(f" - {post.title}")
print(f" 评论数: {len(post.comments)}")
for comment in post.comments:
print(f" - {comment.user.name}: {comment.content}")
async def demonstrate_mixed_loading(async_session):
print("\n=== 演示混合加载策略 ===")
async with async_session() as session:
# 只预加载文章,但不预加载评论
stmt = select(User).where(User.id == 1).options(
selectinload(User.posts)
)
result = await session.execute(stmt)
user = result.scalar_one()
print(f"用户: {user.name}")
# 文章已预加载,不会触发查询
print(f"用户有 {len(user.posts)} 篇文章")
# 用户资料未预加载,会触发查询
print("加载用户资料...")
await session.refresh(user, ["profile"])
profile = user.profile
print(f"用户资料: 位于 {profile.location}")
# 文章评论未预加载,会触发查询
for post in user.posts:
print(f" - {post.title}")
print(f" 加载文章评论...")
await session.refresh(post, ["comments"])
comments = post.comments
if comments:
print(f" 评论数: {len(comments)}")
async def demonstrate_get_vs_query(async_session):
print("\n=== 演示 get() 与其他查询方法对比 ===")
async with async_session() as session:
# 使用 get() 获取用户
print("使用 get() 获取用户:")
user_get = await session.get(User, 1)
print(f"结果: {user_get}")
# 使用 select() 查询用户
print("\n使用 select() 查询用户:")
stmt = select(User).where(User.id == 1)
result = await session.execute(stmt)
user_select = result.scalar_one()
print(f"结果: {user_select}")
# 两者是否指向同一对象(身份映射)
print(f"\n两者是否相同: {user_get is user_select}")
async def demonstrate_relationship_modifications(async_session):
print("\n=== 演示关系修改 ===")
async with async_session() as session:
async with session.begin(): # 自动开始事务
# 获取用户
user = await session.get(User, 1)
# 加载 posts 关系
await session.refresh(user, ["posts"])
# 创建新文章并关联到用户
new_post = Post(title="张三的新文章", content="这是一篇新文章")
# 直接分配给关系属性
user.posts.append(new_post)
# 不需要显式 session.add(new_post),因为关系会处理这个
print(f"添加了新文章: {new_post.title}")
# 验证添加成功
async with session.begin():
# 重新加载用户的所有文章
stmt = select(User).where(User.id == 1).options(
selectinload(User.posts)
)
result = await session.execute(stmt)
user = result.scalar_one()
print(f"用户现在有 {len(user.posts)} 篇文章:")
for post in user.posts:
print(f" - {post.title}")
async def main():
print("SQLAlchemy 异步 ORM 与关系加载演示")
print("=====================================")
# 设置数据库并获取会话工厂
async_session, engine = await setup_database()
try:
# 演示懒加载
await demonstrate_lazy_loading(async_session)
# 演示急加载
await demonstrate_eager_loading(async_session)
# 演示混合加载策略
await demonstrate_mixed_loading(async_session)
# 演示 get() 与查询对比
await demonstrate_get_vs_query(async_session)
# 演示关系修改
await demonstrate_relationship_modifications(async_session)
finally:
# 关闭引擎
await engine.dispose()
# 运行主程序
if __name__ == "__main__":
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment