from marshmallow import post_dump from sqlalchemy import distinct, select from iti.applications.extensions import db, jwt, cache_simple from sqlalchemy.ext.hybrid import hybrid_property from werkzeug.security import generate_password_hash, check_password_hash import datetime from .sys_rel_role_menu import sys_role_menu from .sys_rel_user_role import sys_user_role from .sys_menu import SysMenu from sqlalchemy.orm import joinedload from iti.applications.common.enums import GenderEnum, StatusEnum from iti.applications.common.utils import BaseSchema from apiflask.fields import Dict, String, DateTime, Enum, Nested, List from iti.applications.common.crud import BaseModelMixin from .sys_role import Role @jwt.user_lookup_loader def user_lookup_loader(header, payload): """ 用户查找加载器 """ identity = payload.get("sub", None) # 过期时间 exp = payload.get("exp", None) return load_with_cache(identity=identity, exp=exp) def load_with_cache(identity, exp): if identity is None: return None cached = cache_simple.get(key=f"user_{identity}") if cached is not None: return cached dbUser = db.session.scalar( select(User) .filter_by(id=identity) .options( joinedload(User.roles).noload(Role.menus), joinedload(User.depts), joinedload(User.user_attributes), ) ) if dbUser is None: return None dbUser.permissions # 计算需缓存时长 if exp is None: return None expTime = datetime.datetime.fromtimestamp(exp) now = datetime.datetime.now() cache_simple.set( key=f"user_{identity}", value=dbUser, timeout=(expTime - now).seconds ) return dbUser class User(BaseModelMixin): """ 用户表 """ __tablename__ = "sys_user" username = db.Column(db.String(64), nullable=False, comment="用户名") phone = db.Column(db.String(13), nullable=True, comment="手机号") email = db.Column(db.String(255), nullable=True, comment="邮箱") _password = db.Column("password", db.String(255), nullable=False, comment="密码") realname = db.Column(db.String(32), nullable=True, comment="真实姓名") desc = db.Column(db.Text, nullable=True, comment="描述") avatar = db.Column(db.String(255), nullable=True, comment="头像") gender = db.Column( db.Enum(GenderEnum, values_callable=lambda x: [e.value for e in x]), nullable=False, default=GenderEnum.SECURE.value, comment="性别", ) status = db.Column( db.Enum(StatusEnum, values_callable=lambda x: [e.value for e in x]), nullable=False, default=StatusEnum.ENABLED.value, comment="状态", ) # 关系 roles = db.relationship( "Role", secondary="sys_user_role", primaryjoin="User.id == sys_user_role.c.user_id", secondaryjoin="and_(Role.id == sys_user_role.c.role_id, Role.status == 'enabled')", back_populates="users", ) depts = db.relationship( "SysDept", secondary="sys_user_dept", primaryjoin="User.id == sys_user_dept.c.user_id", secondaryjoin="and_(SysDept.id == sys_user_dept.c.dept_id, SysDept.status == 'enabled')", back_populates="users", ) user_attributes = db.relationship( "SysUserAttribute", back_populates="user", lazy="selectin", cascade="all, delete-orphan", order_by="SysUserAttribute.sort", ) @hybrid_property def password(self): return self._password @password.setter def password(self, value): if value is not None: self._password = generate_password_hash(value, method="pbkdf2:sha256") def check_password(self, value) -> bool: return check_password_hash(self._password, value) _permissions = [] @hybrid_property def permissions(self): if len(self._permissions) == 0: self._permissions = self.get_permissions() return self._permissions @permissions.setter def permissions(self, value): self._permissions = value def get_permissions(self): permissions = db.session.scalars( select(distinct(SysMenu.auth_code)) .join(sys_role_menu, SysMenu.id == sys_role_menu.c.menu_id) .join(sys_user_role, sys_role_menu.c.role_id == sys_user_role.c.role_id) .filter( sys_user_role.c.user_id == self.id, SysMenu.status == StatusEnum.ENABLED, SysMenu.auth_code.isnot(None), ) .order_by(SysMenu.auth_code.asc()) ).all() return permissions _attributes = None @hybrid_property def attributes(self): """ 获取用户扩展属性,按分组组织成字典结构 返回格式: {"erp": {"erp_username": "xxx", ...}, "custom": {...}} """ if self._attributes is None: self._attributes = self.get_attributes() return self._attributes @attributes.setter def attributes(self, value): self._attributes = value def get_attributes(self): """ 将用户扩展属性转换为分组字典 """ result = {} for attr in self.user_attributes: if attr.attr_group not in result: result[attr.attr_group] = {} result[attr.attr_group][attr.attr_key] = attr.get_typed_value() return result def set_attributes(self, attributes_dict: dict): """ 批量设置用户扩展属性 :param attributes_dict: {"erp": {"erp_username": "xxx", ...}, ...} """ from .sys_user_attribute import SysUserAttribute for group, attrs in attributes_dict.items(): for key, value in attrs.items(): # 查找是否已存在 existing = next( ( attr for attr in self.user_attributes if attr.attr_group == group and attr.attr_key == key ), None, ) if existing: existing.set_typed_value(value) else: # 创建新属性 new_attr = SysUserAttribute( user_id=self.id, attr_group=group, attr_key=key, attr_type="string", # 默认类型 ) new_attr.set_typed_value(value) self.user_attributes.append(new_attr) # 清除缓存 self._attributes = None class UserSchema(BaseSchema): def __init__(self, *args, **kwargs): self.roles_type = kwargs.pop("roles_type", "code") # code | id super().__init__(*args, **kwargs) id = String() username = String() phone = String() email = String() password = String(load_only=True) realname = String() avatar = String() gender = Enum(GenderEnum, by_value=True) status = Enum(StatusEnum, by_value=True) desc = String() created_at = DateTime(data_key="createdAt", format="%Y-%m-%d %H:%M:%S") updated_at = DateTime(data_key="updatedAt", format="%Y-%m-%d %H:%M:%S") # 关系 roles = Nested("RoleSchema", many=True, dump_only=True, exclude=["users"]) depts = Nested( "SysDeptSchema", many=True, dump_only=True, exclude=["users", "children", "parent"], ) permissions = List(String()) attributes = Dict(dump_only=True) @post_dump def patch_roles(self, data, **kwargs): """ 角色code列表 """ if "roles" in data: if self.roles_type == "code": role_codes = [role["code"] for role in data["roles"]] else: role_codes = [role["id"] for role in data["roles"]] data["roles"] = role_codes # 部门id列表 if "depts" in data: dept_ids = [dept["id"] for dept in data["depts"]] data["depts"] = dept_ids return data