from __future__ import annotations import secrets import uuid from collections.abc import Iterable from datetime import datetime, timedelta from typing import Any from fastapi import Depends, Request from sqlalchemy import Select, func, or_, select from sqlalchemy.orm import Session, selectinload from iti.auth import ( Principal, StaticPermissionProvider, create_access_token, create_refresh_token, require_user, ) from iti.db import get_db from iti.exceptions import BizError, Unauthorized from iti.responses import pagination from iti_system.enums import StatusEnum from iti_system.models import ( Role, SysConfig, SysDept, SysDictData, SysDictType, SysFile, SysLog, SysMenu, SysUserAttribute, User, ) class SystemPermissionProvider(StaticPermissionProvider): def load_principal(self, principal_id: str, request: Request) -> Principal | None: factory = getattr(request.app.state, "db_sessionmaker", None) if factory is None: return None with factory() as db: user = db.scalar( select(User) .where(User.id == principal_id, User.status == StatusEnum.ENABLED.value) .options(selectinload(User.roles).selectinload(Role.menus)) ) if user is None: return None return Principal( id=user.id, permissions=frozenset(user.permissions), roles=frozenset(user.role_codes), ) def has_permission(self, principal: Principal, code: str) -> bool: return "ADMIN" in principal.roles or code in principal.permissions def current_user( principal: Principal = Depends(require_user), db: Session = Depends(get_db), ) -> User: user = db.scalar( select(User) .where(User.id == principal.id) .options( selectinload(User.roles).selectinload(Role.menus), selectinload(User.depts), selectinload(User.attributes), ) ) if user is None: raise Unauthorized("用户不存在或已失效") return user def paginate_query(db: Session, stmt: Select, *, page: int = 1, size: int = 10) -> dict[str, Any]: page = max(page, 1) size = max(size, 1) total = db.scalar(select(func.count()).select_from(stmt.order_by(None).subquery())) or 0 items = db.scalars(stmt.limit(size).offset((page - 1) * size)).all() return {"items": items, "page": pagination(page=page, size=size, total=total)} def apply_keyword(stmt: Select, keyword: str | None, *columns) -> Select: if not keyword: return stmt like = f"%{keyword}%" return stmt.where(or_(*(column.like(like) for column in columns))) def get_or_404(db: Session, model, item_id: str, message: str = "数据不存在"): item = db.get(model, item_id) if item is None: raise BizError(message, code=404) return item def dump_user(user: User) -> dict[str, Any]: return { "id": user.id, "username": user.username, "phone": user.phone, "email": user.email, "realname": user.realname, "desc": user.desc, "avatar": user.avatar, "gender": user.gender, "status": user.status, "roles": [role.id for role in user.roles], "depts": user.dept_ids, "permissions": user.permissions, "attributes": user.attribute_map, "createdAt": user.created_at, "updatedAt": user.updated_at, } def dump_role(role: Role) -> dict[str, Any]: return { "id": role.id, "name": role.name, "code": role.code, "desc": role.desc, "sort": role.sort, "status": role.status, "permissions": [menu.id for menu in role.menus], "createdAt": role.created_at, "updatedAt": role.updated_at, } def dump_menu(menu: SysMenu, *, with_children: bool = False) -> dict[str, Any]: data = { "id": menu.id, "name": menu.name, "type": menu.type, "path": menu.path, "component": menu.component, "redirect": menu.redirect, "sort": menu.sort, "authCode": menu.auth_code, "meta": menu.meta or {}, "status": menu.status, "parentId": menu.parent_id, "createdAt": menu.created_at, "updatedAt": menu.updated_at, } if with_children: data["children"] = [dump_menu(child, with_children=True) for child in sorted(menu.children, key=lambda item: (item.sort, item.name))] return data def build_tree(items: Iterable[Any], dumper) -> list[dict[str, Any]]: nodes = list(items) children_map: dict[str | None, list[Any]] = {} by_id = {item.id: item for item in nodes} for item in nodes: parent_id = getattr(item, "parent_id", None) if parent_id not in by_id: parent_id = None children_map.setdefault(parent_id, []).append(item) def convert(item) -> dict[str, Any]: data = dumper(item) data["children"] = [ convert(child) for child in sorted(children_map.get(item.id, []), key=lambda value: (value.sort, value.name)) ] return data return [ convert(item) for item in sorted(children_map.get(None, []), key=lambda value: (value.sort, value.name)) ] def dump_dept(dept: SysDept, *, with_children: bool = False) -> dict[str, Any]: data = { "id": dept.id, "name": dept.name, "parentId": dept.parent_id, "desc": dept.desc, "sort": dept.sort, "leaderId": dept.leader_id, "status": dept.status, "createdAt": dept.created_at, "updatedAt": dept.updated_at, } if with_children: data["children"] = [dump_dept(child, with_children=True) for child in sorted(dept.children, key=lambda item: (item.sort, item.name))] return data def dump_config(config: SysConfig) -> dict[str, Any]: return { "id": config.id, "type": config.type, "name": config.name, "code": config.code, "value": config.value, "desc": config.desc, "sort": config.sort, "status": config.status, "createdAt": config.created_at, "updatedAt": config.updated_at, } def dump_dict_type(item: SysDictType, *, with_data: bool = False) -> dict[str, Any]: data = { "id": item.id, "typeName": item.type_name, "typeCode": item.type_code, "desc": item.desc, "sort": item.sort, "status": item.status, "createdAt": item.created_at, "updatedAt": item.updated_at, } if with_data: data["dataList"] = [dump_dict_data(value) for value in sorted(item.data_list, key=lambda value: (value.sort, value.label))] return data def dump_dict_data(item: SysDictData) -> dict[str, Any]: return { "id": item.id, "typeCode": item.type_code, "label": item.label, "code": item.code, "value": item.value, "desc": item.desc, "sort": item.sort, "status": item.status, "createdAt": item.created_at, "updatedAt": item.updated_at, } def dump_log(item: SysLog) -> dict[str, Any]: return { "id": item.id, "name": item.name, "method": item.method, "userId": item.user_id, "path": item.path, "ip": item.ip, "userAgent": item.user_agent, "executionTime": item.execution_time, "success": item.success, "desc": item.desc, "type": item.type, "createdAt": item.created_at, "updatedAt": item.updated_at, } def dump_file(item: SysFile) -> dict[str, Any]: return { "id": item.id, "filename": item.filename, "fileKey": item.file_key, "fileHash": item.file_hash, "mimeType": item.mime_type, "fileSize": item.file_size, "extension": item.extension, "storageType": item.storage_type, "storageInfo": item.storage_info or {}, "directoryId": item.directory_id, "metadata": item.metadata_ or {}, "isDeleted": item.is_deleted, "shareCode": item.share_code, "shareExpireAt": item.share_expire_at, "shareCount": item.share_count, "status": item.status, "url": f"/file/{item.id}/download", "createdAt": item.created_at, "updatedAt": item.updated_at, } def create_token_payload(user: User, config) -> dict[str, Any]: return { "accessToken": create_access_token(user.id, config), "tokenType": "Bearer", "expiresIn": config.jwt_access_token_expires_seconds, "refreshToken": create_refresh_token(user.id, config), "refreshExpiresIn": config.jwt_refresh_token_expires_seconds, "user": dump_user(user), } def find_login_user(db: Session, *, username: str | None = None, phone: str | None = None, email: str | None = None) -> User | None: stmt = select(User).options( selectinload(User.roles).selectinload(Role.menus), selectinload(User.depts), selectinload(User.attributes), ) if username: stmt = stmt.where(User.username == username) elif phone: stmt = stmt.where(User.phone == phone) elif email: stmt = stmt.where(User.email == email) else: return None return db.scalar(stmt) def assert_unique_user(db: Session, *, user_id: str | None = None, username: str | None = None, phone: str | None = None, email: str | None = None) -> None: checks = [(User.username, username, "用户名已存在"), (User.phone, phone, "手机号已存在"), (User.email, email, "邮箱已存在")] for column, value, message in checks: if not value: continue stmt = select(User).where(column == value) if user_id: stmt = stmt.where(User.id != user_id) if db.scalar(stmt): raise BizError(message) def bind_roles(db: Session, role_ids: list[str]) -> list[Role]: if not role_ids: return [] return list(db.scalars(select(Role).where(Role.id.in_(role_ids)))) def bind_depts(db: Session, dept_ids: list[str]) -> list[SysDept]: if not dept_ids: return [] return list(db.scalars(select(SysDept).where(SysDept.id.in_(dept_ids)))) def bind_menus(db: Session, menu_ids: list[str]) -> list[SysMenu]: if not menu_ids: return [] return list(db.scalars(select(SysMenu).where(SysMenu.id.in_(menu_ids)))) def upsert_attribute(user: User, group: str, key: str, value: Any, attr_type: str = "string", description: str | None = None, sort: int = 0) -> SysUserAttribute: for item in user.attributes: if item.attr_group == group and item.attr_key == key: item.attr_type = attr_type item.description = description item.sort = sort item.set_typed_value(value) return item item = SysUserAttribute(attr_group=group, attr_key=key, attr_type=attr_type, description=description, sort=sort) item.set_typed_value(value) user.attributes.append(item) return item def make_share_code() -> str: return secrets.token_urlsafe(12) def share_expire(seconds: int | None) -> datetime | None: if seconds is None or seconds <= 0: return None return datetime.now() + timedelta(seconds=seconds) def new_id() -> str: return uuid.uuid4().hex