Last active
June 10, 2025 12:09
-
-
Save aurthurm/0a654760174edc3f27f7ef80b905da72 to your computer and use it in GitHub Desktop.
Repository Pattern with https://github.com/absent1706/sqlalchemy-mixins
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any, Mapping, Generic, TypeVar | |
from pydantic import BaseModel | |
from sqlalchemy import Column, String, select | |
from sqlalchemy.ext.asyncio import AsyncAttrs | |
from sqlalchemy.orm import DeclarativeBase | |
from sqlalchemy_mixins import ReprMixin, SerializeMixin, SmartQueryMixin, SessionMixin | |
from sqlalchemy_mixins.utils import classproperty | |
from custom_flake_uid import get_flake_uid | |
def new_query(cls): | |
""" | |
New implementation of query method that returns select(cls). | |
""" | |
return select(cls) | |
# Remove session-related methods | |
delattr(SessionMixin, 'set_session') | |
delattr(SessionMixin, 'session') | |
SessionMixin.query = classproperty(new_query) | |
class Base( | |
DeclarativeBase, ReprMixin, SerializeMixin, SmartQueryMixin, AsyncAttrs | |
): | |
__repr__ = ReprMixin.__repr__ | |
__name__: str | |
__mapper_args__ = {"eager_defaults": True} | |
__abstract__ = True | |
uid = Column( | |
String, | |
primary_key=True, | |
index=True, | |
nullable=False, | |
default=get_flake_uid, | |
) | |
def marshall(self, exclude=None) -> Mapping[str, Any]: | |
"""convert instance to dict | |
leverages instance.__dict__ | |
""" | |
if exclude is None: | |
exclude = [] | |
exclude.append("_sa_instance_state") | |
data = self.__dict__ | |
return_data = {} | |
for field in data: | |
if field not in exclude: | |
_v = data[field] | |
# if isinstance(_v, datetime): | |
# _v = format_datetime(_v, human_format=False, with_time=True) | |
return_data[field] = _v | |
return return_data | |
class User(Base): | |
__tablename__ = "user" | |
first_name = Column(String, index=True) | |
last_name = Column(String, index=True) | |
email = Column(String, unique=True, index=True, nullable=False) | |
mobile_phone = Column(String, nullable=True) | |
business_phone = Column(String, nullable=True) | |
user_name = Column(String, unique=True, index=True, nullable=False) | |
hashed_password = Column(String, nullable=False) | |
login_retry = Column(Integer) | |
is_blocked = Column(Boolean(), default=False) | |
avatar = Column(String, nullable=True) | |
bio = Column(String, nullable=True) | |
default_route = Column(Boolean(), nullable=True) | |
is_active = Column(Boolean(), default=True) | |
is_superuser = Column(Boolean(), default=False) | |
@property | |
def has_password(self): | |
return True if self.hashed_password else False | |
@property | |
def full_name(self): | |
return f"{self.first_name} {self.last_name}" | |
class UserBase(BaseModel): | |
email: Optional[EmailStr] = None | |
first_name: str | None = None | |
last_name: str | None = None | |
password: str | None = None | |
user_name: str | None = None | |
avatar: str | None = None | |
bio: str | None = None | |
default_route: str | None = None | |
groups: Optional[Group] = [] | |
login_retry: int | None = 0 | |
is_blocked: bool | None = False | |
is_active: bool | None = True | |
is_superuser: bool = False | |
class UserCreate(UserBase): | |
pass | |
class UserUpdate(UserBase): | |
pass | |
M = TypeVar("M", bound=Base) | |
class BaseRepository(Generic[M]): | |
async_session = async_session | |
model: M = None | |
def __init__(self, model: M) -> None: | |
self.model = model | |
async def save(self, m: M) -> M: | |
if not m: | |
raise ValueError("No model provided to save") # noqa | |
async with self.async_session() as session: | |
try: | |
session.add(m) | |
# try: | |
# session.add(m) | |
# except Exception: | |
# await session.merge(m) | |
await session.flush() | |
await session.commit() | |
except Exception: | |
await session.rollback() | |
raise | |
return m | |
async def save_all(self, items): | |
if not items: | |
raise ValueError("No items provided to save") | |
async with self.async_session() as session: | |
try: | |
session.add_all(items) | |
await session.flush() | |
await session.commit() | |
except Exception: | |
await session.rollback() | |
raise | |
return items | |
async def create(self, **kwargs) -> M: | |
if not kwargs: | |
raise ValueError("No data provided to create a new model") | |
filled = self.model.fill(self.model(), **kwargs) | |
return await self.save(filled) | |
async def bulk_create(self, bulk: list[dict]) -> list[M]: | |
if not bulk: | |
raise ValueError("No data provided to create a new models") | |
to_save = [] | |
for data in bulk: | |
fill = self.model.fill(self.model(), **data) | |
to_save.append(fill) | |
return await self.save_all(to_save) | |
async def update(self, uid: str, **data) -> M: | |
if not uid or not data: | |
raise ValueError("Both uid and data are required to update model") | |
item = await self.get(uid=uid) | |
filled = self.model.fill(item, **data) | |
return await self.save(filled) | |
... | |
class UserRepository(BaseRepository[User]): | |
def __init__(self) -> None: | |
super().__init__(User) | |
E = TypeVar("E", bound=BaseEntity) | |
C = TypeVar("C", bound=BaseModel) | |
U = TypeVar("U", bound=BaseModel) | |
class BaseService(Generic[E, C, U]): | |
def __init__(self, repository) -> None: | |
self.repository: BaseRepository = repository() | |
async def paging_filter( | |
self, | |
page_size: int | None = None, | |
after_cursor: str | None = None, | |
before_cursor: str | None = None, | |
filters: list[dict] | dict = None, | |
sort_by: list[str] | None = None, | |
**kwargs | |
): | |
return await self.repository.paginate( | |
page_size, after_cursor, before_cursor, filters, sort_by, **kwargs | |
) | |
async def search(self, **kwargs) -> list[E]: | |
return await self.repository.search(**kwargs) | |
async def all(self) -> list[E]: | |
return await self.repository.all() | |
async def get(self, **kwargs) -> E: | |
return await self.repository.get(**kwargs) | |
async def get_by_uids(self, uids: list[str]) -> list[E]: | |
return await self.repository.get_by_uids(uids) | |
async def get_all(self, **kwargs) -> list[E]: | |
return await self.repository.get_all(**kwargs) | |
async def get_related(self, related: list[str], **kwargs) -> E: | |
return await self.repository.get_related(related=related, **kwargs) | |
async def create(self, c: C | dict, related: list[str] = None) -> E: | |
data = self._import(c) | |
created = await self.repository.create(**data) | |
if not related: | |
return created | |
return await self.get_related(related=related, uid=created.uid) | |
async def bulk_create(self, bulk: list[dict | C], related: list[str] = None) -> list[E]: | |
created = await self.repository.bulk_create([self._import(b) for b in bulk]) | |
if not related: | |
return created | |
return [(await self.get_related(related=related, uid=x.uid)) for x in created] | |
async def update(self, uid: str, update: U | dict, related: list[str] = None) -> E: | |
if "uid" in update: | |
del update["uid"] | |
updated = await self.repository.update(uid, **self._import(update)) | |
if not related: | |
return updated | |
return await self.get_related(related=related, uid=updated.uid) | |
async def save(self, entity: E, related: list[str] = None) -> E: | |
saved = await self.repository.save(entity) | |
if not related: | |
return saved | |
return await self.get_related(related=related, uid=saved.uid) | |
async def delete(self, uid: str) -> None: | |
return await self.repository.delete(uid) | |
.... | |
@classmethod | |
def _import(cls, schema_in: C | U | dict) -> dict: | |
"""Convert Pydantic schema to dict""" | |
if isinstance(schema_in, dict): | |
return schema_in | |
return schema_in.model_dump(exclude_unset=True) | |
class UserService(BaseService[User, UserCreate, UserUpdate]): | |
def __init__(self) -> None: | |
super().__init__(UserRepository) | |
async def create(self, user_in: UserCreate, related: list[str] = None) -> User: | |
by_username = await self.get_by_username(user_in.user_name) | |
if by_username: | |
raise AlreadyExistsError("Username already exist") | |
policy = password_check(user_in.password, user_in.user_name) | |
if not policy["password_ok"]: | |
raise ValidationError(policy["message"]) | |
hashed_password = get_password_hash(user_in.password) | |
data = self._import(user_in) | |
del data["password"] | |
data["hashed_password"] = hashed_password | |
return await super().create(data, related=related) | |
async def update(self, user_uid: str, user_in: UserUpdate) -> User: | |
update_data = self._import(user_in) | |
if "password" in update_data: | |
policy = password_check(user_in.password, user_in.user_name) | |
if not policy["password_ok"]: | |
raise Exception(policy["message"]) | |
hashed_password = get_password_hash(update_data["password"]) | |
del update_data["password"] | |
update_data["hashed_password"] = hashed_password | |
if "user" in update_data: | |
del update_data["user"] | |
return await super().update(user_uid, update_data) | |
async def has_access(self, user: User, password: str): | |
if user.is_blocked: | |
raise Exception("Blocked Account: Reset Password to regain access") | |
if not user.is_active: | |
raise Exception("In active account: contact administrator") | |
if not verify_password(password, user.hashed_password): | |
msg = "" | |
retries = user.login_retry | |
if user.login_retry < 3: | |
msg = f"Wrong Password {2 - retries} attempts left" | |
user.login_retry = user.login_retry + 1 | |
if user.login_retry == 3: | |
user.is_blocked = True | |
msg = "Sorry your Account has been Blocked" | |
await self.save(user) | |
raise Exception(msg) | |
if user.login_retry != 0: | |
user.login_retry = 0 | |
await self.save(user) | |
return user | |
async def authenticate(self, username, password): | |
if is_valid_email(username): | |
raise Exception("Use your username authenticate") | |
user = await self.get_by_username(username) | |
return self.has_access(user, password) | |
async def get_by_email(self, email): | |
user = await self.get(email=email) | |
if not user: | |
return None | |
return user | |
async def get_by_username(self, username) -> User: | |
return await self.get(user_name=username) | |
async def give_super_powers(self, user_uid: str): | |
user = self.get(user_uid) | |
user_obj = marshaller(user) | |
user_in = UserUpdate(**{**user_obj, "is_superuser": True}) | |
await self.update(user_uid, user_in) | |
async def strip_super_powers(self, user_uid: str): | |
user = self.get(user_uid) | |
user_obj = marshaller(user) | |
user_in = UserUpdate(**{**user_obj, "is_superuser": False}) | |
await self.update(user_uid, user_in) | |
async def activate(self, user_uid: str): | |
user = self.get(user_uid) | |
user_obj = marshaller(user) | |
user_in = UserUpdate(**{**user_obj, "is_active": True}) | |
await super().update(user_uid, user_in) | |
async def deactivate(self, user_uid: str): | |
user = self.get(user_uid) | |
user_obj = marshaller(user) | |
user_in = UserUpdate(**{**user_obj, "is_active": False}) | |
await super().update(user_uid, user_in) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
For a detailed use case follow: https://github.com/beak-insights/felicity-lims