Skip to content

Instantly share code, notes, and snippets.

@neepoo
Created April 9, 2025 16:17
Show Gist options
  • Save neepoo/41fd0f09e4561ee91b2fbe8ca5bab515 to your computer and use it in GitHub Desktop.
Save neepoo/41fd0f09e4561ee91b2fbe8ca5bab515 to your computer and use it in GitHub Desktop.
sqla_basic
import datetime
from sqlalchemy import Column, Integer, String, Float, ForeignKey, create_engine, select, func, Table
from sqlalchemy.orm import declarative_base, Mapped, relationship, sessionmaker
Base = declarative_base()
# 关联表
user_favorites = Table('user_favorites', Base.metadata,
Column('user_id', Integer, ForeignKey('users.id'), primary_key=True),
Column('product_id', Integer, ForeignKey('products.id'), primary_key=True)
)
class User(Base):
__tablename__ = 'users'
id = Column(Integer, primary_key=True)
name = Column(String(50), nullable=False)
email = Column(String(100), unique=True)
age = Column(Integer)
# 关系
orders = relationship("Order", back_populates="user")
favorite_products = relationship("Product", secondary=user_favorites,
back_populates="favorited_by")
def __repr__(self):
return f"<User(id={self.id}, name='{self.name}', email='{self.email}')>"
class Product(Base):
__tablename__ = 'products'
id = Column(Integer, primary_key=True)
name = Column(String(100), nullable=False)
price = Column(Float, nullable=False)
# 关系
order_items = relationship("OrderItem", back_populates="product")
favorited_by = relationship("User", secondary=user_favorites,
back_populates="favorite_products")
def __repr__(self):
return f"<Product(id={self.id}, name='{self.name}', price={self.price})>"
class Order(Base):
__tablename__ = 'orders'
id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey('users.id'))
order_date = Column(String(20), default=lambda: datetime.datetime.now().strftime("%Y-%m-%d"))
total_amount = Column(Float, default=0.0)
# 关系
user = relationship("User", back_populates="orders")
items = relationship("OrderItem", back_populates="order")
def __repr__(self):
return f"<Order(id={self.id}, user_id={self.user_id}, total={self.total_amount})>"
class OrderItem(Base):
__tablename__ = 'order_items'
id = Column(Integer, primary_key=True)
order_id = Column(Integer, ForeignKey('orders.id'))
product_id = Column(Integer, ForeignKey('products.id'))
quantity = Column(Integer, default=1)
unit_price = Column(Float)
# 关系
order = relationship("Order", back_populates="items")
product = relationship("Product", back_populates="order_items")
def __repr__(self):
return f"<OrderItem(id={self.id}, product_id={self.product_id}, quantity={self.quantity})>"
# 创建内存数据库
engine = create_engine('sqlite:///:memory:', echo=True)
Base.metadata.create_all(engine)
# 创建会话
Session = sessionmaker(bind=engine)
session = Session()
def create_sample_data():
# 添加用户
users = [
User(name="张三", email="[email protected]", age=30),
User(name="李四", email="[email protected]", age=25),
User(name="王五", email="[email protected]", age=35),
User(name="赵六", email="[email protected]", age=28)
]
session.add_all(users)
session.flush()
for u in users:
print(u)
# 添加产品
products = [
Product(name="笔记本电脑", price=6999.99),
Product(name="智能手机", price=3999.00),
Product(name="耳机", price=999.00),
Product(name="鼠标", price=199.00),
Product(name="键盘", price=399.00)
]
session.add_all(products)
session.flush()
# 添加订单和订单项
# 张三的订单
order1 = Order(
user_id=1,
order_date=datetime.datetime.now().strftime("%Y-%m-%d"),
total_amount=1111
)
session.add(order1)
session.flush()
order1_items = [
OrderItem(
order_id=order1.id,
product_id=1,
quantity=2,
unit_price=6999.99
),
OrderItem(
order_id=order1.id,
product_id=3,
quantity=1,
unit_price=999.00
)
]
session.add_all(order1_items)
# 李四的订单
order2 = Order(user_id=2, order_date="2025-04-03", total_amount=4998.00)
session.add(order2)
session.flush()
order_items2 = [
OrderItem(order_id=order2.id, product_id=2, quantity=1, unit_price=3999.00),
OrderItem(order_id=order2.id, product_id=3, quantity=1, unit_price=999.00)
]
session.add_all(order_items2)
# 王五的订单
order3 = Order(user_id=3, order_date="2025-04-05", total_amount=7796.99)
session.add(order3)
session.flush()
order_items3 = [
OrderItem(order_id=order3.id, product_id=1, quantity=1, unit_price=6999.99),
OrderItem(order_id=order3.id, product_id=3, quantity=1, unit_price=999.00),
OrderItem(order_id=order3.id, product_id=4, quantity=1, unit_price=199.00)
]
session.add_all(order_items3)
# 张三的第二个订单
order4 = Order(user_id=1, order_date="2025-04-05", total_amount=399.0)
session.add(order4)
session.flush()
order_items4 = [
OrderItem(order_id=order4.id, product_id=5, quantity=1, unit_price=399.00)
]
session.add_all(order_items4)
session.commit()
create_sample_data()
def print_section(title):
print("\n" + "=" * 80)
print(f" {title}")
print("=" * 80)
def example_demo():
# scalar 获取单个值
print_section("示例1: scalar() - 获取单个值")
# 计算用户总数
stmt = select(func.count(User.id))
result = session.execute(stmt)
total_users = result.scalar()
print(f"总用户数: {total_users}")
# 查找特定用户的id
stmt = select(User.id).where(User.name == "张三")
result = session.execute(stmt)
user_id = result.scalar()
print(f"用户张三的ID: {user_id}")
# 获取最高价格的产品
stmt = select(func.max(Product.price))
result = session.execute(stmt)
max_price = result.scalar()
print(f"最高价格的产品: {max_price}")
# 使用 scalar_one_or_none() 处理可能没有结果的情况
stmt = select(User.name).where(User.id == 999) # 不存在的用户ID
name = session.execute(stmt).scalar_one_or_none()
print(f"ID为999的用户名: {name}") # 输出: None
# 示例 2: scalars() - 获取单列多行数据
print_section("示例 2: scalars() - 获取单列多行数据")
# 获取所有用户的姓名
stmt = select(User.name)
result = session.execute(stmt)
names = result.scalars().all()
print("所有用户的姓名:", names)
# 获取所有产品
stmt = select(Product)
result = session.execute(stmt)
products = result.scalars().all()
print("所有产品:", products)
# 获取订单金额>5600的订单id
stmt = select(Order.id).where(Order.total_amount > 5000)
result = session.execute(stmt)
orders = result.scalars().all()
print(orders)
# 使用session.scalars()的简写方式
product_names = session.scalars(select(Product.name)).all()
print(f"所有产品名称: {product_names}")
# 示例 3: 多列查询 - 不使用 scalars()
print_section("示例 3: 多列查询 - 不使用 scalars()")
# 用户ID和名称
stmt = select(User.id, User.name)
result = session.execute(stmt)
print("用户ID和名称:")
for user_id, user_name in result:
print(f" ID: {user_id}, 名称: {user_name}")
# 产品名称和价格
stmt = select(Product.name, Product.price)
result = session.execute(stmt)
print("\n产品名称和价格:")
for row in result:
print(f" 产品: {row[0]}, 价格: {row[1]}")
# 示例 5: 错误使用 scalars() 处理多列查询
print_section("示例 5: 错误使用 scalars() 处理多列查询")
stmt = select(Order.id, Order.total_amount, User.name)
stmt = stmt.join(User)
result = session.execute(stmt)
# scalars() 只返回第一列(Order.id)的值
order_ids = result.scalars().all()
print(f"使用 scalars() 处理多列查询,只获取到第一列值: {order_ids}")
# 示例 6: 复杂查询和结果处理
print_section("示例 6: 复杂查询和结果处理")
# 每个用户的订单数量和消费总额
stmt = select(
User.name,
func.count(Order.id).label("order_count"),
func.sum(Order.total_amount).label("total_amount")
).join(Order).group_by(User.id)
result = session.execute(stmt)
print("用户订单统计:")
for name, order_count, total_spent in result:
# 处理 NULL 值 (用户可能没有订单)
total_spent = total_spent or 0
print(f" {name}: {order_count} 个订单, 总消费 ¥{total_spent:.2f}")
# 示例 7: 实体查询与关系
print_section("示例 7: 实体查询与关系")
# 查询用户实体和他们的订单数量
stmt = select(
User,
func.count(Order.id).label("order_count"),
).join(Order).group_by(User.id)
result = session.execute(stmt)
print("用户及其订单数:")
for row in result:
user = row[0] # User 实体
order_count = row[1] # 计数值
print(f" {user.name} ({user.email}): {order_count} 个订单")
# 示例 8: 混合使用 scalar() 和标准结果处理
print_section("示例 8: 混合使用 scalar() 和标准结果处理")
user_id = session.query(User.id).where(User.name == "张三").scalar()
# 获取张三的所有订单
stmt = select(Order).where(Order.user_id == user_id)
orders = session.execute(stmt).scalars().all()
print(f"张三的订单")
for order in orders:
print(f" 订单 #{order.id} - 日期: {order.order_date}, 金额: ¥{order.total_amount}")
# print_section("延迟访问")
# # 获取订单项
# print(order.items)
stmt = select(OrderItem, Product.name).join(Product).where(OrderItem.order_id == order.id)
items_result = session.execute(stmt)
print(" 订单项:")
for item, product_name in items_result:
total = item.quantity * item.unit_price
print(f" - {product_name} x{item.quantity}: ¥{total:.2f}")
# 示例 9: 使用 unique() 和 scalars() 组合
print_section("示例 9: 使用 unique() 和 scalars() 组合")
# 获取所有不同的年龄
stmt = select(User.age)
unique_ages = session.execute(stmt).scalars().unique().all()
print(f"不同的年龄值: {unique_ages}")
# 所有订单中出现的产品ID (去重)
stmt = select(OrderItem.product_id)
product_ids = session.execute(stmt).scalars().unique().all()
print(f"订单中的产品ID: {product_ids}")
print_section("示例 xx: 多对多关系,用户和商品收藏")
user = session.get(User, 1)
p1 = session.get(Product, 1)
p2 = session.get(Product, 2)
user.favorite_products.append(p1)
user.favorite_products.append(p2)
session.flush()
print(p1.favorited_by)
session.commit()
example_demo()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment