Skip to content

Instantly share code, notes, and snippets.

@aurthurm
Last active June 10, 2025 12:09
Show Gist options
  • Save aurthurm/0a654760174edc3f27f7ef80b905da72 to your computer and use it in GitHub Desktop.
Save aurthurm/0a654760174edc3f27f7ef80b905da72 to your computer and use it in GitHub Desktop.
from typing import Any, Mapping, Generic, TypeVar
from pydantic import BaseModel
from sqlalchemy import Column, String, select
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy_mixins import ReprMixin, SerializeMixin, SmartQueryMixin, SessionMixin
from sqlalchemy_mixins.utils import classproperty
from custom_flake_uid import get_flake_uid
def new_query(cls):
"""
New implementation of query method that returns select(cls).
"""
return select(cls)
# Remove session-related methods
delattr(SessionMixin, 'set_session')
delattr(SessionMixin, 'session')
SessionMixin.query = classproperty(new_query)
class Base(
DeclarativeBase, ReprMixin, SerializeMixin, SmartQueryMixin, AsyncAttrs
):
__repr__ = ReprMixin.__repr__
__name__: str
__mapper_args__ = {"eager_defaults": True}
__abstract__ = True
uid = Column(
String,
primary_key=True,
index=True,
nullable=False,
default=get_flake_uid,
)
def marshall(self, exclude=None) -> Mapping[str, Any]:
"""convert instance to dict
leverages instance.__dict__
"""
if exclude is None:
exclude = []
exclude.append("_sa_instance_state")
data = self.__dict__
return_data = {}
for field in data:
if field not in exclude:
_v = data[field]
# if isinstance(_v, datetime):
# _v = format_datetime(_v, human_format=False, with_time=True)
return_data[field] = _v
return return_data
class User(Base):
__tablename__ = "user"
first_name = Column(String, index=True)
last_name = Column(String, index=True)
email = Column(String, unique=True, index=True, nullable=False)
mobile_phone = Column(String, nullable=True)
business_phone = Column(String, nullable=True)
user_name = Column(String, unique=True, index=True, nullable=False)
hashed_password = Column(String, nullable=False)
login_retry = Column(Integer)
is_blocked = Column(Boolean(), default=False)
avatar = Column(String, nullable=True)
bio = Column(String, nullable=True)
default_route = Column(Boolean(), nullable=True)
is_active = Column(Boolean(), default=True)
is_superuser = Column(Boolean(), default=False)
@property
def has_password(self):
return True if self.hashed_password else False
@property
def full_name(self):
return f"{self.first_name} {self.last_name}"
class UserBase(BaseModel):
email: Optional[EmailStr] = None
first_name: str | None = None
last_name: str | None = None
password: str | None = None
user_name: str | None = None
avatar: str | None = None
bio: str | None = None
default_route: str | None = None
groups: Optional[Group] = []
login_retry: int | None = 0
is_blocked: bool | None = False
is_active: bool | None = True
is_superuser: bool = False
class UserCreate(UserBase):
pass
class UserUpdate(UserBase):
pass
M = TypeVar("M", bound=Base)
class BaseRepository(Generic[M]):
async_session = async_session
model: M = None
def __init__(self, model: M) -> None:
self.model = model
async def save(self, m: M) -> M:
if not m:
raise ValueError("No model provided to save") # noqa
async with self.async_session() as session:
try:
session.add(m)
# try:
# session.add(m)
# except Exception:
# await session.merge(m)
await session.flush()
await session.commit()
except Exception:
await session.rollback()
raise
return m
async def save_all(self, items):
if not items:
raise ValueError("No items provided to save")
async with self.async_session() as session:
try:
session.add_all(items)
await session.flush()
await session.commit()
except Exception:
await session.rollback()
raise
return items
async def create(self, **kwargs) -> M:
if not kwargs:
raise ValueError("No data provided to create a new model")
filled = self.model.fill(self.model(), **kwargs)
return await self.save(filled)
async def bulk_create(self, bulk: list[dict]) -> list[M]:
if not bulk:
raise ValueError("No data provided to create a new models")
to_save = []
for data in bulk:
fill = self.model.fill(self.model(), **data)
to_save.append(fill)
return await self.save_all(to_save)
async def update(self, uid: str, **data) -> M:
if not uid or not data:
raise ValueError("Both uid and data are required to update model")
item = await self.get(uid=uid)
filled = self.model.fill(item, **data)
return await self.save(filled)
...
class UserRepository(BaseRepository[User]):
def __init__(self) -> None:
super().__init__(User)
E = TypeVar("E", bound=BaseEntity)
C = TypeVar("C", bound=BaseModel)
U = TypeVar("U", bound=BaseModel)
class BaseService(Generic[E, C, U]):
def __init__(self, repository) -> None:
self.repository: BaseRepository = repository()
async def paging_filter(
self,
page_size: int | None = None,
after_cursor: str | None = None,
before_cursor: str | None = None,
filters: list[dict] | dict = None,
sort_by: list[str] | None = None,
**kwargs
):
return await self.repository.paginate(
page_size, after_cursor, before_cursor, filters, sort_by, **kwargs
)
async def search(self, **kwargs) -> list[E]:
return await self.repository.search(**kwargs)
async def all(self) -> list[E]:
return await self.repository.all()
async def get(self, **kwargs) -> E:
return await self.repository.get(**kwargs)
async def get_by_uids(self, uids: list[str]) -> list[E]:
return await self.repository.get_by_uids(uids)
async def get_all(self, **kwargs) -> list[E]:
return await self.repository.get_all(**kwargs)
async def get_related(self, related: list[str], **kwargs) -> E:
return await self.repository.get_related(related=related, **kwargs)
async def create(self, c: C | dict, related: list[str] = None) -> E:
data = self._import(c)
created = await self.repository.create(**data)
if not related:
return created
return await self.get_related(related=related, uid=created.uid)
async def bulk_create(self, bulk: list[dict | C], related: list[str] = None) -> list[E]:
created = await self.repository.bulk_create([self._import(b) for b in bulk])
if not related:
return created
return [(await self.get_related(related=related, uid=x.uid)) for x in created]
async def update(self, uid: str, update: U | dict, related: list[str] = None) -> E:
if "uid" in update:
del update["uid"]
updated = await self.repository.update(uid, **self._import(update))
if not related:
return updated
return await self.get_related(related=related, uid=updated.uid)
async def save(self, entity: E, related: list[str] = None) -> E:
saved = await self.repository.save(entity)
if not related:
return saved
return await self.get_related(related=related, uid=saved.uid)
async def delete(self, uid: str) -> None:
return await self.repository.delete(uid)
....
@classmethod
def _import(cls, schema_in: C | U | dict) -> dict:
"""Convert Pydantic schema to dict"""
if isinstance(schema_in, dict):
return schema_in
return schema_in.model_dump(exclude_unset=True)
class UserService(BaseService[User, UserCreate, UserUpdate]):
def __init__(self) -> None:
super().__init__(UserRepository)
async def create(self, user_in: UserCreate, related: list[str] = None) -> User:
by_username = await self.get_by_username(user_in.user_name)
if by_username:
raise AlreadyExistsError("Username already exist")
policy = password_check(user_in.password, user_in.user_name)
if not policy["password_ok"]:
raise ValidationError(policy["message"])
hashed_password = get_password_hash(user_in.password)
data = self._import(user_in)
del data["password"]
data["hashed_password"] = hashed_password
return await super().create(data, related=related)
async def update(self, user_uid: str, user_in: UserUpdate) -> User:
update_data = self._import(user_in)
if "password" in update_data:
policy = password_check(user_in.password, user_in.user_name)
if not policy["password_ok"]:
raise Exception(policy["message"])
hashed_password = get_password_hash(update_data["password"])
del update_data["password"]
update_data["hashed_password"] = hashed_password
if "user" in update_data:
del update_data["user"]
return await super().update(user_uid, update_data)
async def has_access(self, user: User, password: str):
if user.is_blocked:
raise Exception("Blocked Account: Reset Password to regain access")
if not user.is_active:
raise Exception("In active account: contact administrator")
if not verify_password(password, user.hashed_password):
msg = ""
retries = user.login_retry
if user.login_retry < 3:
msg = f"Wrong Password {2 - retries} attempts left"
user.login_retry = user.login_retry + 1
if user.login_retry == 3:
user.is_blocked = True
msg = "Sorry your Account has been Blocked"
await self.save(user)
raise Exception(msg)
if user.login_retry != 0:
user.login_retry = 0
await self.save(user)
return user
async def authenticate(self, username, password):
if is_valid_email(username):
raise Exception("Use your username authenticate")
user = await self.get_by_username(username)
return self.has_access(user, password)
async def get_by_email(self, email):
user = await self.get(email=email)
if not user:
return None
return user
async def get_by_username(self, username) -> User:
return await self.get(user_name=username)
async def give_super_powers(self, user_uid: str):
user = self.get(user_uid)
user_obj = marshaller(user)
user_in = UserUpdate(**{**user_obj, "is_superuser": True})
await self.update(user_uid, user_in)
async def strip_super_powers(self, user_uid: str):
user = self.get(user_uid)
user_obj = marshaller(user)
user_in = UserUpdate(**{**user_obj, "is_superuser": False})
await self.update(user_uid, user_in)
async def activate(self, user_uid: str):
user = self.get(user_uid)
user_obj = marshaller(user)
user_in = UserUpdate(**{**user_obj, "is_active": True})
await super().update(user_uid, user_in)
async def deactivate(self, user_uid: str):
user = self.get(user_uid)
user_obj = marshaller(user)
user_in = UserUpdate(**{**user_obj, "is_active": False})
await super().update(user_uid, user_in)
@aurthurm
Copy link
Author

aurthurm commented Sep 1, 2024

For a detailed use case follow: https://github.com/beak-insights/felicity-lims

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment