Created
April 9, 2025 16:17
-
-
Save neepoo/41fd0f09e4561ee91b2fbe8ca5bab515 to your computer and use it in GitHub Desktop.
sqla_basic
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datetime | |
from sqlalchemy import Column, Integer, String, Float, ForeignKey, create_engine, select, func, Table | |
from sqlalchemy.orm import declarative_base, Mapped, relationship, sessionmaker | |
Base = declarative_base() | |
# 关联表 | |
user_favorites = Table('user_favorites', Base.metadata, | |
Column('user_id', Integer, ForeignKey('users.id'), primary_key=True), | |
Column('product_id', Integer, ForeignKey('products.id'), primary_key=True) | |
) | |
class User(Base): | |
__tablename__ = 'users' | |
id = Column(Integer, primary_key=True) | |
name = Column(String(50), nullable=False) | |
email = Column(String(100), unique=True) | |
age = Column(Integer) | |
# 关系 | |
orders = relationship("Order", back_populates="user") | |
favorite_products = relationship("Product", secondary=user_favorites, | |
back_populates="favorited_by") | |
def __repr__(self): | |
return f"<User(id={self.id}, name='{self.name}', email='{self.email}')>" | |
class Product(Base): | |
__tablename__ = 'products' | |
id = Column(Integer, primary_key=True) | |
name = Column(String(100), nullable=False) | |
price = Column(Float, nullable=False) | |
# 关系 | |
order_items = relationship("OrderItem", back_populates="product") | |
favorited_by = relationship("User", secondary=user_favorites, | |
back_populates="favorite_products") | |
def __repr__(self): | |
return f"<Product(id={self.id}, name='{self.name}', price={self.price})>" | |
class Order(Base): | |
__tablename__ = 'orders' | |
id = Column(Integer, primary_key=True) | |
user_id = Column(Integer, ForeignKey('users.id')) | |
order_date = Column(String(20), default=lambda: datetime.datetime.now().strftime("%Y-%m-%d")) | |
total_amount = Column(Float, default=0.0) | |
# 关系 | |
user = relationship("User", back_populates="orders") | |
items = relationship("OrderItem", back_populates="order") | |
def __repr__(self): | |
return f"<Order(id={self.id}, user_id={self.user_id}, total={self.total_amount})>" | |
class OrderItem(Base): | |
__tablename__ = 'order_items' | |
id = Column(Integer, primary_key=True) | |
order_id = Column(Integer, ForeignKey('orders.id')) | |
product_id = Column(Integer, ForeignKey('products.id')) | |
quantity = Column(Integer, default=1) | |
unit_price = Column(Float) | |
# 关系 | |
order = relationship("Order", back_populates="items") | |
product = relationship("Product", back_populates="order_items") | |
def __repr__(self): | |
return f"<OrderItem(id={self.id}, product_id={self.product_id}, quantity={self.quantity})>" | |
# 创建内存数据库 | |
engine = create_engine('sqlite:///:memory:', echo=True) | |
Base.metadata.create_all(engine) | |
# 创建会话 | |
Session = sessionmaker(bind=engine) | |
session = Session() | |
def create_sample_data(): | |
# 添加用户 | |
users = [ | |
User(name="张三", email="[email protected]", age=30), | |
User(name="李四", email="[email protected]", age=25), | |
User(name="王五", email="[email protected]", age=35), | |
User(name="赵六", email="[email protected]", age=28) | |
] | |
session.add_all(users) | |
session.flush() | |
for u in users: | |
print(u) | |
# 添加产品 | |
products = [ | |
Product(name="笔记本电脑", price=6999.99), | |
Product(name="智能手机", price=3999.00), | |
Product(name="耳机", price=999.00), | |
Product(name="鼠标", price=199.00), | |
Product(name="键盘", price=399.00) | |
] | |
session.add_all(products) | |
session.flush() | |
# 添加订单和订单项 | |
# 张三的订单 | |
order1 = Order( | |
user_id=1, | |
order_date=datetime.datetime.now().strftime("%Y-%m-%d"), | |
total_amount=1111 | |
) | |
session.add(order1) | |
session.flush() | |
order1_items = [ | |
OrderItem( | |
order_id=order1.id, | |
product_id=1, | |
quantity=2, | |
unit_price=6999.99 | |
), | |
OrderItem( | |
order_id=order1.id, | |
product_id=3, | |
quantity=1, | |
unit_price=999.00 | |
) | |
] | |
session.add_all(order1_items) | |
# 李四的订单 | |
order2 = Order(user_id=2, order_date="2025-04-03", total_amount=4998.00) | |
session.add(order2) | |
session.flush() | |
order_items2 = [ | |
OrderItem(order_id=order2.id, product_id=2, quantity=1, unit_price=3999.00), | |
OrderItem(order_id=order2.id, product_id=3, quantity=1, unit_price=999.00) | |
] | |
session.add_all(order_items2) | |
# 王五的订单 | |
order3 = Order(user_id=3, order_date="2025-04-05", total_amount=7796.99) | |
session.add(order3) | |
session.flush() | |
order_items3 = [ | |
OrderItem(order_id=order3.id, product_id=1, quantity=1, unit_price=6999.99), | |
OrderItem(order_id=order3.id, product_id=3, quantity=1, unit_price=999.00), | |
OrderItem(order_id=order3.id, product_id=4, quantity=1, unit_price=199.00) | |
] | |
session.add_all(order_items3) | |
# 张三的第二个订单 | |
order4 = Order(user_id=1, order_date="2025-04-05", total_amount=399.0) | |
session.add(order4) | |
session.flush() | |
order_items4 = [ | |
OrderItem(order_id=order4.id, product_id=5, quantity=1, unit_price=399.00) | |
] | |
session.add_all(order_items4) | |
session.commit() | |
create_sample_data() | |
def print_section(title): | |
print("\n" + "=" * 80) | |
print(f" {title}") | |
print("=" * 80) | |
def example_demo(): | |
# scalar 获取单个值 | |
print_section("示例1: scalar() - 获取单个值") | |
# 计算用户总数 | |
stmt = select(func.count(User.id)) | |
result = session.execute(stmt) | |
total_users = result.scalar() | |
print(f"总用户数: {total_users}") | |
# 查找特定用户的id | |
stmt = select(User.id).where(User.name == "张三") | |
result = session.execute(stmt) | |
user_id = result.scalar() | |
print(f"用户张三的ID: {user_id}") | |
# 获取最高价格的产品 | |
stmt = select(func.max(Product.price)) | |
result = session.execute(stmt) | |
max_price = result.scalar() | |
print(f"最高价格的产品: {max_price}") | |
# 使用 scalar_one_or_none() 处理可能没有结果的情况 | |
stmt = select(User.name).where(User.id == 999) # 不存在的用户ID | |
name = session.execute(stmt).scalar_one_or_none() | |
print(f"ID为999的用户名: {name}") # 输出: None | |
# 示例 2: scalars() - 获取单列多行数据 | |
print_section("示例 2: scalars() - 获取单列多行数据") | |
# 获取所有用户的姓名 | |
stmt = select(User.name) | |
result = session.execute(stmt) | |
names = result.scalars().all() | |
print("所有用户的姓名:", names) | |
# 获取所有产品 | |
stmt = select(Product) | |
result = session.execute(stmt) | |
products = result.scalars().all() | |
print("所有产品:", products) | |
# 获取订单金额>5600的订单id | |
stmt = select(Order.id).where(Order.total_amount > 5000) | |
result = session.execute(stmt) | |
orders = result.scalars().all() | |
print(orders) | |
# 使用session.scalars()的简写方式 | |
product_names = session.scalars(select(Product.name)).all() | |
print(f"所有产品名称: {product_names}") | |
# 示例 3: 多列查询 - 不使用 scalars() | |
print_section("示例 3: 多列查询 - 不使用 scalars()") | |
# 用户ID和名称 | |
stmt = select(User.id, User.name) | |
result = session.execute(stmt) | |
print("用户ID和名称:") | |
for user_id, user_name in result: | |
print(f" ID: {user_id}, 名称: {user_name}") | |
# 产品名称和价格 | |
stmt = select(Product.name, Product.price) | |
result = session.execute(stmt) | |
print("\n产品名称和价格:") | |
for row in result: | |
print(f" 产品: {row[0]}, 价格: {row[1]}") | |
# 示例 5: 错误使用 scalars() 处理多列查询 | |
print_section("示例 5: 错误使用 scalars() 处理多列查询") | |
stmt = select(Order.id, Order.total_amount, User.name) | |
stmt = stmt.join(User) | |
result = session.execute(stmt) | |
# scalars() 只返回第一列(Order.id)的值 | |
order_ids = result.scalars().all() | |
print(f"使用 scalars() 处理多列查询,只获取到第一列值: {order_ids}") | |
# 示例 6: 复杂查询和结果处理 | |
print_section("示例 6: 复杂查询和结果处理") | |
# 每个用户的订单数量和消费总额 | |
stmt = select( | |
User.name, | |
func.count(Order.id).label("order_count"), | |
func.sum(Order.total_amount).label("total_amount") | |
).join(Order).group_by(User.id) | |
result = session.execute(stmt) | |
print("用户订单统计:") | |
for name, order_count, total_spent in result: | |
# 处理 NULL 值 (用户可能没有订单) | |
total_spent = total_spent or 0 | |
print(f" {name}: {order_count} 个订单, 总消费 ¥{total_spent:.2f}") | |
# 示例 7: 实体查询与关系 | |
print_section("示例 7: 实体查询与关系") | |
# 查询用户实体和他们的订单数量 | |
stmt = select( | |
User, | |
func.count(Order.id).label("order_count"), | |
).join(Order).group_by(User.id) | |
result = session.execute(stmt) | |
print("用户及其订单数:") | |
for row in result: | |
user = row[0] # User 实体 | |
order_count = row[1] # 计数值 | |
print(f" {user.name} ({user.email}): {order_count} 个订单") | |
# 示例 8: 混合使用 scalar() 和标准结果处理 | |
print_section("示例 8: 混合使用 scalar() 和标准结果处理") | |
user_id = session.query(User.id).where(User.name == "张三").scalar() | |
# 获取张三的所有订单 | |
stmt = select(Order).where(Order.user_id == user_id) | |
orders = session.execute(stmt).scalars().all() | |
print(f"张三的订单") | |
for order in orders: | |
print(f" 订单 #{order.id} - 日期: {order.order_date}, 金额: ¥{order.total_amount}") | |
# print_section("延迟访问") | |
# # 获取订单项 | |
# print(order.items) | |
stmt = select(OrderItem, Product.name).join(Product).where(OrderItem.order_id == order.id) | |
items_result = session.execute(stmt) | |
print(" 订单项:") | |
for item, product_name in items_result: | |
total = item.quantity * item.unit_price | |
print(f" - {product_name} x{item.quantity}: ¥{total:.2f}") | |
# 示例 9: 使用 unique() 和 scalars() 组合 | |
print_section("示例 9: 使用 unique() 和 scalars() 组合") | |
# 获取所有不同的年龄 | |
stmt = select(User.age) | |
unique_ages = session.execute(stmt).scalars().unique().all() | |
print(f"不同的年龄值: {unique_ages}") | |
# 所有订单中出现的产品ID (去重) | |
stmt = select(OrderItem.product_id) | |
product_ids = session.execute(stmt).scalars().unique().all() | |
print(f"订单中的产品ID: {product_ids}") | |
print_section("示例 xx: 多对多关系,用户和商品收藏") | |
user = session.get(User, 1) | |
p1 = session.get(Product, 1) | |
p2 = session.get(Product, 2) | |
user.favorite_products.append(p1) | |
user.favorite_products.append(p2) | |
session.flush() | |
print(p1.favorited_by) | |
session.commit() | |
example_demo() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment