|
|
from dataclasses import field
|
|
|
from marshmallow_dataclass import dataclass
|
|
|
from apiflask.fields import Integer, URL, Field
|
|
|
from marshmallow import fields
|
|
|
from apiflask import Schema
|
|
|
from .str import camel_case
|
|
|
|
|
|
|
|
|
# ==================== Schema 定义 ====================
|
|
|
class BaseSchema(Schema):
|
|
|
"""
|
|
|
基础 Schema 扩展
|
|
|
1. 有序返回
|
|
|
2. 未知字段不报错
|
|
|
"""
|
|
|
|
|
|
class Meta:
|
|
|
ordered = True
|
|
|
unknown = "INCLUDE"
|
|
|
|
|
|
def on_bind_field(self, field_name: str, field_obj: Field) -> None:
|
|
|
"""
|
|
|
绑定字段时处理
|
|
|
1.统一驼峰命名返回(小驼峰)
|
|
|
"""
|
|
|
if field_obj.data_key is None:
|
|
|
field_obj.data_key = camel_case(field_name)
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
@dataclass(base_schema=Schema)
|
|
|
class Pagination:
|
|
|
page: int = field(default=1, metadata={"metadata": {"description": "当前页码"}})
|
|
|
size: int = field(default=10, metadata={"metadata": {"description": "每页数量"}})
|
|
|
pages: int = field(default=0, metadata={"metadata": {"description": "总页数"}})
|
|
|
total: int = field(default=0, metadata={"metadata": {"description": "总记录数"}})
|
|
|
current: str = field(
|
|
|
default="",
|
|
|
metadata={"metadata": {"description": "当前页URL"}, "dump_only": True},
|
|
|
)
|
|
|
next: str = field(
|
|
|
default="",
|
|
|
metadata={"metadata": {"description": "下一页URL"}, "dump_only": True},
|
|
|
)
|
|
|
prev: str = field(
|
|
|
default="",
|
|
|
metadata={"metadata": {"description": "上一页URL"}, "dump_only": True},
|
|
|
)
|
|
|
first: str = field(
|
|
|
default="", metadata={"metadata": {"description": "首页URL"}, "dump_only": True}
|
|
|
)
|
|
|
last: str = field(
|
|
|
default="", metadata={"metadata": {"description": "末页URL"}, "dump_only": True}
|
|
|
)
|
|
|
|
|
|
|
|
|
# 获取 Pagination 类中的所有字段
|
|
|
pagination_fields = [field.name for field in Pagination.__dataclass_fields__.values()]
|
|
|
|
|
|
|
|
|
class PaginationSchema(Schema):
|
|
|
"""自定义分页信息 Schema(与 APIFlask 保持一致)"""
|
|
|
|
|
|
page = Integer(metadata={"description": "当前页码"})
|
|
|
size = Integer(metadata={"description": "每页数量"}) # per_page → size
|
|
|
pages = Integer(metadata={"description": "总页数"})
|
|
|
total = Integer(metadata={"description": "总记录数"})
|
|
|
current = URL(metadata={"description": "当前页URL"}, dump_only=True)
|
|
|
next = URL(metadata={"description": "下一页URL"}, dump_only=True)
|
|
|
prev = URL(metadata={"description": "上一页URL"}, dump_only=True)
|
|
|
first = URL(metadata={"description": "首页URL"}, dump_only=True)
|
|
|
last = URL(metadata={"description": "末页URL"}, dump_only=True)
|
|
|
|
|
|
|
|
|
# 获取 PaginationSchema 类中的所有字段
|
|
|
pagination_schema_fields = list(PaginationSchema._declared_fields.keys())
|
|
|
|
|
|
|
|
|
def page_schema(item_schema_cls: type, *, schema_name: str | None = None) -> type:
|
|
|
"""
|
|
|
根据传入的 Item OutSchema 生成通用分页负载 Schema:
|
|
|
{ items: [Item], page: Pagination }
|
|
|
|
|
|
- item_schema_cls: 基础 Schema
|
|
|
- schema_name: 可选,自定义 schema 名称
|
|
|
"""
|
|
|
if not schema_name:
|
|
|
schema_name = f"PageItems[{getattr(item_schema_cls, '__name__', 'Item')}]"
|
|
|
|
|
|
class _PageItemsSchema(Schema):
|
|
|
items = fields.Nested(item_schema_cls, many=True, required=True)
|
|
|
page = fields.Nested(PaginationSchema, required=True)
|
|
|
|
|
|
_PageItemsSchema.__name__ = schema_name
|
|
|
return _PageItemsSchema
|
|
|
|
|
|
|
|
|
def condition_schema(base_schema_cls: type, control_config):
|
|
|
"""
|
|
|
多字段控制动态Schema创建
|
|
|
|
|
|
Args:
|
|
|
base_schema_class: 基础 Schema 类
|
|
|
control_config: 控制配置字典
|
|
|
{
|
|
|
"withDataList": ["data_list"], # 当 withDataList=true 时包含 data_list
|
|
|
"withTimestamps": ["created_at", "updated_at"], # 当 withTimestamps=true 时包含时间字段
|
|
|
"withStatus": ["status"], # 当 withStatus=true 时包含状态字段
|
|
|
}
|
|
|
"""
|
|
|
|
|
|
class _ConditionSchema(base_schema_cls):
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
self._control_config = control_config
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
def dump(self, obj, many=None, **kwargs):
|
|
|
from flask import request
|
|
|
|
|
|
# 收集排除字段列表
|
|
|
exclude_fields = []
|
|
|
for control_field, target_fields in self._control_config.items():
|
|
|
try:
|
|
|
control_value = (
|
|
|
request.args.get(control_field, "false").lower() == "true"
|
|
|
)
|
|
|
except Exception:
|
|
|
# 没有请求上下文时,默认不排除字段
|
|
|
control_value = True
|
|
|
|
|
|
if not control_value:
|
|
|
exclude_fields.extend(target_fields)
|
|
|
|
|
|
# 如果有字段需要排除,创建临时 Schema
|
|
|
if exclude_fields:
|
|
|
temp_schema = base_schema_cls(exclude=exclude_fields)
|
|
|
return temp_schema.dump(obj, many=many, **kwargs)
|
|
|
else:
|
|
|
return super().dump(obj, many=many, **kwargs)
|
|
|
|
|
|
return _ConditionSchema
|