You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
399 lines
12 KiB
Python
399 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import secrets
|
|
import uuid
|
|
from collections.abc import Iterable
|
|
from datetime import datetime, timedelta
|
|
from typing import Any
|
|
|
|
from fastapi import Depends, Request
|
|
from sqlalchemy import Select, func, or_, select
|
|
from sqlalchemy.orm import Session, selectinload
|
|
|
|
from iti.auth import (
|
|
Principal,
|
|
StaticPermissionProvider,
|
|
create_access_token,
|
|
create_refresh_token,
|
|
require_user,
|
|
)
|
|
from iti.db import get_db
|
|
from iti.exceptions import BizError, Unauthorized
|
|
from iti.responses import pagination
|
|
|
|
from iti_system.enums import MenuTypeEnum, StatusEnum
|
|
from iti_system.models import (
|
|
Role,
|
|
SysConfig,
|
|
SysDept,
|
|
SysDictData,
|
|
SysDictType,
|
|
SysFile,
|
|
SysLog,
|
|
SysMenu,
|
|
SysUserAttribute,
|
|
User,
|
|
)
|
|
|
|
|
|
class SystemPermissionProvider(StaticPermissionProvider):
|
|
def load_principal(self, principal_id: str, request: Request) -> Principal | None:
|
|
factory = getattr(request.app.state, "db_sessionmaker", None)
|
|
if factory is None:
|
|
return None
|
|
with factory() as db:
|
|
user = db.scalar(
|
|
select(User)
|
|
.where(User.id == principal_id, User.status == StatusEnum.ENABLED.value)
|
|
.options(selectinload(User.roles).selectinload(Role.menus))
|
|
)
|
|
if user is None:
|
|
return None
|
|
return Principal(
|
|
id=user.id,
|
|
permissions=frozenset(user.permissions),
|
|
roles=frozenset(user.role_codes),
|
|
)
|
|
|
|
def has_permission(self, principal: Principal, code: str) -> bool:
|
|
return "ADMIN" in principal.roles or code in principal.permissions
|
|
|
|
|
|
def current_user(
|
|
principal: Principal = Depends(require_user),
|
|
db: Session = Depends(get_db),
|
|
) -> User:
|
|
user = db.scalar(
|
|
select(User)
|
|
.where(User.id == principal.id)
|
|
.options(
|
|
selectinload(User.roles).selectinload(Role.menus),
|
|
selectinload(User.depts),
|
|
selectinload(User.attributes),
|
|
)
|
|
)
|
|
if user is None:
|
|
raise Unauthorized("用户不存在或已失效")
|
|
return user
|
|
|
|
|
|
def paginate_query(db: Session, stmt: Select, *, page: int = 1, size: int = 10) -> dict[str, Any]:
|
|
page = max(page, 1)
|
|
size = max(size, 1)
|
|
total = db.scalar(select(func.count()).select_from(stmt.order_by(None).subquery())) or 0
|
|
items = db.scalars(stmt.limit(size).offset((page - 1) * size)).all()
|
|
return {"items": items, "page": pagination(page=page, size=size, total=total)}
|
|
|
|
|
|
def apply_keyword(stmt: Select, keyword: str | None, *columns) -> Select:
|
|
if not keyword:
|
|
return stmt
|
|
like = f"%{keyword}%"
|
|
return stmt.where(or_(*(column.like(like) for column in columns)))
|
|
|
|
|
|
def get_or_404(db: Session, model, item_id: str, message: str = "数据不存在"):
|
|
item = db.get(model, item_id)
|
|
if item is None:
|
|
raise BizError(message, code=404)
|
|
return item
|
|
|
|
|
|
def dump_user(user: User) -> dict[str, Any]:
|
|
role_codes = user.role_codes
|
|
return {
|
|
"id": user.id,
|
|
"username": user.username,
|
|
"phone": user.phone,
|
|
"email": user.email,
|
|
"realname": user.realname,
|
|
"desc": user.desc,
|
|
"avatar": user.avatar,
|
|
"gender": user.gender,
|
|
"status": user.status,
|
|
"roles": [role.id for role in user.roles],
|
|
"roleCodes": role_codes,
|
|
"depts": user.dept_ids,
|
|
"permissions": user.permissions,
|
|
"isSuper": "ADMIN" in role_codes,
|
|
"attributes": user.attribute_map,
|
|
"createdAt": user.created_at,
|
|
"updatedAt": user.updated_at,
|
|
}
|
|
|
|
|
|
def dump_role(role: Role) -> dict[str, Any]:
|
|
return {
|
|
"id": role.id,
|
|
"name": role.name,
|
|
"code": role.code,
|
|
"desc": role.desc,
|
|
"sort": role.sort,
|
|
"status": role.status,
|
|
"permissions": [menu.id for menu in role.menus],
|
|
"createdAt": role.created_at,
|
|
"updatedAt": role.updated_at,
|
|
}
|
|
|
|
|
|
def dump_menu(menu: SysMenu, *, with_children: bool = False) -> dict[str, Any]:
|
|
data = {
|
|
"id": menu.id,
|
|
"name": menu.name,
|
|
"type": menu.type,
|
|
"path": menu.path,
|
|
"component": menu.component,
|
|
"redirect": menu.redirect,
|
|
"sort": menu.sort,
|
|
"authCode": menu.auth_code,
|
|
"meta": menu.meta or {},
|
|
"status": menu.status,
|
|
"parentId": menu.parent_id,
|
|
"createdAt": menu.created_at,
|
|
"updatedAt": menu.updated_at,
|
|
}
|
|
if with_children:
|
|
data["children"] = [dump_menu(child, with_children=True) for child in sorted(menu.children, key=lambda item: (item.sort, item.name))]
|
|
return data
|
|
|
|
|
|
def build_tree(items: Iterable[Any], dumper) -> list[dict[str, Any]]:
|
|
nodes = list(items)
|
|
children_map: dict[str | None, list[Any]] = {}
|
|
by_id = {item.id: item for item in nodes}
|
|
for item in nodes:
|
|
parent_id = getattr(item, "parent_id", None)
|
|
if parent_id not in by_id:
|
|
parent_id = None
|
|
children_map.setdefault(parent_id, []).append(item)
|
|
|
|
def convert(item) -> dict[str, Any]:
|
|
data = dumper(item)
|
|
data["children"] = [
|
|
convert(child)
|
|
for child in sorted(children_map.get(item.id, []), key=lambda value: (value.sort, value.name))
|
|
]
|
|
return data
|
|
|
|
return [
|
|
convert(item)
|
|
for item in sorted(children_map.get(None, []), key=lambda value: (value.sort, value.name))
|
|
]
|
|
|
|
|
|
def visible_menu_tree_for_user(user: User, menus: Iterable[SysMenu]) -> list[dict[str, Any]]:
|
|
visible_menus = [menu for menu in menus if menu.status == StatusEnum.ENABLED.value and menu.type != MenuTypeEnum.BUTTON.value]
|
|
if "ADMIN" in user.role_codes:
|
|
return build_tree(visible_menus, dump_menu)
|
|
|
|
direct_menu_ids = {
|
|
menu.id
|
|
for role in user.roles
|
|
for menu in role.menus
|
|
if menu.status == StatusEnum.ENABLED.value
|
|
}
|
|
permissions = set(user.permissions)
|
|
menu_by_id = {menu.id: menu for menu in visible_menus}
|
|
visible_ids: set[str] = set()
|
|
|
|
for menu in visible_menus:
|
|
if menu.id not in direct_menu_ids and (not menu.auth_code or menu.auth_code not in permissions):
|
|
continue
|
|
current: SysMenu | None = menu
|
|
while current is not None and current.id not in visible_ids:
|
|
visible_ids.add(current.id)
|
|
current = menu_by_id.get(current.parent_id)
|
|
|
|
return build_tree([menu for menu in visible_menus if menu.id in visible_ids], dump_menu)
|
|
|
|
|
|
def dump_dept(dept: SysDept, *, with_children: bool = False) -> dict[str, Any]:
|
|
data = {
|
|
"id": dept.id,
|
|
"name": dept.name,
|
|
"parentId": dept.parent_id,
|
|
"desc": dept.desc,
|
|
"sort": dept.sort,
|
|
"leaderId": dept.leader_id,
|
|
"status": dept.status,
|
|
"createdAt": dept.created_at,
|
|
"updatedAt": dept.updated_at,
|
|
}
|
|
if with_children:
|
|
data["children"] = [dump_dept(child, with_children=True) for child in sorted(dept.children, key=lambda item: (item.sort, item.name))]
|
|
return data
|
|
|
|
|
|
def dump_config(config: SysConfig) -> dict[str, Any]:
|
|
return {
|
|
"id": config.id,
|
|
"type": config.type,
|
|
"name": config.name,
|
|
"code": config.code,
|
|
"value": config.value,
|
|
"desc": config.desc,
|
|
"sort": config.sort,
|
|
"status": config.status,
|
|
"createdAt": config.created_at,
|
|
"updatedAt": config.updated_at,
|
|
}
|
|
|
|
|
|
def dump_dict_type(item: SysDictType, *, with_data: bool = False) -> dict[str, Any]:
|
|
data = {
|
|
"id": item.id,
|
|
"typeName": item.type_name,
|
|
"typeCode": item.type_code,
|
|
"desc": item.desc,
|
|
"sort": item.sort,
|
|
"status": item.status,
|
|
"createdAt": item.created_at,
|
|
"updatedAt": item.updated_at,
|
|
}
|
|
if with_data:
|
|
data["dataList"] = [dump_dict_data(value) for value in sorted(item.data_list, key=lambda value: (value.sort, value.label))]
|
|
return data
|
|
|
|
|
|
def dump_dict_data(item: SysDictData) -> dict[str, Any]:
|
|
return {
|
|
"id": item.id,
|
|
"typeCode": item.type_code,
|
|
"label": item.label,
|
|
"code": item.code,
|
|
"value": item.value,
|
|
"desc": item.desc,
|
|
"sort": item.sort,
|
|
"status": item.status,
|
|
"createdAt": item.created_at,
|
|
"updatedAt": item.updated_at,
|
|
}
|
|
|
|
|
|
def dump_log(item: SysLog) -> dict[str, Any]:
|
|
return {
|
|
"id": item.id,
|
|
"name": item.name,
|
|
"method": item.method,
|
|
"userId": item.user_id,
|
|
"path": item.path,
|
|
"ip": item.ip,
|
|
"userAgent": item.user_agent,
|
|
"executionTime": item.execution_time,
|
|
"success": item.success,
|
|
"desc": item.desc,
|
|
"type": item.type,
|
|
"createdAt": item.created_at,
|
|
"updatedAt": item.updated_at,
|
|
}
|
|
|
|
|
|
def dump_file(item: SysFile) -> dict[str, Any]:
|
|
return {
|
|
"id": item.id,
|
|
"filename": item.filename,
|
|
"fileKey": item.file_key,
|
|
"fileHash": item.file_hash,
|
|
"mimeType": item.mime_type,
|
|
"fileSize": item.file_size,
|
|
"extension": item.extension,
|
|
"storageType": item.storage_type,
|
|
"storageInfo": item.storage_info or {},
|
|
"directoryId": item.directory_id,
|
|
"metadata": item.metadata_ or {},
|
|
"isDeleted": item.is_deleted,
|
|
"shareCode": item.share_code,
|
|
"shareExpireAt": item.share_expire_at,
|
|
"shareCount": item.share_count,
|
|
"status": item.status,
|
|
"url": f"/file/{item.id}/download",
|
|
"createdAt": item.created_at,
|
|
"updatedAt": item.updated_at,
|
|
}
|
|
|
|
|
|
def create_token_payload(user: User, config) -> dict[str, Any]:
|
|
return {
|
|
"accessToken": create_access_token(user.id, config),
|
|
"tokenType": "Bearer",
|
|
"expiresIn": config.jwt_access_token_expires_seconds,
|
|
"refreshToken": create_refresh_token(user.id, config),
|
|
"refreshExpiresIn": config.jwt_refresh_token_expires_seconds,
|
|
"user": dump_user(user),
|
|
}
|
|
|
|
|
|
def find_login_user(db: Session, *, username: str | None = None, phone: str | None = None, email: str | None = None) -> User | None:
|
|
stmt = select(User).options(
|
|
selectinload(User.roles).selectinload(Role.menus),
|
|
selectinload(User.depts),
|
|
selectinload(User.attributes),
|
|
)
|
|
if username:
|
|
stmt = stmt.where(User.username == username)
|
|
elif phone:
|
|
stmt = stmt.where(User.phone == phone)
|
|
elif email:
|
|
stmt = stmt.where(User.email == email)
|
|
else:
|
|
return None
|
|
return db.scalar(stmt)
|
|
|
|
|
|
def assert_unique_user(db: Session, *, user_id: str | None = None, username: str | None = None, phone: str | None = None, email: str | None = None) -> None:
|
|
checks = [(User.username, username, "用户名已存在"), (User.phone, phone, "手机号已存在"), (User.email, email, "邮箱已存在")]
|
|
for column, value, message in checks:
|
|
if not value:
|
|
continue
|
|
stmt = select(User).where(column == value)
|
|
if user_id:
|
|
stmt = stmt.where(User.id != user_id)
|
|
if db.scalar(stmt):
|
|
raise BizError(message)
|
|
|
|
|
|
def bind_roles(db: Session, role_ids: list[str]) -> list[Role]:
|
|
if not role_ids:
|
|
return []
|
|
return list(db.scalars(select(Role).where(Role.id.in_(role_ids))))
|
|
|
|
|
|
def bind_depts(db: Session, dept_ids: list[str]) -> list[SysDept]:
|
|
if not dept_ids:
|
|
return []
|
|
return list(db.scalars(select(SysDept).where(SysDept.id.in_(dept_ids))))
|
|
|
|
|
|
def bind_menus(db: Session, menu_ids: list[str]) -> list[SysMenu]:
|
|
if not menu_ids:
|
|
return []
|
|
return list(db.scalars(select(SysMenu).where(SysMenu.id.in_(menu_ids))))
|
|
|
|
|
|
def upsert_attribute(user: User, group: str, key: str, value: Any, attr_type: str = "string", description: str | None = None, sort: int = 0) -> SysUserAttribute:
|
|
for item in user.attributes:
|
|
if item.attr_group == group and item.attr_key == key:
|
|
item.attr_type = attr_type
|
|
item.description = description
|
|
item.sort = sort
|
|
item.set_typed_value(value)
|
|
return item
|
|
item = SysUserAttribute(attr_group=group, attr_key=key, attr_type=attr_type, description=description, sort=sort)
|
|
item.set_typed_value(value)
|
|
user.attributes.append(item)
|
|
return item
|
|
|
|
|
|
def make_share_code() -> str:
|
|
return secrets.token_urlsafe(12)
|
|
|
|
|
|
def share_expire(seconds: int | None) -> datetime | None:
|
|
if seconds is None or seconds <= 0:
|
|
return None
|
|
return datetime.now() + timedelta(seconds=seconds)
|
|
|
|
|
|
def new_id() -> str:
|
|
return uuid.uuid4().hex
|