Created
April 9, 2025 16:17
-
-
Save neepoo/49bc87e994ba722908e8ee805bed1764 to your computer and use it in GitHub Desktop.
sqla_async
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession | |
from sqlalchemy.orm import declarative_base, relationship, sessionmaker | |
from sqlalchemy import Column, Integer, String, ForeignKey, select | |
from sqlalchemy.orm import selectinload, joinedload | |
# 创建基类 | |
Base = declarative_base() | |
# 定义模型 | |
class User(Base): | |
__tablename__ = 'users' | |
id = Column(Integer, primary_key=True) | |
name = Column(String(50), nullable=False) | |
email = Column(String(100)) | |
# 设置关系 - 一对多 | |
posts = relationship("Post", back_populates="author") | |
# 设置关系 - 一对一 | |
profile = relationship("UserProfile", back_populates="user", uselist=False) | |
def __repr__(self): | |
return f"<User(id={self.id}, name='{self.name}')>" | |
class UserProfile(Base): | |
__tablename__ = 'user_profiles' | |
id = Column(Integer, primary_key=True) | |
user_id = Column(Integer, ForeignKey('users.id'), unique=True) | |
bio = Column(String(200)) | |
location = Column(String(100)) | |
# 设置关系 - 多对一 | |
user = relationship("User", back_populates="profile") | |
def __repr__(self): | |
return f"<UserProfile(user_id={self.user_id}, location='{self.location}')>" | |
class Post(Base): | |
__tablename__ = 'posts' | |
id = Column(Integer, primary_key=True) | |
title = Column(String(100), nullable=False) | |
content = Column(String(500)) | |
author_id = Column(Integer, ForeignKey('users.id')) | |
# 设置关系 - 多对一 | |
author = relationship("User", back_populates="posts") | |
# 设置关系 - 一对多 | |
comments = relationship("Comment", back_populates="post") | |
def __repr__(self): | |
return f"<Post(id={self.id}, title='{self.title}')>" | |
class Comment(Base): | |
__tablename__ = 'comments' | |
id = Column(Integer, primary_key=True) | |
content = Column(String(200), nullable=False) | |
post_id = Column(Integer, ForeignKey('posts.id')) | |
user_id = Column(Integer, ForeignKey('users.id')) | |
# 设置关系 | |
post = relationship("Post", back_populates="comments") | |
user = relationship("User") | |
def __repr__(self): | |
return f"<Comment(id={self.id}, content='{self.content[:20]}...')>" | |
# 创建异步引擎 - 使用 aiosqlite | |
async def setup_database(): | |
# SQLite URL 中的 +aiosqlite 是关键,它告诉 SQLAlchemy 使用异步驱动 | |
engine = create_async_engine( | |
"sqlite+aiosqlite:///:memory:", | |
echo=True # 设置为 True 可以查看 SQL 语句,方便调试 | |
) | |
# 创建表 | |
async with engine.begin() as conn: | |
await conn.run_sync(Base.metadata.create_all) | |
# 创建会话工厂 | |
async_session = sessionmaker( | |
engine, | |
expire_on_commit=False, | |
class_=AsyncSession | |
) | |
# 添加示例数据 | |
async with async_session() as session: | |
# 创建用户 | |
user1 = User(id=1, name="张三", email="[email protected]") | |
user2 = User(id=2, name="李四", email="[email protected]") | |
# 创建用户资料 | |
profile1 = UserProfile(user=user1, bio="我是张三", location="北京") | |
profile2 = UserProfile(user=user2, bio="我是李四", location="上海") | |
# 创建文章 | |
post1 = Post(title="张三的第一篇文章", content="内容1", author=user1) | |
post2 = Post(title="张三的第二篇文章", content="内容2", author=user1) | |
post3 = Post(title="李四的文章", content="内容3", author=user2) | |
# 创建评论 | |
comment1 = Comment(content="很好的文章!", post=post1, user=user2) | |
comment2 = Comment(content="我也想写这样的文章", post=post1, user=user2) | |
comment3 = Comment(content="学习了", post=post2, user=user2) | |
# 添加所有对象并提交 | |
session.add_all([user1, user2, profile1, profile2, | |
post1, post2, post3, | |
comment1, comment2, comment3]) | |
await session.commit() | |
return async_session, engine | |
async def demonstrate_lazy_loading(async_session): | |
print("\n=== 演示懒加载 ===") | |
async with async_session() as session: | |
# 获取单个用户 | |
user = await session.get(User, 1) | |
print(f"用户: {user.name}") | |
# 访问关联的用户资料 - 会触发懒加载 | |
print("加载用户资料...") | |
# SQLAlchemy 2.0+ 中的直接访问方式 | |
# 使用 session.refresh 显式加载关系 | |
await session.refresh(user, ["profile"]) | |
profile = user.profile | |
print(f"用户资料: 位于 {profile.location}") | |
# 访问关联的文章 - 会触发懒加载 | |
print("加载用户文章...") | |
await session.refresh(user, ["posts"]) | |
posts = user.posts | |
print(f"用户有 {len(posts)} 篇文章:") | |
# 访问每篇文章的评论 - 每篇文章会触发一次懒加载 | |
for post in posts: | |
print(f" - {post.title}") | |
print(f" 加载文章评论...") | |
await session.refresh(post, ["comments"]) | |
comments = post.comments | |
if comments: | |
print(f" 评论数: {len(comments)}") | |
for comment in comments: | |
# 获取评论的用户 - 会触发懒加载 | |
await session.refresh(comment, ["user"]) | |
comment_user = comment.user | |
print(f" - {comment_user.name}: {comment.content}") | |
async def demonstrate_eager_loading(async_session): | |
print("\n=== 演示急加载 (Eager Loading) ===") | |
async with async_session() as session: | |
# 使用 joinedload 一次性加载用户及其资料 | |
stmt = select(User).where(User.id == 1).options( | |
joinedload(User.profile) | |
) | |
result = await session.execute(stmt) | |
user = result.scalar_one() | |
print(f"用户: {user.name}") | |
# 访问资料不会触发新查询,因为已经预加载了 | |
print(f"用户资料: 位于 {user.profile.location}") | |
# 使用 selectinload 预加载文章和文章评论 | |
stmt = select(User).where(User.id == 1).options( | |
selectinload(User.posts).selectinload(Post.comments).joinedload(Comment.user) | |
) | |
result = await session.execute(stmt) | |
user = result.scalar_one() | |
print("\n预加载文章及评论:") | |
# 这些访问不会触发新查询,因为已经预加载了 | |
for post in user.posts: | |
print(f" - {post.title}") | |
print(f" 评论数: {len(post.comments)}") | |
for comment in post.comments: | |
print(f" - {comment.user.name}: {comment.content}") | |
async def demonstrate_mixed_loading(async_session): | |
print("\n=== 演示混合加载策略 ===") | |
async with async_session() as session: | |
# 只预加载文章,但不预加载评论 | |
stmt = select(User).where(User.id == 1).options( | |
selectinload(User.posts) | |
) | |
result = await session.execute(stmt) | |
user = result.scalar_one() | |
print(f"用户: {user.name}") | |
# 文章已预加载,不会触发查询 | |
print(f"用户有 {len(user.posts)} 篇文章") | |
# 用户资料未预加载,会触发查询 | |
print("加载用户资料...") | |
await session.refresh(user, ["profile"]) | |
profile = user.profile | |
print(f"用户资料: 位于 {profile.location}") | |
# 文章评论未预加载,会触发查询 | |
for post in user.posts: | |
print(f" - {post.title}") | |
print(f" 加载文章评论...") | |
await session.refresh(post, ["comments"]) | |
comments = post.comments | |
if comments: | |
print(f" 评论数: {len(comments)}") | |
async def demonstrate_get_vs_query(async_session): | |
print("\n=== 演示 get() 与其他查询方法对比 ===") | |
async with async_session() as session: | |
# 使用 get() 获取用户 | |
print("使用 get() 获取用户:") | |
user_get = await session.get(User, 1) | |
print(f"结果: {user_get}") | |
# 使用 select() 查询用户 | |
print("\n使用 select() 查询用户:") | |
stmt = select(User).where(User.id == 1) | |
result = await session.execute(stmt) | |
user_select = result.scalar_one() | |
print(f"结果: {user_select}") | |
# 两者是否指向同一对象(身份映射) | |
print(f"\n两者是否相同: {user_get is user_select}") | |
async def demonstrate_relationship_modifications(async_session): | |
print("\n=== 演示关系修改 ===") | |
async with async_session() as session: | |
async with session.begin(): # 自动开始事务 | |
# 获取用户 | |
user = await session.get(User, 1) | |
# 加载 posts 关系 | |
await session.refresh(user, ["posts"]) | |
# 创建新文章并关联到用户 | |
new_post = Post(title="张三的新文章", content="这是一篇新文章") | |
# 直接分配给关系属性 | |
user.posts.append(new_post) | |
# 不需要显式 session.add(new_post),因为关系会处理这个 | |
print(f"添加了新文章: {new_post.title}") | |
# 验证添加成功 | |
async with session.begin(): | |
# 重新加载用户的所有文章 | |
stmt = select(User).where(User.id == 1).options( | |
selectinload(User.posts) | |
) | |
result = await session.execute(stmt) | |
user = result.scalar_one() | |
print(f"用户现在有 {len(user.posts)} 篇文章:") | |
for post in user.posts: | |
print(f" - {post.title}") | |
async def main(): | |
print("SQLAlchemy 异步 ORM 与关系加载演示") | |
print("=====================================") | |
# 设置数据库并获取会话工厂 | |
async_session, engine = await setup_database() | |
try: | |
# 演示懒加载 | |
await demonstrate_lazy_loading(async_session) | |
# 演示急加载 | |
await demonstrate_eager_loading(async_session) | |
# 演示混合加载策略 | |
await demonstrate_mixed_loading(async_session) | |
# 演示 get() 与查询对比 | |
await demonstrate_get_vs_query(async_session) | |
# 演示关系修改 | |
await demonstrate_relationship_modifications(async_session) | |
finally: | |
# 关闭引擎 | |
await engine.dispose() | |
# 运行主程序 | |
if __name__ == "__main__": | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment