You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
iTi-System/iti_system/services.py

399 lines
12 KiB
Python

from __future__ import annotations
import secrets
import uuid
from collections.abc import Iterable
from datetime import datetime, timedelta
from typing import Any
from fastapi import Depends, Request
from sqlalchemy import Select, func, or_, select
from sqlalchemy.orm import Session, selectinload
from iti.auth import (
Principal,
StaticPermissionProvider,
create_access_token,
create_refresh_token,
require_user,
)
from iti.db import get_db
from iti.exceptions import BizError, Unauthorized
from iti.responses import pagination
from iti_system.enums import MenuTypeEnum, StatusEnum
from iti_system.models import (
Role,
SysConfig,
SysDept,
SysDictData,
SysDictType,
SysFile,
SysLog,
SysMenu,
SysUserAttribute,
User,
)
class SystemPermissionProvider(StaticPermissionProvider):
def load_principal(self, principal_id: str, request: Request) -> Principal | None:
factory = getattr(request.app.state, "db_sessionmaker", None)
if factory is None:
return None
with factory() as db:
user = db.scalar(
select(User)
.where(User.id == principal_id, User.status == StatusEnum.ENABLED.value)
.options(selectinload(User.roles).selectinload(Role.menus))
)
if user is None:
return None
return Principal(
id=user.id,
permissions=frozenset(user.permissions),
roles=frozenset(user.role_codes),
)
def has_permission(self, principal: Principal, code: str) -> bool:
return "ADMIN" in principal.roles or code in principal.permissions
def current_user(
principal: Principal = Depends(require_user),
db: Session = Depends(get_db),
) -> User:
user = db.scalar(
select(User)
.where(User.id == principal.id)
.options(
selectinload(User.roles).selectinload(Role.menus),
selectinload(User.depts),
selectinload(User.attributes),
)
)
if user is None:
raise Unauthorized("用户不存在或已失效")
return user
def paginate_query(db: Session, stmt: Select, *, page: int = 1, size: int = 10) -> dict[str, Any]:
page = max(page, 1)
size = max(size, 1)
total = db.scalar(select(func.count()).select_from(stmt.order_by(None).subquery())) or 0
items = db.scalars(stmt.limit(size).offset((page - 1) * size)).all()
return {"items": items, "page": pagination(page=page, size=size, total=total)}
def apply_keyword(stmt: Select, keyword: str | None, *columns) -> Select:
if not keyword:
return stmt
like = f"%{keyword}%"
return stmt.where(or_(*(column.like(like) for column in columns)))
def get_or_404(db: Session, model, item_id: str, message: str = "数据不存在"):
item = db.get(model, item_id)
if item is None:
raise BizError(message, code=404)
return item
def dump_user(user: User) -> dict[str, Any]:
role_codes = user.role_codes
return {
"id": user.id,
"username": user.username,
"phone": user.phone,
"email": user.email,
"realname": user.realname,
"desc": user.desc,
"avatar": user.avatar,
"gender": user.gender,
"status": user.status,
"roles": [role.id for role in user.roles],
"roleCodes": role_codes,
"depts": user.dept_ids,
"permissions": user.permissions,
"isSuper": "ADMIN" in role_codes,
"attributes": user.attribute_map,
"createdAt": user.created_at,
"updatedAt": user.updated_at,
}
def dump_role(role: Role) -> dict[str, Any]:
return {
"id": role.id,
"name": role.name,
"code": role.code,
"desc": role.desc,
"sort": role.sort,
"status": role.status,
"permissions": [menu.id for menu in role.menus],
"createdAt": role.created_at,
"updatedAt": role.updated_at,
}
def dump_menu(menu: SysMenu, *, with_children: bool = False) -> dict[str, Any]:
data = {
"id": menu.id,
"name": menu.name,
"type": menu.type,
"path": menu.path,
"component": menu.component,
"redirect": menu.redirect,
"sort": menu.sort,
"authCode": menu.auth_code,
"meta": menu.meta or {},
"status": menu.status,
"parentId": menu.parent_id,
"createdAt": menu.created_at,
"updatedAt": menu.updated_at,
}
if with_children:
data["children"] = [dump_menu(child, with_children=True) for child in sorted(menu.children, key=lambda item: (item.sort, item.name))]
return data
def build_tree(items: Iterable[Any], dumper) -> list[dict[str, Any]]:
nodes = list(items)
children_map: dict[str | None, list[Any]] = {}
by_id = {item.id: item for item in nodes}
for item in nodes:
parent_id = getattr(item, "parent_id", None)
if parent_id not in by_id:
parent_id = None
children_map.setdefault(parent_id, []).append(item)
def convert(item) -> dict[str, Any]:
data = dumper(item)
data["children"] = [
convert(child)
for child in sorted(children_map.get(item.id, []), key=lambda value: (value.sort, value.name))
]
return data
return [
convert(item)
for item in sorted(children_map.get(None, []), key=lambda value: (value.sort, value.name))
]
def visible_menu_tree_for_user(user: User, menus: Iterable[SysMenu]) -> list[dict[str, Any]]:
visible_menus = [menu for menu in menus if menu.status == StatusEnum.ENABLED.value and menu.type != MenuTypeEnum.BUTTON.value]
if "ADMIN" in user.role_codes:
return build_tree(visible_menus, dump_menu)
direct_menu_ids = {
menu.id
for role in user.roles
for menu in role.menus
if menu.status == StatusEnum.ENABLED.value
}
permissions = set(user.permissions)
menu_by_id = {menu.id: menu for menu in visible_menus}
visible_ids: set[str] = set()
for menu in visible_menus:
if menu.id not in direct_menu_ids and (not menu.auth_code or menu.auth_code not in permissions):
continue
current: SysMenu | None = menu
while current is not None and current.id not in visible_ids:
visible_ids.add(current.id)
current = menu_by_id.get(current.parent_id)
return build_tree([menu for menu in visible_menus if menu.id in visible_ids], dump_menu)
def dump_dept(dept: SysDept, *, with_children: bool = False) -> dict[str, Any]:
data = {
"id": dept.id,
"name": dept.name,
"parentId": dept.parent_id,
"desc": dept.desc,
"sort": dept.sort,
"leaderId": dept.leader_id,
"status": dept.status,
"createdAt": dept.created_at,
"updatedAt": dept.updated_at,
}
if with_children:
data["children"] = [dump_dept(child, with_children=True) for child in sorted(dept.children, key=lambda item: (item.sort, item.name))]
return data
def dump_config(config: SysConfig) -> dict[str, Any]:
return {
"id": config.id,
"type": config.type,
"name": config.name,
"code": config.code,
"value": config.value,
"desc": config.desc,
"sort": config.sort,
"status": config.status,
"createdAt": config.created_at,
"updatedAt": config.updated_at,
}
def dump_dict_type(item: SysDictType, *, with_data: bool = False) -> dict[str, Any]:
data = {
"id": item.id,
"typeName": item.type_name,
"typeCode": item.type_code,
"desc": item.desc,
"sort": item.sort,
"status": item.status,
"createdAt": item.created_at,
"updatedAt": item.updated_at,
}
if with_data:
data["dataList"] = [dump_dict_data(value) for value in sorted(item.data_list, key=lambda value: (value.sort, value.label))]
return data
def dump_dict_data(item: SysDictData) -> dict[str, Any]:
return {
"id": item.id,
"typeCode": item.type_code,
"label": item.label,
"code": item.code,
"value": item.value,
"desc": item.desc,
"sort": item.sort,
"status": item.status,
"createdAt": item.created_at,
"updatedAt": item.updated_at,
}
def dump_log(item: SysLog) -> dict[str, Any]:
return {
"id": item.id,
"name": item.name,
"method": item.method,
"userId": item.user_id,
"path": item.path,
"ip": item.ip,
"userAgent": item.user_agent,
"executionTime": item.execution_time,
"success": item.success,
"desc": item.desc,
"type": item.type,
"createdAt": item.created_at,
"updatedAt": item.updated_at,
}
def dump_file(item: SysFile) -> dict[str, Any]:
return {
"id": item.id,
"filename": item.filename,
"fileKey": item.file_key,
"fileHash": item.file_hash,
"mimeType": item.mime_type,
"fileSize": item.file_size,
"extension": item.extension,
"storageType": item.storage_type,
"storageInfo": item.storage_info or {},
"directoryId": item.directory_id,
"metadata": item.metadata_ or {},
"isDeleted": item.is_deleted,
"shareCode": item.share_code,
"shareExpireAt": item.share_expire_at,
"shareCount": item.share_count,
"status": item.status,
"url": f"/file/{item.id}/download",
"createdAt": item.created_at,
"updatedAt": item.updated_at,
}
def create_token_payload(user: User, config) -> dict[str, Any]:
return {
"accessToken": create_access_token(user.id, config),
"tokenType": "Bearer",
"expiresIn": config.jwt_access_token_expires_seconds,
"refreshToken": create_refresh_token(user.id, config),
"refreshExpiresIn": config.jwt_refresh_token_expires_seconds,
"user": dump_user(user),
}
def find_login_user(db: Session, *, username: str | None = None, phone: str | None = None, email: str | None = None) -> User | None:
stmt = select(User).options(
selectinload(User.roles).selectinload(Role.menus),
selectinload(User.depts),
selectinload(User.attributes),
)
if username:
stmt = stmt.where(User.username == username)
elif phone:
stmt = stmt.where(User.phone == phone)
elif email:
stmt = stmt.where(User.email == email)
else:
return None
return db.scalar(stmt)
def assert_unique_user(db: Session, *, user_id: str | None = None, username: str | None = None, phone: str | None = None, email: str | None = None) -> None:
checks = [(User.username, username, "用户名已存在"), (User.phone, phone, "手机号已存在"), (User.email, email, "邮箱已存在")]
for column, value, message in checks:
if not value:
continue
stmt = select(User).where(column == value)
if user_id:
stmt = stmt.where(User.id != user_id)
if db.scalar(stmt):
raise BizError(message)
def bind_roles(db: Session, role_ids: list[str]) -> list[Role]:
if not role_ids:
return []
return list(db.scalars(select(Role).where(Role.id.in_(role_ids))))
def bind_depts(db: Session, dept_ids: list[str]) -> list[SysDept]:
if not dept_ids:
return []
return list(db.scalars(select(SysDept).where(SysDept.id.in_(dept_ids))))
def bind_menus(db: Session, menu_ids: list[str]) -> list[SysMenu]:
if not menu_ids:
return []
return list(db.scalars(select(SysMenu).where(SysMenu.id.in_(menu_ids))))
def upsert_attribute(user: User, group: str, key: str, value: Any, attr_type: str = "string", description: str | None = None, sort: int = 0) -> SysUserAttribute:
for item in user.attributes:
if item.attr_group == group and item.attr_key == key:
item.attr_type = attr_type
item.description = description
item.sort = sort
item.set_typed_value(value)
return item
item = SysUserAttribute(attr_group=group, attr_key=key, attr_type=attr_type, description=description, sort=sort)
item.set_typed_value(value)
user.attributes.append(item)
return item
def make_share_code() -> str:
return secrets.token_urlsafe(12)
def share_expire(seconds: int | None) -> datetime | None:
if seconds is None or seconds <= 0:
return None
return datetime.now() + timedelta(seconds=seconds)
def new_id() -> str:
return uuid.uuid4().hex