You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

190 lines
6.9 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from dataclasses import field
from marshmallow_dataclass import dataclass
from apiflask.fields import Integer, URL, Field
from marshmallow import fields
from apiflask import Schema
from .str import camel_case
# ==================== Schema 定义 ====================
class BaseSchema(Schema):
"""
基础 Schema 扩展
1. 有序返回
2. 未知字段不报错
"""
class Meta:
ordered = True
unknown = "INCLUDE"
def on_bind_field(self, field_name: str, field_obj: Field) -> None:
"""
绑定字段时处理
1.统一驼峰命名返回(小驼峰)
"""
if field_obj.data_key is None:
field_obj.data_key = camel_case(field_name)
pass
@dataclass(base_schema=Schema)
class Pagination:
page: int = field(default=1, metadata={"metadata": {"description": "当前页码"}})
size: int = field(default=10, metadata={"metadata": {"description": "每页数量"}})
pages: int = field(default=0, metadata={"metadata": {"description": "总页数"}})
total: int = field(default=0, metadata={"metadata": {"description": "总记录数"}})
current: str = field(
default="",
metadata={"metadata": {"description": "当前页URL"}, "dump_only": True},
)
next: str = field(
default="",
metadata={"metadata": {"description": "下一页URL"}, "dump_only": True},
)
prev: str = field(
default="",
metadata={"metadata": {"description": "上一页URL"}, "dump_only": True},
)
first: str = field(
default="", metadata={"metadata": {"description": "首页URL"}, "dump_only": True}
)
last: str = field(
default="", metadata={"metadata": {"description": "末页URL"}, "dump_only": True}
)
# 获取 Pagination 类中的所有字段
pagination_fields = [field.name for field in Pagination.__dataclass_fields__.values()]
class PaginationSchema(Schema):
"""自定义分页信息 Schema与 APIFlask 保持一致)"""
page = Integer(metadata={"description": "当前页码"})
size = Integer(metadata={"description": "每页数量"}) # per_page → size
pages = Integer(metadata={"description": "总页数"})
total = Integer(metadata={"description": "总记录数"})
current = URL(metadata={"description": "当前页URL"}, dump_only=True)
next = URL(metadata={"description": "下一页URL"}, dump_only=True)
prev = URL(metadata={"description": "上一页URL"}, dump_only=True)
first = URL(metadata={"description": "首页URL"}, dump_only=True)
last = URL(metadata={"description": "末页URL"}, dump_only=True)
# 获取 PaginationSchema 类中的所有字段
pagination_schema_fields = list(PaginationSchema._declared_fields.keys())
def page_schema(item_schema_cls: type, *, schema_name: str | None = None) -> type:
"""
根据传入的 Item OutSchema 生成通用分页负载 Schema
{ items: [Item], page: Pagination }
- item_schema_cls: 基础 Schema
- schema_name: 可选,自定义 schema 名称
"""
if not schema_name:
schema_name = f"PageItems[{getattr(item_schema_cls, '__name__', 'Item')}]"
class _PageItemsSchema(Schema):
items = fields.Nested(item_schema_cls, many=True, required=True)
page = fields.Nested(PaginationSchema, required=True)
_PageItemsSchema.__name__ = schema_name
return _PageItemsSchema
def condition_schema(base_schema_cls: type, control_config):
"""
多字段控制动态Schema创建
Args:
base_schema_class: 基础 Schema 类
control_config: 控制配置字典
{
"withDataList": ["data_list"], # 当 withDataList=true 时包含 data_list
"withTimestamps": ["created_at", "updated_at"], # 当 withTimestamps=true 时包含时间字段
"withStatus": ["status"], # 当 withStatus=true 时包含状态字段
}
"""
class _ConditionSchema(base_schema_cls):
def __init__(self, *args, **kwargs):
self._control_config = control_config
super().__init__(*args, **kwargs)
def dump(self, obj, many=None, **kwargs):
from flask import request
# 收集排除字段列表
exclude_fields = []
for control_field, target_fields in self._control_config.items():
try:
control_value = (
request.args.get(control_field, "false").lower() == "true"
)
except Exception:
# 没有请求上下文时,默认不排除字段
control_value = True
if not control_value:
exclude_fields.extend(target_fields)
# 如果有字段需要排除,创建临时 Schema
if exclude_fields:
temp_schema = base_schema_cls(exclude=exclude_fields)
return temp_schema.dump(obj, many=many, **kwargs)
else:
return super().dump(obj, many=many, **kwargs)
return _ConditionSchema
def custom_schema_name_resolver(schema):
"""
自定义 schema 名称解析器,解决循环引用导致的命名冲突
根据 APIFlask 官方文档:
- 函数接收一个 schema 对象作为参数
- 返回一个字符串作为 schema 的名称
- 用于解决多个 schema 解析为相同名称的问题
处理策略:
1. 优先使用 Meta.name如果定义
2. 自动移除 Schema 后缀
3. 为带有 exclude 参数的嵌套 schema 生成唯一名称
"""
schema_class = schema.__class__
# 1. 优先检查是否在 Meta 中定义了 name
if hasattr(schema_class, "Meta") and hasattr(schema_class.Meta, "name"):
base_name = schema_class.Meta.name
else:
# 2. 使用类名,移除 Schema 后缀
base_name = schema_class.__name__
if base_name.endswith("Schema"):
base_name = base_name[:-6]
if schema.partial: # 为部分模式添加 "Update" 后缀
base_name += "Update"
# 3. 处理嵌套时的 exclude 参数
# 当使用 Nested("SomeSchema", exclude=["field1", "field2"]) 时
# apispec 会创建新的 schema 实例,需要为其生成唯一名称
if hasattr(schema, "exclude") and schema.exclude:
# 将 exclude 的字段排序,确保相同的 exclude 组合生成相同的名称
excluded_fields = sorted(schema.exclude)
# 生成简洁的后缀:首字母大写拼接
# 例如exclude=["children", "parent"] -> "ChildrenParent"
suffix = "".join([field.capitalize() for field in excluded_fields])
return f"{base_name}Exclude{suffix}"
# 4. 处理 only 参数(如果使用)
if hasattr(schema, "only") and schema.only:
only_fields = sorted(schema.only)
suffix = "".join([field.capitalize() for field in only_fields])
return f"{base_name}Only{suffix}"
return base_name