|
|
from dataclasses import field
|
|
|
from marshmallow_dataclass import dataclass
|
|
|
from apiflask.fields import Integer, URL, Field
|
|
|
from marshmallow import fields
|
|
|
from apiflask import Schema
|
|
|
from .str import camel_case
|
|
|
|
|
|
|
|
|
# ==================== Schema 定义 ====================
|
|
|
class BaseSchema(Schema):
|
|
|
"""
|
|
|
基础 Schema 扩展
|
|
|
1. 有序返回
|
|
|
2. 未知字段不报错
|
|
|
"""
|
|
|
|
|
|
class Meta:
|
|
|
ordered = True
|
|
|
unknown = "INCLUDE"
|
|
|
|
|
|
def on_bind_field(self, field_name: str, field_obj: Field) -> None:
|
|
|
"""
|
|
|
绑定字段时处理
|
|
|
1.统一驼峰命名返回(小驼峰)
|
|
|
"""
|
|
|
if field_obj.data_key is None:
|
|
|
field_obj.data_key = camel_case(field_name)
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
@dataclass(base_schema=Schema)
|
|
|
class Pagination:
|
|
|
page: int = field(default=1, metadata={"metadata": {"description": "当前页码"}})
|
|
|
size: int = field(default=10, metadata={"metadata": {"description": "每页数量"}})
|
|
|
pages: int = field(default=0, metadata={"metadata": {"description": "总页数"}})
|
|
|
total: int = field(default=0, metadata={"metadata": {"description": "总记录数"}})
|
|
|
current: str = field(
|
|
|
default="",
|
|
|
metadata={"metadata": {"description": "当前页URL"}, "dump_only": True},
|
|
|
)
|
|
|
next: str = field(
|
|
|
default="",
|
|
|
metadata={"metadata": {"description": "下一页URL"}, "dump_only": True},
|
|
|
)
|
|
|
prev: str = field(
|
|
|
default="",
|
|
|
metadata={"metadata": {"description": "上一页URL"}, "dump_only": True},
|
|
|
)
|
|
|
first: str = field(
|
|
|
default="", metadata={"metadata": {"description": "首页URL"}, "dump_only": True}
|
|
|
)
|
|
|
last: str = field(
|
|
|
default="", metadata={"metadata": {"description": "末页URL"}, "dump_only": True}
|
|
|
)
|
|
|
|
|
|
|
|
|
# 获取 Pagination 类中的所有字段
|
|
|
pagination_fields = [field.name for field in Pagination.__dataclass_fields__.values()]
|
|
|
|
|
|
|
|
|
class PaginationSchema(Schema):
|
|
|
"""自定义分页信息 Schema(与 APIFlask 保持一致)"""
|
|
|
|
|
|
page = Integer(metadata={"description": "当前页码"})
|
|
|
size = Integer(metadata={"description": "每页数量"}) # per_page → size
|
|
|
pages = Integer(metadata={"description": "总页数"})
|
|
|
total = Integer(metadata={"description": "总记录数"})
|
|
|
current = URL(metadata={"description": "当前页URL"}, dump_only=True)
|
|
|
next = URL(metadata={"description": "下一页URL"}, dump_only=True)
|
|
|
prev = URL(metadata={"description": "上一页URL"}, dump_only=True)
|
|
|
first = URL(metadata={"description": "首页URL"}, dump_only=True)
|
|
|
last = URL(metadata={"description": "末页URL"}, dump_only=True)
|
|
|
|
|
|
|
|
|
# 获取 PaginationSchema 类中的所有字段
|
|
|
pagination_schema_fields = list(PaginationSchema._declared_fields.keys())
|
|
|
|
|
|
|
|
|
def page_schema(item_schema_cls: type, *, schema_name: str | None = None) -> type:
|
|
|
"""
|
|
|
根据传入的 Item OutSchema 生成通用分页负载 Schema:
|
|
|
{ items: [Item], page: Pagination }
|
|
|
|
|
|
- item_schema_cls: 基础 Schema
|
|
|
- schema_name: 可选,自定义 schema 名称
|
|
|
"""
|
|
|
if not schema_name:
|
|
|
schema_name = f"PageItems[{getattr(item_schema_cls, '__name__', 'Item')}]"
|
|
|
|
|
|
class _PageItemsSchema(Schema):
|
|
|
items = fields.Nested(item_schema_cls, many=True, required=True)
|
|
|
page = fields.Nested(PaginationSchema, required=True)
|
|
|
|
|
|
_PageItemsSchema.__name__ = schema_name
|
|
|
return _PageItemsSchema
|
|
|
|
|
|
|
|
|
def condition_schema(base_schema_cls: type, control_config):
|
|
|
"""
|
|
|
多字段控制动态Schema创建
|
|
|
|
|
|
Args:
|
|
|
base_schema_class: 基础 Schema 类
|
|
|
control_config: 控制配置字典
|
|
|
{
|
|
|
"withDataList": ["data_list"], # 当 withDataList=true 时包含 data_list
|
|
|
"withTimestamps": ["created_at", "updated_at"], # 当 withTimestamps=true 时包含时间字段
|
|
|
"withStatus": ["status"], # 当 withStatus=true 时包含状态字段
|
|
|
}
|
|
|
"""
|
|
|
|
|
|
class _ConditionSchema(base_schema_cls):
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
self._control_config = control_config
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
def dump(self, obj, many=None, **kwargs):
|
|
|
from flask import request
|
|
|
|
|
|
# 收集排除字段列表
|
|
|
exclude_fields = []
|
|
|
for control_field, target_fields in self._control_config.items():
|
|
|
try:
|
|
|
control_value = (
|
|
|
request.args.get(control_field, "false").lower() == "true"
|
|
|
)
|
|
|
except Exception:
|
|
|
# 没有请求上下文时,默认不排除字段
|
|
|
control_value = True
|
|
|
|
|
|
if not control_value:
|
|
|
exclude_fields.extend(target_fields)
|
|
|
|
|
|
# 如果有字段需要排除,创建临时 Schema
|
|
|
if exclude_fields:
|
|
|
temp_schema = base_schema_cls(exclude=exclude_fields)
|
|
|
return temp_schema.dump(obj, many=many, **kwargs)
|
|
|
else:
|
|
|
return super().dump(obj, many=many, **kwargs)
|
|
|
|
|
|
return _ConditionSchema
|
|
|
|
|
|
|
|
|
def custom_schema_name_resolver(schema):
|
|
|
"""
|
|
|
自定义 schema 名称解析器,解决循环引用导致的命名冲突
|
|
|
|
|
|
根据 APIFlask 官方文档:
|
|
|
- 函数接收一个 schema 对象作为参数
|
|
|
- 返回一个字符串作为 schema 的名称
|
|
|
- 用于解决多个 schema 解析为相同名称的问题
|
|
|
|
|
|
处理策略:
|
|
|
1. 优先使用 Meta.name(如果定义)
|
|
|
2. 自动移除 Schema 后缀
|
|
|
3. 为带有 exclude 参数的嵌套 schema 生成唯一名称
|
|
|
"""
|
|
|
schema_class = schema.__class__
|
|
|
|
|
|
# 1. 优先检查是否在 Meta 中定义了 name
|
|
|
if hasattr(schema_class, "Meta") and hasattr(schema_class.Meta, "name"):
|
|
|
base_name = schema_class.Meta.name
|
|
|
else:
|
|
|
# 2. 使用类名,移除 Schema 后缀
|
|
|
base_name = schema_class.__name__
|
|
|
if base_name.endswith("Schema"):
|
|
|
base_name = base_name[:-6]
|
|
|
if schema.partial: # 为部分模式添加 "Update" 后缀
|
|
|
base_name += "Update"
|
|
|
|
|
|
# 3. 处理嵌套时的 exclude 参数
|
|
|
# 当使用 Nested("SomeSchema", exclude=["field1", "field2"]) 时
|
|
|
# apispec 会创建新的 schema 实例,需要为其生成唯一名称
|
|
|
if hasattr(schema, "exclude") and schema.exclude:
|
|
|
# 将 exclude 的字段排序,确保相同的 exclude 组合生成相同的名称
|
|
|
excluded_fields = sorted(schema.exclude)
|
|
|
# 生成简洁的后缀:首字母大写拼接
|
|
|
# 例如:exclude=["children", "parent"] -> "ChildrenParent"
|
|
|
suffix = "".join([field.capitalize() for field in excluded_fields])
|
|
|
return f"{base_name}Exclude{suffix}"
|
|
|
|
|
|
# 4. 处理 only 参数(如果使用)
|
|
|
if hasattr(schema, "only") and schema.only:
|
|
|
only_fields = sorted(schema.only)
|
|
|
suffix = "".join([field.capitalize() for field in only_fields])
|
|
|
return f"{base_name}Only{suffix}"
|
|
|
|
|
|
return base_name
|