refactor: rebuild fastapi framework foundation
parent
69c845aacd
commit
9a71aa8c93
@ -1,8 +1,10 @@
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app import app
|
||||
|
||||
|
||||
def test_example_ping():
|
||||
client = app.test_client()
|
||||
response = client.get("/example/ping")
|
||||
def test_health():
|
||||
response = TestClient(app).get("/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"]["pong"] is True
|
||||
assert response.json() == {"status": "ok"}
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
from iti.applications.common.crud import BaseModelMixin
|
||||
from iti.applications.extensions import db
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from iti.db import Base, IdMixin, TimestampMixin
|
||||
|
||||
class Example(BaseModelMixin):
|
||||
|
||||
class Example(IdMixin, TimestampMixin, Base):
|
||||
__tablename__ = "biz_example"
|
||||
|
||||
name = db.Column(db.String(128), nullable=False, comment="名称")
|
||||
name: Mapped[str] = mapped_column(String(128), nullable=False, comment="名称")
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
from apiflask import APIBlueprint
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from iti.applications.common.utils import success
|
||||
from iti.auth import require_permission
|
||||
from iti.responses import ok
|
||||
|
||||
bp = APIBlueprint("example", __name__, tag="Example")
|
||||
|
||||
router = APIRouter(prefix="/example", tags=["example"])
|
||||
|
||||
@bp.get("/ping")
|
||||
|
||||
@router.get("/ping", dependencies=[Depends(require_permission("example:item:list"))])
|
||||
def ping():
|
||||
return success({"pong": True})
|
||||
return ok({"pong": True})
|
||||
|
||||
@ -0,0 +1,76 @@
|
||||
# 审计
|
||||
|
||||
iTi-Flask 只提供审计事件工具和异步发送器。
|
||||
它不写 `sys_log`。
|
||||
|
||||
`sys_log` 由注册了 `iti-system` 的业务项目接收并入库。
|
||||
|
||||
## 配置
|
||||
|
||||
```python
|
||||
class DevConfig(BaseDevConfig):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.audit_enabled = True
|
||||
self.audit_service_name = "audit"
|
||||
self.services = {
|
||||
"audit": {
|
||||
"base_url": "http://hsyh-mes-phase2.local",
|
||||
"token": "change-me",
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
接收方需要把同一个 token 配进 `service_tokens`。
|
||||
|
||||
```python
|
||||
self.service_tokens = {"hsyh-erp": "change-me"}
|
||||
```
|
||||
|
||||
## 操作日志
|
||||
|
||||
业务显式提供 before / after 快照。
|
||||
框架负责 diff、脱敏和异步发送。
|
||||
|
||||
```python
|
||||
from fastapi import Request
|
||||
from iti.audit import audit_operation
|
||||
|
||||
|
||||
def update_order(order_id: str, request: Request):
|
||||
before = {"qty": 1}
|
||||
after = {"qty": 2}
|
||||
audit_operation(
|
||||
request,
|
||||
title="修改生产订单",
|
||||
target_type="mo",
|
||||
target_id=order_id,
|
||||
before=before,
|
||||
after=after,
|
||||
)
|
||||
```
|
||||
|
||||
## 登录日志
|
||||
|
||||
系统包登录接口已调用 `audit_login()`。
|
||||
普通业务项目如需自定义登录,也使用同一个工具。
|
||||
|
||||
```python
|
||||
from iti.audit import audit_login
|
||||
|
||||
|
||||
def login(request):
|
||||
audit_login(request, success=True, desc="admin")
|
||||
```
|
||||
|
||||
## 接收入口
|
||||
|
||||
`iti-system` 提供:
|
||||
|
||||
```http
|
||||
POST /internal/audit/events
|
||||
POST /internal/audit/login
|
||||
POST /internal/audit/operation
|
||||
```
|
||||
|
||||
这些接口使用框架服务 token 鉴权。
|
||||
@ -0,0 +1,189 @@
|
||||
# 前端管理端接口契约
|
||||
|
||||
本轮不改前端。
|
||||
这份文档用于后续管理端接口适配。
|
||||
|
||||
## 响应包装
|
||||
|
||||
管理端 API 默认 HTTP 200。
|
||||
|
||||
成功:
|
||||
|
||||
```json
|
||||
{"data": {}, "code": 200, "message": "成功"}
|
||||
```
|
||||
|
||||
失败:
|
||||
|
||||
```json
|
||||
{"data": null, "code": 403, "message": "权限不足"}
|
||||
```
|
||||
|
||||
字段输出使用 camelCase。
|
||||
|
||||
## 认证
|
||||
|
||||
登录:
|
||||
|
||||
```http
|
||||
POST /auth/loginByPassword
|
||||
POST /auth/loginByCode
|
||||
POST /auth/register
|
||||
POST /auth/refresh
|
||||
POST /auth/logout
|
||||
GET /auth/codes
|
||||
```
|
||||
|
||||
登录响应核心字段:
|
||||
|
||||
```json
|
||||
{
|
||||
"accessToken": "...",
|
||||
"tokenType": "Bearer",
|
||||
"expiresIn": 86400,
|
||||
"refreshToken": "...",
|
||||
"refreshExpiresIn": 2592000,
|
||||
"user": {}
|
||||
}
|
||||
```
|
||||
|
||||
后续请求:
|
||||
|
||||
```http
|
||||
Authorization: Bearer <accessToken>
|
||||
```
|
||||
|
||||
## 系统接口
|
||||
|
||||
用户:
|
||||
|
||||
```http
|
||||
GET /sys/user/current
|
||||
GET /sys/user/list
|
||||
GET /sys/user/page
|
||||
POST /sys/user
|
||||
PUT /sys/user/{id}
|
||||
DELETE /sys/user/{id}
|
||||
PUT /sys/user/password
|
||||
```
|
||||
|
||||
角色:
|
||||
|
||||
```http
|
||||
GET /sys/role/list
|
||||
GET /sys/role/page
|
||||
POST /sys/role
|
||||
PUT /sys/role/{id}
|
||||
DELETE /sys/role/{id}
|
||||
```
|
||||
|
||||
菜单:
|
||||
|
||||
```http
|
||||
GET /sys/menu/list
|
||||
GET /sys/menu/tree
|
||||
GET /sys/menu/exists
|
||||
POST /sys/menu
|
||||
PUT /sys/menu/{id}
|
||||
DELETE /sys/menu/{id}
|
||||
```
|
||||
|
||||
部门:
|
||||
|
||||
```http
|
||||
GET /sys/dept/list
|
||||
GET /sys/dept/page
|
||||
GET /sys/dept/tree
|
||||
POST /sys/dept
|
||||
PUT /sys/dept/{id}
|
||||
DELETE /sys/dept/{id}
|
||||
```
|
||||
|
||||
配置:
|
||||
|
||||
```http
|
||||
GET /sys/config/list
|
||||
GET /sys/config/page
|
||||
POST /sys/config
|
||||
PUT /sys/config/{id}
|
||||
DELETE /sys/config/{id}
|
||||
```
|
||||
|
||||
字典:
|
||||
|
||||
```http
|
||||
GET /sys/dict/type/page
|
||||
GET /sys/dict/type
|
||||
POST /sys/dict/type
|
||||
PUT /sys/dict/type/{id}
|
||||
DELETE /sys/dict/type/{id}
|
||||
GET /sys/dict/data/page
|
||||
GET /sys/dict/data/list
|
||||
GET /sys/dict/data/{id}
|
||||
GET /sys/dict/data
|
||||
POST /sys/dict/data
|
||||
PUT /sys/dict/data/{id}
|
||||
DELETE /sys/dict/data/{id}
|
||||
DELETE /sys/dict/data/batch
|
||||
```
|
||||
|
||||
日志:
|
||||
|
||||
```http
|
||||
GET /sys/log/page
|
||||
DELETE /sys/log/{id}
|
||||
DELETE /sys/log/batch
|
||||
```
|
||||
|
||||
文件:
|
||||
|
||||
```http
|
||||
POST /upload
|
||||
POST /upload/chunk/init
|
||||
POST /upload/chunk/upload
|
||||
POST /upload/chunk/merge
|
||||
DELETE /upload/chunk/{uploadId}
|
||||
GET /upload/chunk/{uploadId}/progress
|
||||
POST /upload/chunk/cleanup
|
||||
GET /sys/file/{fileId}
|
||||
DELETE /sys/file/{fileId}
|
||||
POST /sys/file/{fileId}/restore
|
||||
DELETE /sys/file/{fileId}/permanent
|
||||
POST /sys/file/{fileId}/share
|
||||
DELETE /sys/file/{fileId}/share
|
||||
GET /file/{fileId}/download
|
||||
GET /file/{fileId}/preview
|
||||
GET /file/{fileId}/thumbnail
|
||||
GET /file/share/{shareCode}
|
||||
GET /file/share/{shareCode}/download
|
||||
```
|
||||
|
||||
用户扩展属性:
|
||||
|
||||
```http
|
||||
GET /sys/user-attributes/current
|
||||
PUT /sys/user-attributes/current
|
||||
GET /sys/user-attributes/{userId}
|
||||
PUT /sys/user-attributes/{userId}
|
||||
GET /sys/user-attributes/{userId}/{group}/{key}
|
||||
PUT /sys/user-attributes/{userId}/{group}/{key}
|
||||
DELETE /sys/user-attributes/{userId}/{group}
|
||||
DELETE /sys/user-attributes/{userId}/{group}/{key}
|
||||
POST /sys/user-attributes/{userId}/batch
|
||||
```
|
||||
|
||||
## 分页
|
||||
|
||||
分页响应:
|
||||
|
||||
```json
|
||||
{
|
||||
"items": [],
|
||||
"page": {
|
||||
"page": 1,
|
||||
"size": 10,
|
||||
"pages": 1,
|
||||
"total": 0
|
||||
}
|
||||
}
|
||||
```
|
||||
@ -1,3 +1,3 @@
|
||||
# SPDX-FileCopyrightText: 2025-present NoahLan <6995syu@163.com>
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
from iti.app import create_app
|
||||
|
||||
__all__ = ["create_app"]
|
||||
|
||||
@ -1,6 +1,312 @@
|
||||
from iti.applications import create_app
|
||||
from __future__ import annotations
|
||||
|
||||
app = create_app()
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Iterable, Mapping
|
||||
from contextvars import ContextVar
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from functools import wraps
|
||||
from inspect import isawaitable
|
||||
from typing import Any
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(debug=True)
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.dependencies.utils import (
|
||||
_should_embed_body_fields,
|
||||
get_body_field,
|
||||
get_dependant,
|
||||
get_flat_dependant,
|
||||
get_parameterless_sub_dependant,
|
||||
)
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.routing import APIRoute
|
||||
from fastapi.routing import request_response
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from starlette.responses import Response
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from iti.auth.permissions import StaticPermissionProvider
|
||||
from iti.audit import init_audit
|
||||
from iti.cache import CacheManager
|
||||
from iti.config import BaseConfig, get_config
|
||||
from iti.db import configure_db
|
||||
from iti.exceptions import ItiError
|
||||
from iti.health import router as health_router
|
||||
from iti.limiter import SimpleLimiter
|
||||
from iti.logging_config import configure_logging, log_extra
|
||||
from iti.modules import init_modules
|
||||
from iti.responses.auto import is_envelope_payload, is_raw_response_request
|
||||
from iti.responses import fail
|
||||
from iti.service_client import init_service_clients
|
||||
from iti.tasks import init_task_runner
|
||||
|
||||
logger = logging.getLogger("iti")
|
||||
error_logger = logging.getLogger("iti.error")
|
||||
_current_request: ContextVar[Request | None] = ContextVar("iti_current_request", default=None)
|
||||
|
||||
|
||||
def create_app(
|
||||
config_name: str | None = None,
|
||||
*,
|
||||
modules: Iterable[Any] | None = None,
|
||||
config_mapping: Mapping[str, type[BaseConfig] | BaseConfig] | type[BaseConfig] | BaseConfig | None = None,
|
||||
permission_provider: Any | None = None,
|
||||
) -> FastAPI:
|
||||
config = _resolve_config(config_name, config_mapping)
|
||||
configure_logging(config)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
runner = getattr(app.state, "iti_task_runner", None)
|
||||
audit_dispatcher = getattr(app.state, "audit_dispatcher", None)
|
||||
if audit_dispatcher:
|
||||
audit_dispatcher.start()
|
||||
if runner and config.tasks_enabled:
|
||||
runner.start()
|
||||
yield
|
||||
if runner:
|
||||
runner.stop()
|
||||
if audit_dispatcher:
|
||||
audit_dispatcher.stop()
|
||||
for client in getattr(app.state, "iti_service_clients", {}).values():
|
||||
client.close()
|
||||
|
||||
app = FastAPI(
|
||||
title=config.app_name,
|
||||
debug=config.debug,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
app.state.config = config
|
||||
app.state.cache = CacheManager(default_timeout=config.cache_default_timeout)
|
||||
app.state.limiter = SimpleLimiter(enabled=config.ratelimit_enabled)
|
||||
app.state.permission_provider = permission_provider or StaticPermissionProvider()
|
||||
|
||||
init_middlewares(app)
|
||||
|
||||
if config.cors_origins:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=config.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
engine, sessionmaker = configure_db(
|
||||
config.database_url,
|
||||
echo=config.sqlalchemy_echo,
|
||||
pool_pre_ping=config.sqlalchemy_pool_pre_ping,
|
||||
)
|
||||
app.state.db_engine = engine
|
||||
app.state.db_sessionmaker = sessionmaker
|
||||
|
||||
init_error_handlers(app)
|
||||
init_service_clients(app, config.services)
|
||||
init_task_runner(app)
|
||||
init_audit(app)
|
||||
module_registry = init_modules(app, modules)
|
||||
app.state.iti_modules = module_registry
|
||||
if config.health_enabled:
|
||||
app.include_router(health_router)
|
||||
module_registry.run_phase("register_routes", app)
|
||||
module_registry.run_phase("register_permissions", app)
|
||||
module_registry.run_phase("register_menu_seed", app)
|
||||
install_auto_envelope(app)
|
||||
return app
|
||||
|
||||
|
||||
def init_middlewares(app: FastAPI) -> None:
|
||||
@app.middleware("http")
|
||||
async def request_context_middleware(request: Request, call_next):
|
||||
token = _current_request.set(request)
|
||||
trace_id = request.headers.get("X-Trace-Id") or uuid.uuid4().hex
|
||||
request_id = request.headers.get("X-Request-Id") or uuid.uuid4().hex
|
||||
request.state.trace_id = trace_id
|
||||
request.state.request_id = request_id
|
||||
started_at = time.perf_counter()
|
||||
response: Response | None
|
||||
try:
|
||||
response = await call_next(request)
|
||||
except Exception:
|
||||
request.state.response_code = getattr(request.state, "response_code", 500)
|
||||
_log_request(request, started_at, 500)
|
||||
_current_request.reset(token)
|
||||
raise
|
||||
response.headers.setdefault("X-Trace-Id", trace_id)
|
||||
response.headers.setdefault("X-Request-Id", request_id)
|
||||
_log_request(request, started_at, response.status_code)
|
||||
_current_request.reset(token)
|
||||
return response
|
||||
|
||||
|
||||
def _log_request(request: Request, started_at: float, status_code: int) -> None:
|
||||
duration_ms = round((time.perf_counter() - started_at) * 1000, 2)
|
||||
logger.info(
|
||||
"request method=%s path=%s status=%s code=%s durationMs=%s ip=%s",
|
||||
request.method,
|
||||
request.url.path,
|
||||
status_code,
|
||||
getattr(request.state, "response_code", "-"),
|
||||
duration_ms,
|
||||
request.client.host if request.client else "-",
|
||||
extra=log_extra(request),
|
||||
)
|
||||
|
||||
|
||||
def _resolve_config(
|
||||
config_name: str | None,
|
||||
config_mapping: Mapping[str, type[BaseConfig] | BaseConfig] | type[BaseConfig] | BaseConfig | None,
|
||||
) -> BaseConfig:
|
||||
if config_mapping is None:
|
||||
return get_config(config_name)
|
||||
if isinstance(config_mapping, Mapping):
|
||||
env_name = config_name or "dev"
|
||||
value = config_mapping.get(env_name, config_mapping.get("default"))
|
||||
if value is None:
|
||||
return get_config(config_name)
|
||||
return value() if isinstance(value, type) else value
|
||||
return config_mapping() if isinstance(config_mapping, type) else config_mapping
|
||||
|
||||
|
||||
def init_error_handlers(app: FastAPI) -> None:
|
||||
@app.exception_handler(ItiError)
|
||||
async def handle_iti_error(request: Request, exc: ItiError):
|
||||
request.state.response_code = exc.code
|
||||
return JSONResponse(
|
||||
status_code=request.app.state.config.response_envelope_http_status,
|
||||
content=fail(exc.message, code=exc.code, data=exc.data),
|
||||
)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def handle_validation_error(request: Request, exc: RequestValidationError):
|
||||
request.state.response_code = 422
|
||||
return JSONResponse(
|
||||
status_code=request.app.state.config.response_envelope_http_status,
|
||||
content=fail("参数验证错误", code=422, data=exc.errors()),
|
||||
)
|
||||
|
||||
@app.exception_handler(SQLAlchemyError)
|
||||
async def handle_db_error(request: Request, exc: SQLAlchemyError):
|
||||
request.state.response_code = 500
|
||||
error_logger.exception("database error", extra=log_extra(request))
|
||||
return JSONResponse(
|
||||
status_code=request.app.state.config.response_envelope_http_status,
|
||||
content=fail("数据库错误", code=500, data=str(exc)),
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def handle_exception(request: Request, exc: Exception):
|
||||
request.state.response_code = 500
|
||||
error_logger.exception("server error", extra=log_extra(request))
|
||||
return JSONResponse(
|
||||
status_code=request.app.state.config.response_envelope_http_status,
|
||||
content=fail("服务器错误", code=500, data=str(exc)),
|
||||
)
|
||||
|
||||
|
||||
def to_plain_data(value: Any) -> Any:
|
||||
if is_dataclass(value):
|
||||
return asdict(value)
|
||||
return value
|
||||
|
||||
|
||||
def install_auto_envelope(app: FastAPI) -> None:
|
||||
config = app.state.config
|
||||
if not config.response_envelope_enabled:
|
||||
return
|
||||
raw_paths = tuple(config.raw_response_paths)
|
||||
for route in app.routes:
|
||||
if not isinstance(route, APIRoute):
|
||||
continue
|
||||
if getattr(route, "__iti_envelope_installed__", False):
|
||||
continue
|
||||
if _is_route_raw(route, raw_paths):
|
||||
continue
|
||||
original_call = route.dependant.call
|
||||
if original_call is None:
|
||||
continue
|
||||
route.endpoint = _wrap_endpoint_with_envelope(original_call)
|
||||
_rebuild_route_dependant(route)
|
||||
setattr(route, "__iti_envelope_installed__", True)
|
||||
|
||||
|
||||
def _is_route_raw(route: APIRoute, raw_paths: Iterable[str]) -> bool:
|
||||
endpoint = route.endpoint
|
||||
if getattr(endpoint, "__iti_raw_response__", False):
|
||||
return True
|
||||
for path in route.path_format, route.path:
|
||||
request = _PathOnlyRequest(path)
|
||||
if is_raw_response_request(request, raw_paths):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _wrap_endpoint_with_envelope(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
result = func(*args, **kwargs)
|
||||
if isawaitable(result):
|
||||
result = await result
|
||||
if isinstance(result, Response):
|
||||
return result
|
||||
payload = _to_jsonable(result)
|
||||
if is_envelope_payload(payload):
|
||||
_mark_response_code(args, kwargs, payload["code"])
|
||||
return payload
|
||||
_mark_response_code(args, kwargs, 200)
|
||||
return {"data": payload, "code": 200, "message": "成功"}
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _to_jsonable(value: Any) -> Any:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, BaseModel):
|
||||
return value.model_dump(by_alias=True)
|
||||
if is_dataclass(value):
|
||||
return asdict(value)
|
||||
return value
|
||||
|
||||
|
||||
def _mark_response_code(args: tuple[Any, ...], kwargs: dict[str, Any], code: int) -> None:
|
||||
request = _request_from_call(args, kwargs) or _current_request.get()
|
||||
if request is not None:
|
||||
request.state.response_code = code
|
||||
|
||||
|
||||
def _request_from_call(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Request | None:
|
||||
for value in list(args) + list(kwargs.values()):
|
||||
if isinstance(value, Request):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _rebuild_route_dependant(route: APIRoute) -> None:
|
||||
route.dependant = get_dependant(
|
||||
path=route.path_format,
|
||||
call=route.endpoint,
|
||||
scope="function",
|
||||
)
|
||||
for depends in route.dependencies[::-1]:
|
||||
route.dependant.dependencies.insert(
|
||||
0,
|
||||
get_parameterless_sub_dependant(depends=depends, path=route.path_format),
|
||||
)
|
||||
route._flat_dependant = get_flat_dependant(route.dependant)
|
||||
route._embed_body_fields = _should_embed_body_fields(route._flat_dependant.body_params)
|
||||
route.body_field = get_body_field(
|
||||
flat_dependant=route._flat_dependant,
|
||||
name=route.unique_id,
|
||||
embed_body_fields=route._embed_body_fields,
|
||||
)
|
||||
route.app = request_response(route.get_route_handler())
|
||||
|
||||
|
||||
class _PathOnlyRequest:
|
||||
def __init__(self, path: str) -> None:
|
||||
self.url = type("URL", (), {"path": path})()
|
||||
self.scope = {"endpoint": None}
|
||||
|
||||
@ -1,2 +0,0 @@
|
||||
from .logger import setup_logger
|
||||
from .filter import ModelFilter
|
||||
@ -1,27 +0,0 @@
|
||||
"""
|
||||
通用工具模块
|
||||
"""
|
||||
|
||||
from .http import success, fail, page, pagination_builder
|
||||
from .schema import (
|
||||
Pagination,
|
||||
PaginationSchema,
|
||||
pagination_fields,
|
||||
pagination_schema_fields,
|
||||
page_schema,
|
||||
condition_schema,
|
||||
BaseSchema,
|
||||
custom_schema_name_resolver
|
||||
)
|
||||
from .tree import (
|
||||
build_tree_from_list,
|
||||
flatten_tree,
|
||||
find_node_by_id,
|
||||
get_node_path,
|
||||
filter_tree_by_condition,
|
||||
get_tree_depth,
|
||||
TreeKeyConfig,
|
||||
default_key_config,
|
||||
)
|
||||
from .str import camel_case
|
||||
from .time import parse_datetime_string
|
||||
@ -1,29 +0,0 @@
|
||||
from sqlalchemy.ext.declarative import DeclarativeBase
|
||||
|
||||
|
||||
def is_sqlalchemy_model(obj):
|
||||
"""
|
||||
判断对象是否为 SQLAlchemy 模型
|
||||
"""
|
||||
if isinstance(obj, DeclarativeBase):
|
||||
return True
|
||||
|
||||
if hasattr(obj, "_sa_instance_state"):
|
||||
return True
|
||||
|
||||
if hasattr(obj, "__mapper__"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_orm_result(data):
|
||||
"""
|
||||
判断数据是否为 ORM 查询结果
|
||||
"""
|
||||
if isinstance(data, list):
|
||||
if not data:
|
||||
return False
|
||||
return is_sqlalchemy_model(data[0])
|
||||
|
||||
return is_sqlalchemy_model(data)
|
||||
@ -1,41 +0,0 @@
|
||||
from .error_views import init_error_views
|
||||
from .limit import init_limiter, limiter
|
||||
from .jwt import init_jwt, jwt
|
||||
from .db import init_db, db, ma
|
||||
from .migrate import init_migrate, migrate
|
||||
from .plugins import init_plugin, broadcast_execute
|
||||
from .encoder import init_encoder
|
||||
from .moment import init_moment, moment
|
||||
from .http import init_http
|
||||
from .error_handler import init_error_handler
|
||||
from .cache import init_cache, cache_simple, cache_redis
|
||||
from .event_bus import init_eventbus, eventbus
|
||||
from iti.applications.common.logger import init_logger
|
||||
|
||||
|
||||
def init_exts(app) -> None:
|
||||
# 日志
|
||||
init_logger(app)
|
||||
|
||||
# 插件
|
||||
init_plugin(app)
|
||||
broadcast_execute(app, "event_begin")
|
||||
|
||||
# http
|
||||
init_http(app)
|
||||
|
||||
# json
|
||||
init_encoder(app)
|
||||
init_moment(app)
|
||||
|
||||
# flask 扩展
|
||||
init_db(app)
|
||||
init_jwt(app)
|
||||
init_migrate(app)
|
||||
init_limiter(app)
|
||||
init_cache(app)
|
||||
init_eventbus(app)
|
||||
|
||||
# 系统蓝图相关
|
||||
init_error_views(app)
|
||||
init_error_handler(app)
|
||||
@ -1,15 +0,0 @@
|
||||
from flask_caching import Cache
|
||||
|
||||
cache_simple = Cache()
|
||||
cache_redis = Cache()
|
||||
|
||||
|
||||
def init_cache(app):
|
||||
simpleConfig = app.config.get("CACHE_SIMPLE", None)
|
||||
redisConfig = app.config.get("CACHE_REDIS", None)
|
||||
if simpleConfig is not None and simpleConfig.get("ENABLED", False):
|
||||
app.logger.info(f"初始化简单缓存: {simpleConfig}")
|
||||
cache_simple.init_app(app, config=simpleConfig)
|
||||
if redisConfig is not None and redisConfig.get("ENABLED", False):
|
||||
app.logger.info(f"初始化 Redis 缓存: {redisConfig}")
|
||||
cache_redis.init_app(app, config=redisConfig)
|
||||
@ -1,124 +0,0 @@
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
from flask_sqlalchemy.query import Query as BaseQuery
|
||||
from flask_marshmallow import Marshmallow
|
||||
import datetime
|
||||
import os
|
||||
from marshmallow import fields
|
||||
from marshmallow.validate import (
|
||||
URL,
|
||||
Email,
|
||||
Range,
|
||||
Length,
|
||||
Equal,
|
||||
Regexp,
|
||||
Predicate,
|
||||
NoneOf,
|
||||
OneOf,
|
||||
ContainsOnly,
|
||||
)
|
||||
from iti.applications.common.utils import fail
|
||||
from sqlalchemy import MetaData
|
||||
|
||||
URL.default_message = "无效的链接"
|
||||
Email.default_message = "无效的邮箱地址"
|
||||
Range.message_min = "不能小于{min}"
|
||||
Range.message_max = "不能小于{max}"
|
||||
Range.message_all = "不能超过{min}和{max}这个范围"
|
||||
Length.message_min = "长度不得小于{min}位"
|
||||
Length.message_max = "长度不得大于{max}位"
|
||||
Length.message_all = "长度不能超过{min}和{max}这个范围"
|
||||
Length.message_equal = "长度必须等于{equal}位"
|
||||
Equal.default_message = "必须等于{other}"
|
||||
Regexp.default_message = "非法输入"
|
||||
Predicate.default_message = "非法输入"
|
||||
NoneOf.default_message = "非法输入"
|
||||
OneOf.default_message = "无效的选择"
|
||||
ContainsOnly.default_message = "一个或多个无效的选择"
|
||||
|
||||
fields.Field.default_error_messages = {
|
||||
"required": "缺少必要数据",
|
||||
"null": "数据不能为空",
|
||||
"validator_failed": "非法数据",
|
||||
}
|
||||
|
||||
fields.Str.default_error_messages = {"invalid": "不是合法文本"}
|
||||
fields.Int.default_error_messages = {"invalid": "不是合法整数"}
|
||||
fields.Number.default_error_messages = {"invalid": "不是合法数字"}
|
||||
fields.Boolean.default_error_messages = {"invalid": "不是合法布尔值"}
|
||||
|
||||
|
||||
class Query(BaseQuery):
|
||||
def soft_delete(self):
|
||||
"""
|
||||
软删除查询
|
||||
"""
|
||||
return self.update({"deleted_at": datetime.datetime.now()})
|
||||
|
||||
def logic_all(self):
|
||||
"""
|
||||
逻辑未删除查询
|
||||
"""
|
||||
return self.filter_by(deleted_at=None).all()
|
||||
|
||||
def all_json(self, schema: Marshmallow().Schema):
|
||||
"""
|
||||
查询结果转换为 JSON
|
||||
"""
|
||||
return schema(many=True).dump(self.all())
|
||||
|
||||
|
||||
naming_convention = {
|
||||
"ix": "ix_%(column_0_label)s",
|
||||
"uq": "uq_%(table_name)s_%(column_0_name)s",
|
||||
"ck": "ck_%(table_name)s_%(column_0_name)s",
|
||||
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
||||
"pk": "pk_%(table_name)s",
|
||||
}
|
||||
|
||||
db = SQLAlchemy(
|
||||
query_class=Query, metadata=MetaData(naming_convention=naming_convention),
|
||||
)
|
||||
ma = Marshmallow()
|
||||
|
||||
|
||||
def init_db(app) -> None:
|
||||
"""
|
||||
初始化数据库
|
||||
"""
|
||||
db.init_app(app)
|
||||
ma.init_app(app)
|
||||
|
||||
# db错误处理
|
||||
_handle_db_error(app)
|
||||
|
||||
if os.environ.get("WERKZEUG_RUN_MAIN") == "true":
|
||||
with app.app_context():
|
||||
try:
|
||||
db.engine.connect()
|
||||
except Exception as e:
|
||||
exit(f"数据库连接失败: {e}")
|
||||
|
||||
|
||||
def _handle_db_error(app):
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
show_error_details = app.config.get("SQLALCHEMY_SHOW_ERROR_DETAILS", False)
|
||||
|
||||
@app.errorhandler(SQLAlchemyError)
|
||||
def handle_sqlalchemy_db_error(error):
|
||||
"""
|
||||
SQLAlchemy 数据库错误处理
|
||||
"""
|
||||
app.logger.error(f"数据库错误: {error}")
|
||||
data = {
|
||||
"code": error.code if hasattr(error, "code") else 500,
|
||||
}
|
||||
if show_error_details:
|
||||
data["args"] = error.args if hasattr(error, "args") else None
|
||||
data["statement"] = error.statement if hasattr(error, "statement") else None
|
||||
data["params"] = error.params if hasattr(error, "params") else None
|
||||
return fail(
|
||||
"数据库错误",
|
||||
code=500,
|
||||
data=data,
|
||||
)
|
||||
@ -1,31 +0,0 @@
|
||||
from flask import request, render_template
|
||||
|
||||
from iti.applications.common.utils import fail
|
||||
|
||||
|
||||
def _wants_html() -> bool:
|
||||
return (
|
||||
request.accept_mimetypes.accept_html
|
||||
and request.accept_mimetypes["text/html"]
|
||||
>= request.accept_mimetypes["application/json"]
|
||||
)
|
||||
|
||||
|
||||
def init_error_views(app):
|
||||
@app.errorhandler(403)
|
||||
def forbidden(error):
|
||||
if not _wants_html():
|
||||
return fail(message="Forbidden", code=403), 200
|
||||
return render_template("errors/403.html"), 403
|
||||
|
||||
@app.errorhandler(404)
|
||||
def page_not_found(error):
|
||||
if not _wants_html():
|
||||
return fail(message="Not Found", code=404), 200
|
||||
return render_template("errors/404.html"), 404
|
||||
|
||||
@app.errorhandler(500)
|
||||
def internal_server_error(error):
|
||||
if not _wants_html():
|
||||
return fail(message="Internal Server Error", code=500), 200
|
||||
return render_template("errors/500.html"), 500
|
||||
@ -1,10 +0,0 @@
|
||||
from .eventbus import EventBus
|
||||
|
||||
eventbus = EventBus()
|
||||
|
||||
|
||||
def init_eventbus(app):
|
||||
"""
|
||||
初始化事件总线
|
||||
"""
|
||||
eventbus.init_app(app)
|
||||
@ -1,6 +0,0 @@
|
||||
from .event_bus import EventBus
|
||||
from .event_middleware import EventMiddleware
|
||||
from .event_handler import (
|
||||
BaseEventHandler,
|
||||
FlaskEventHandler,
|
||||
)
|
||||
@ -1,70 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from iti.applications.common import setup_logger
|
||||
from flask import current_app
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
class BaseEventHandler(ABC):
|
||||
"""
|
||||
事件处理器基类
|
||||
"""
|
||||
|
||||
def __init__(self, order: int = 0, async_mode: bool = False):
|
||||
self.order = order
|
||||
self.async_mode = async_mode
|
||||
|
||||
@abstractmethod
|
||||
def handle(self, data: any) -> any:
|
||||
"""
|
||||
处理事件
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_handle(self, data: any) -> any:
|
||||
"""
|
||||
处理事件前
|
||||
"""
|
||||
return data
|
||||
|
||||
def after_handle(self, data: any) -> any:
|
||||
"""
|
||||
处理事件后
|
||||
"""
|
||||
return data
|
||||
|
||||
def on_error(self, error: Exception, data: any) -> None:
|
||||
"""
|
||||
处理事件错误
|
||||
"""
|
||||
logger.error(f"事件处理错误: {error}, 数据: {data}", exc_info=True)
|
||||
|
||||
|
||||
class FlaskEventHandler(BaseEventHandler):
|
||||
"""Flask 事件处理器基类"""
|
||||
|
||||
def __init__(self, order: int = 0, async_mode: bool = False):
|
||||
super().__init__(order, async_mode)
|
||||
self._app = None
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
"""获取 Flask 应用实例"""
|
||||
if self._app is None:
|
||||
self._app = current_app
|
||||
return self._app
|
||||
|
||||
def handle(self, data: any) -> any:
|
||||
"""处理事件"""
|
||||
try:
|
||||
data = self.before_handle(data)
|
||||
result = self._do_handle(data)
|
||||
result = self.after_handle(result)
|
||||
return result
|
||||
except Exception as e:
|
||||
self.on_error(e, data)
|
||||
raise
|
||||
|
||||
def _do_handle(self, data: any) -> any:
|
||||
"""实际处理逻辑"""
|
||||
pass
|
||||
@ -1,19 +0,0 @@
|
||||
class EventMiddleware:
|
||||
"""
|
||||
事件中间件基类
|
||||
"""
|
||||
|
||||
def __call__(self, event_name: str, args: tuple, kwargs: dict) -> tuple:
|
||||
"""
|
||||
处理事件
|
||||
返回处理后的 (args, kwargs)
|
||||
"""
|
||||
return args, kwargs
|
||||
|
||||
def on_error(
|
||||
self, error: Exception, event_name: str, args: tuple, kwargs: dict
|
||||
) -> None:
|
||||
"""
|
||||
处理事件错误
|
||||
"""
|
||||
pass
|
||||
@ -1,30 +0,0 @@
|
||||
from flask_jwt_extended import JWTManager
|
||||
from iti.applications.common.utils import fail
|
||||
|
||||
jwt = JWTManager()
|
||||
|
||||
|
||||
def init_jwt(app) -> None:
|
||||
"""
|
||||
初始化 JWT
|
||||
"""
|
||||
jwt.init_app(app)
|
||||
|
||||
# 自定义错误消息
|
||||
@jwt.unauthorized_loader
|
||||
def unauthorized_loader(_callback):
|
||||
return fail("缺少令牌参数 Authorization Bearer", code=401), 401
|
||||
|
||||
@jwt.invalid_token_loader
|
||||
def invalid_token_loader(_callback):
|
||||
return fail("无效的令牌", code=401, data=str(_callback)), 401
|
||||
|
||||
@jwt.expired_token_loader
|
||||
def expired_token_loader(_header, _payload):
|
||||
return fail("令牌已过期", code=401), 401
|
||||
|
||||
@jwt.user_identity_loader
|
||||
def user_identity_loader(user):
|
||||
if user is None or not hasattr(user, "id"):
|
||||
return None
|
||||
return user.id
|
||||
@ -1,11 +0,0 @@
|
||||
from flask_migrate import Migrate
|
||||
from .db import db
|
||||
|
||||
migrate = Migrate()
|
||||
|
||||
|
||||
def init_migrate(app) -> None:
|
||||
"""
|
||||
初始化迁移
|
||||
"""
|
||||
migrate.init_app(app, db)
|
||||
@ -1,8 +0,0 @@
|
||||
from flask_moment import Moment
|
||||
|
||||
|
||||
moment = Moment()
|
||||
|
||||
|
||||
def init_moment(app):
|
||||
moment.init_app(app)
|
||||
@ -1,11 +0,0 @@
|
||||
from iti.applications.extensions import broadcast_execute
|
||||
from iti.applications.routes.front import bp as frontend_bp
|
||||
|
||||
|
||||
def init_routes(app):
|
||||
# 前端路由注册(可选)
|
||||
if app.config.get("FRONTEND_ENABLED", False):
|
||||
app.register_blueprint(frontend_bp)
|
||||
|
||||
# 插件初始化
|
||||
broadcast_execute(app, "event_init")
|
||||
@ -1,43 +0,0 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from apiflask import APIBlueprint
|
||||
from flask import abort, current_app, send_from_directory
|
||||
|
||||
bp = APIBlueprint("front", __name__, tag="前端")
|
||||
|
||||
|
||||
def _get_frontend_path():
|
||||
frontend_path = current_app.config.get("FRONTEND_PATH")
|
||||
if not frontend_path:
|
||||
abort(404)
|
||||
|
||||
path = Path(frontend_path)
|
||||
if not path.is_absolute():
|
||||
base_dir = Path(current_app.config.get("BASE_DIR", os.getcwd()))
|
||||
path = base_dir / path
|
||||
return path.resolve()
|
||||
|
||||
|
||||
@bp.get("/")
|
||||
def index():
|
||||
"""渲染前端 SPA 入口页面"""
|
||||
frontend_path = _get_frontend_path()
|
||||
index_path = frontend_path / "index.html"
|
||||
if not index_path.exists():
|
||||
abort(404)
|
||||
return send_from_directory(frontend_path, "index.html")
|
||||
|
||||
|
||||
@bp.get("/<path:fallback>")
|
||||
def fallback(fallback):
|
||||
"""兜底: 避免history模式下的影响"""
|
||||
frontend_path = _get_frontend_path()
|
||||
target_path = frontend_path / fallback
|
||||
if target_path.exists() and target_path.is_file():
|
||||
return send_from_directory(frontend_path, fallback)
|
||||
|
||||
index_path = frontend_path / "index.html"
|
||||
if not index_path.exists():
|
||||
abort(404)
|
||||
return send_from_directory(frontend_path, "index.html")
|
||||
@ -1,4 +0,0 @@
|
||||
def init_services(app) -> None:
|
||||
"""初始化Services"""
|
||||
return None
|
||||
|
||||
@ -0,0 +1,257 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from functools import wraps
|
||||
from inspect import isawaitable
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from iti.auth import Actor
|
||||
from iti.service_client import ServiceClientError, service_client
|
||||
|
||||
|
||||
logger = logging.getLogger("iti.audit")
|
||||
SENSITIVE_KEYS = {"password", "token", "authorization", "secret", "refreshToken", "refresh_token"}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AuditEvent:
|
||||
type: str
|
||||
title: str
|
||||
success: bool = True
|
||||
actor_id: str | None = None
|
||||
actor_type: str | None = None
|
||||
method: str | None = None
|
||||
path: str | None = None
|
||||
ip: str | None = None
|
||||
user_agent: str | None = None
|
||||
target_type: str | None = None
|
||||
target_id: str | None = None
|
||||
diff: dict[str, Any] | None = None
|
||||
desc: str | None = None
|
||||
error: str | None = None
|
||||
trace_id: str | None = None
|
||||
occurred_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
|
||||
def payload(self) -> dict[str, Any]:
|
||||
return {key: value for key, value in asdict(self).items() if value is not None}
|
||||
|
||||
|
||||
class AuditDispatcher:
|
||||
def __init__(self, app) -> None:
|
||||
self.app = app
|
||||
config = app.state.config
|
||||
self.enabled = bool(config.audit_enabled)
|
||||
self.service_name = config.audit_service_name
|
||||
self.batch_size = max(int(config.audit_batch_size), 1)
|
||||
self.flush_interval = float(config.audit_flush_interval_seconds)
|
||||
self._queue: queue.Queue[AuditEvent] = queue.Queue(maxsize=config.audit_queue_size)
|
||||
self._stop = threading.Event()
|
||||
self._thread: threading.Thread | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
if not self.enabled:
|
||||
return
|
||||
if self._thread and self._thread.is_alive():
|
||||
return
|
||||
self._thread = threading.Thread(target=self._loop, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
self._stop.set()
|
||||
if self._thread:
|
||||
self._thread.join(timeout=3)
|
||||
|
||||
def emit(self, event: AuditEvent) -> None:
|
||||
if not self.enabled:
|
||||
return
|
||||
try:
|
||||
self._queue.put_nowait(event)
|
||||
except queue.Full:
|
||||
try:
|
||||
self._queue.get_nowait()
|
||||
self._queue.put_nowait(event)
|
||||
except queue.Empty:
|
||||
logger.warning("audit queue full and event dropped")
|
||||
|
||||
def _loop(self) -> None:
|
||||
while not self._stop.is_set():
|
||||
batch = self._drain()
|
||||
if batch:
|
||||
self._send(batch)
|
||||
self._stop.wait(self.flush_interval)
|
||||
batch = self._drain()
|
||||
if batch:
|
||||
self._send(batch)
|
||||
|
||||
def _drain(self) -> list[AuditEvent]:
|
||||
batch: list[AuditEvent] = []
|
||||
for _ in range(self.batch_size):
|
||||
try:
|
||||
batch.append(self._queue.get_nowait())
|
||||
except queue.Empty:
|
||||
break
|
||||
return batch
|
||||
|
||||
def _send(self, batch: list[AuditEvent]) -> None:
|
||||
try:
|
||||
client = service_client(self.app, self.service_name)
|
||||
client.post("/internal/audit/events", json={"events": [item.payload() for item in batch]})
|
||||
except ServiceClientError as exc:
|
||||
logger.warning("audit send failed: %s", exc)
|
||||
|
||||
|
||||
def init_audit(app) -> AuditDispatcher:
|
||||
dispatcher = AuditDispatcher(app)
|
||||
app.state.audit_dispatcher = dispatcher
|
||||
return dispatcher
|
||||
|
||||
|
||||
def audit_operation(
|
||||
request: Request,
|
||||
*,
|
||||
title: str,
|
||||
target_type: str | None = None,
|
||||
target_id: str | None = None,
|
||||
before: dict[str, Any] | None = None,
|
||||
after: dict[str, Any] | None = None,
|
||||
success: bool = True,
|
||||
desc: str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
dispatcher = getattr(request.app.state, "audit_dispatcher", None)
|
||||
if dispatcher is None:
|
||||
return
|
||||
actor = getattr(request.state, "actor", None)
|
||||
dispatcher.emit(
|
||||
AuditEvent(
|
||||
type="operation",
|
||||
title=title,
|
||||
success=success,
|
||||
actor_id=getattr(actor, "id", None),
|
||||
actor_type=getattr(actor, "type", None),
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
ip=request.client.host if request.client else None,
|
||||
user_agent=request.headers.get("user-agent"),
|
||||
target_type=target_type,
|
||||
target_id=target_id,
|
||||
diff=build_diff(before, after) if before is not None or after is not None else None,
|
||||
desc=desc,
|
||||
error=error,
|
||||
trace_id=getattr(request.state, "trace_id", None),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def operation_log(
|
||||
title: str,
|
||||
*,
|
||||
target_type: str | None = None,
|
||||
) -> Callable:
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
request = _find_request(args, kwargs)
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
if isawaitable(result):
|
||||
result = await result
|
||||
if request is not None:
|
||||
audit_operation(
|
||||
request,
|
||||
title=title,
|
||||
target_type=target_type,
|
||||
)
|
||||
return result
|
||||
except Exception as exc:
|
||||
if request is not None:
|
||||
audit_operation(
|
||||
request,
|
||||
title=title,
|
||||
target_type=target_type,
|
||||
success=False,
|
||||
error=str(exc),
|
||||
)
|
||||
raise
|
||||
|
||||
return async_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _find_request(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Request | None:
|
||||
for value in list(args) + list(kwargs.values()):
|
||||
if isinstance(value, Request):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def audit_login(
|
||||
request: Request,
|
||||
*,
|
||||
title: str = "登录",
|
||||
actor: Actor | None = None,
|
||||
success: bool = True,
|
||||
desc: str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
dispatcher = getattr(request.app.state, "audit_dispatcher", None)
|
||||
if dispatcher is None:
|
||||
return
|
||||
actor = actor or getattr(request.state, "actor", None)
|
||||
dispatcher.emit(
|
||||
AuditEvent(
|
||||
type="login",
|
||||
title=title,
|
||||
success=success,
|
||||
actor_id=getattr(actor, "id", None),
|
||||
actor_type=getattr(actor, "type", None),
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
ip=request.client.host if request.client else None,
|
||||
user_agent=request.headers.get("user-agent"),
|
||||
desc=desc,
|
||||
error=error,
|
||||
trace_id=getattr(request.state, "trace_id", None),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def build_diff(before: dict[str, Any] | None, after: dict[str, Any] | None) -> dict[str, Any]:
|
||||
before = before or {}
|
||||
after = after or {}
|
||||
keys = sorted(set(before) | set(after))
|
||||
changes = {}
|
||||
for key in keys:
|
||||
old = before.get(key)
|
||||
new = after.get(key)
|
||||
if old != new:
|
||||
changes[key] = {"before": sanitize_value(key, old), "after": sanitize_value(key, new)}
|
||||
return changes
|
||||
|
||||
|
||||
def sanitize(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
result = {}
|
||||
for key, item in value.items():
|
||||
if str(key) in SENSITIVE_KEYS:
|
||||
result[key] = "***"
|
||||
else:
|
||||
result[key] = sanitize(item)
|
||||
return result
|
||||
if isinstance(value, list):
|
||||
return [sanitize(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def sanitize_value(key: str, value: Any) -> Any:
|
||||
if key in SENSITIVE_KEYS:
|
||||
return "***"
|
||||
return sanitize(value)
|
||||
@ -0,0 +1,33 @@
|
||||
from .jwt import create_access_token, create_refresh_token, decode_token
|
||||
from .permissions import (
|
||||
Actor,
|
||||
Principal,
|
||||
PermissionProvider,
|
||||
StaticPermissionProvider,
|
||||
get_principal,
|
||||
get_service_actor,
|
||||
require_actor,
|
||||
require_permission,
|
||||
require_service,
|
||||
require_service_scope,
|
||||
require_user,
|
||||
set_permission_provider,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Actor",
|
||||
"PermissionProvider",
|
||||
"Principal",
|
||||
"StaticPermissionProvider",
|
||||
"create_access_token",
|
||||
"create_refresh_token",
|
||||
"decode_token",
|
||||
"get_principal",
|
||||
"get_service_actor",
|
||||
"require_actor",
|
||||
"require_permission",
|
||||
"require_service",
|
||||
"require_service_scope",
|
||||
"require_user",
|
||||
"set_permission_provider",
|
||||
]
|
||||
@ -0,0 +1,66 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from iti.config import BaseConfig
|
||||
from iti.exceptions import Unauthorized
|
||||
|
||||
|
||||
def _create_token(
|
||||
subject: str,
|
||||
config: BaseConfig,
|
||||
*,
|
||||
token_type: str,
|
||||
expires_seconds: int,
|
||||
claims: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
now = datetime.now(timezone.utc)
|
||||
payload = {
|
||||
"sub": subject,
|
||||
"type": token_type,
|
||||
"iat": int(now.timestamp()),
|
||||
"exp": int((now + timedelta(seconds=expires_seconds)).timestamp()),
|
||||
**(claims or {}),
|
||||
}
|
||||
return jwt.encode(payload, config.jwt_secret_key, algorithm=config.jwt_algorithm)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
subject: str,
|
||||
config: BaseConfig,
|
||||
claims: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
return _create_token(
|
||||
subject,
|
||||
config,
|
||||
token_type="access",
|
||||
expires_seconds=config.jwt_access_token_expires_seconds,
|
||||
claims=claims,
|
||||
)
|
||||
|
||||
|
||||
def create_refresh_token(
|
||||
subject: str,
|
||||
config: BaseConfig,
|
||||
claims: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
return _create_token(
|
||||
subject,
|
||||
config,
|
||||
token_type="refresh",
|
||||
expires_seconds=config.jwt_refresh_token_expires_seconds,
|
||||
claims=claims,
|
||||
)
|
||||
|
||||
|
||||
def decode_token(token: str, config: BaseConfig, *, token_type: str | None = None) -> dict:
|
||||
try:
|
||||
payload = jwt.decode(token, config.jwt_secret_key, algorithms=[config.jwt_algorithm])
|
||||
except JWTError as exc:
|
||||
raise Unauthorized("无效的令牌") from exc
|
||||
if token_type is not None and payload.get("type") != token_type:
|
||||
raise Unauthorized("令牌类型错误")
|
||||
return payload
|
||||
@ -0,0 +1,196 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Protocol
|
||||
|
||||
from fastapi import Depends, Request
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from iti.exceptions import PermissionDenied, Unauthorized
|
||||
|
||||
from .jwt import decode_token
|
||||
|
||||
|
||||
bearer_scheme = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Principal:
|
||||
id: str
|
||||
type: str = "user"
|
||||
permissions: frozenset[str] = field(default_factory=frozenset)
|
||||
roles: frozenset[str] = field(default_factory=frozenset)
|
||||
scopes: frozenset[str] = field(default_factory=frozenset)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Actor:
|
||||
id: str
|
||||
type: str
|
||||
principal: Principal | None = None
|
||||
service_name: str | None = None
|
||||
|
||||
|
||||
class PermissionProvider(Protocol):
|
||||
def load_principal(self, principal_id: str, request: Request) -> Principal | None:
|
||||
...
|
||||
|
||||
def has_permission(self, principal: Principal, code: str) -> bool:
|
||||
...
|
||||
|
||||
def has_scope(self, principal: Principal, scope: str) -> bool:
|
||||
...
|
||||
|
||||
|
||||
class StaticPermissionProvider:
|
||||
def load_principal(self, principal_id: str, request: Request) -> Principal | None:
|
||||
return Principal(id=principal_id)
|
||||
|
||||
def has_permission(self, principal: Principal, code: str) -> bool:
|
||||
return code in principal.permissions
|
||||
|
||||
def has_scope(self, principal: Principal, scope: str) -> bool:
|
||||
return scope in principal.scopes
|
||||
|
||||
|
||||
permission_provider: PermissionProvider = StaticPermissionProvider()
|
||||
|
||||
|
||||
def set_permission_provider(provider: PermissionProvider) -> None:
|
||||
global permission_provider
|
||||
permission_provider = provider
|
||||
|
||||
|
||||
def get_principal(
|
||||
request: Request,
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
|
||||
) -> Principal | None:
|
||||
if credentials is None:
|
||||
return None
|
||||
config = request.app.state.config
|
||||
payload = decode_token(credentials.credentials, config, token_type="access")
|
||||
principal_id = payload.get("sub")
|
||||
if not principal_id:
|
||||
raise Unauthorized("无效的令牌")
|
||||
provider = getattr(request.app.state, "permission_provider", permission_provider)
|
||||
principal = provider.load_principal(principal_id, request)
|
||||
if principal is None:
|
||||
raise Unauthorized("用户不存在或已失效")
|
||||
request.state.principal = principal
|
||||
request.state.actor = Actor(id=principal.id, type="user", principal=principal)
|
||||
return principal
|
||||
|
||||
|
||||
def require_user(
|
||||
principal: Principal | None = Depends(get_principal),
|
||||
) -> Principal:
|
||||
if principal is None:
|
||||
raise Unauthorized("缺少令牌参数 Authorization Bearer")
|
||||
return principal
|
||||
|
||||
|
||||
def require_permission(code: str) -> Callable:
|
||||
def dependency(
|
||||
request: Request,
|
||||
principal: Principal = Depends(require_user),
|
||||
) -> Principal:
|
||||
provider = getattr(request.app.state, "permission_provider", permission_provider)
|
||||
if not provider.has_permission(principal, code):
|
||||
raise PermissionDenied("权限不足", code=403)
|
||||
return principal
|
||||
|
||||
return dependency
|
||||
|
||||
|
||||
def require_service_scope(scope: str) -> Callable:
|
||||
def dependency(
|
||||
request: Request,
|
||||
principal: Principal = Depends(require_user),
|
||||
) -> Principal:
|
||||
provider = getattr(request.app.state, "permission_provider", permission_provider)
|
||||
if not provider.has_scope(principal, scope):
|
||||
raise PermissionDenied("服务权限不足", code=403)
|
||||
return principal
|
||||
|
||||
return dependency
|
||||
|
||||
|
||||
def get_service_actor(
|
||||
request: Request,
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
|
||||
) -> Actor | None:
|
||||
if credentials is None:
|
||||
return None
|
||||
service_name = match_service_token(
|
||||
getattr(request.app.state.config, "service_tokens", {}),
|
||||
credentials.credentials,
|
||||
)
|
||||
if service_name is None:
|
||||
return None
|
||||
actor = Actor(id=service_name, type="service", service_name=service_name)
|
||||
request.state.actor = actor
|
||||
return actor
|
||||
|
||||
|
||||
def match_service_token(tokens: Mapping[str, str], token: str) -> str | None:
|
||||
for name, expected in tokens.items():
|
||||
if expected and token == expected:
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
def require_service(service_name: str | None = None) -> Callable:
|
||||
def dependency(
|
||||
request: Request,
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
|
||||
) -> Actor:
|
||||
actor = get_service_actor(request, credentials)
|
||||
if actor is None:
|
||||
raise Unauthorized("无效的服务令牌")
|
||||
if service_name is not None and actor.service_name != service_name:
|
||||
raise PermissionDenied("服务权限不足", code=403)
|
||||
return actor
|
||||
|
||||
return dependency
|
||||
|
||||
|
||||
def require_actor(
|
||||
*,
|
||||
permissions: list[str] | tuple[str, ...] | None = None,
|
||||
allow_service: bool = False,
|
||||
service_name: str | None = None,
|
||||
) -> Callable:
|
||||
required_permissions = tuple(permissions or ())
|
||||
|
||||
def dependency(
|
||||
request: Request,
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
|
||||
) -> Actor:
|
||||
if credentials is None:
|
||||
raise Unauthorized("缺少令牌参数 Authorization Bearer")
|
||||
|
||||
service_actor = get_service_actor(request, credentials)
|
||||
if service_actor is not None:
|
||||
if not allow_service:
|
||||
raise PermissionDenied("服务权限不足", code=403)
|
||||
if service_name is not None and service_actor.service_name != service_name:
|
||||
raise PermissionDenied("服务权限不足", code=403)
|
||||
return service_actor
|
||||
|
||||
principal = get_principal(request, credentials)
|
||||
if principal is None:
|
||||
raise Unauthorized("缺少令牌参数 Authorization Bearer")
|
||||
provider = getattr(request.app.state, "permission_provider", permission_provider)
|
||||
missing = [
|
||||
code
|
||||
for code in required_permissions
|
||||
if not provider.has_permission(principal, code)
|
||||
]
|
||||
if missing:
|
||||
raise PermissionDenied("权限不足", code=403)
|
||||
actor = Actor(id=principal.id, type="user", principal=principal)
|
||||
request.state.actor = actor
|
||||
return actor
|
||||
|
||||
return dependency
|
||||
@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheItem:
|
||||
value: Any
|
||||
expires_at: float | None
|
||||
|
||||
|
||||
class CacheManager:
|
||||
def __init__(self, *, default_timeout: int = 300) -> None:
|
||||
self.default_timeout = default_timeout
|
||||
self._items: dict[str, CacheItem] = {}
|
||||
|
||||
def get(self, key: str) -> Any:
|
||||
item = self._items.get(key)
|
||||
if item is None:
|
||||
return None
|
||||
if item.expires_at is not None and item.expires_at < time.time():
|
||||
self._items.pop(key, None)
|
||||
return None
|
||||
return item.value
|
||||
|
||||
def set(self, key: str, value: Any, timeout: int | None = None) -> None:
|
||||
timeout = self.default_timeout if timeout is None else timeout
|
||||
expires_at = None if timeout <= 0 else time.time() + timeout
|
||||
self._items[key] = CacheItem(value=value, expires_at=expires_at)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self._items.pop(key, None)
|
||||
|
||||
def clear(self) -> None:
|
||||
self._items.clear()
|
||||
@ -1,6 +1,38 @@
|
||||
import click
|
||||
|
||||
from iti.config import get_config
|
||||
|
||||
|
||||
@click.group()
|
||||
def iti_cli() -> None:
|
||||
"""iTi-Flask framework commands."""
|
||||
|
||||
|
||||
@iti_cli.command("config")
|
||||
@click.option("--env", "env_name", default=None, help="Config environment name.")
|
||||
def show_config(env_name: str | None) -> None:
|
||||
config = get_config(env_name)
|
||||
click.echo(f"app_env={config.app_env}")
|
||||
click.echo(f"database_url={config.database_url}")
|
||||
click.echo(f"health_enabled={config.health_enabled}")
|
||||
click.echo(f"ready_check_db={config.ready_check_db}")
|
||||
|
||||
|
||||
@iti_cli.command("routes")
|
||||
@click.argument("app_import")
|
||||
def list_routes(app_import: str) -> None:
|
||||
app = _load_app(app_import)
|
||||
for route in app.routes:
|
||||
methods = ",".join(sorted(getattr(route, "methods", []) or []))
|
||||
path = getattr(route, "path", "")
|
||||
name = getattr(route, "name", "")
|
||||
click.echo(f"{methods:20} {path:40} {name}")
|
||||
|
||||
|
||||
def _load_app(app_import: str):
|
||||
module_name, _, attr_name = app_import.partition(":")
|
||||
if not module_name or not attr_name:
|
||||
raise click.ClickException("app import must use module:attribute")
|
||||
module = __import__(module_name, fromlist=[attr_name])
|
||||
app = getattr(module, attr_name)
|
||||
return app() if callable(app) else app
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from werkzeug.exceptions import HTTPException
|
||||
from iti.exceptions import BizError
|
||||
|
||||
|
||||
class BizException(HTTPException):
|
||||
class BizException(BizError):
|
||||
def __init__(self, message: str = "操作失败", code: int = 500, data=None):
|
||||
self.message = message
|
||||
self.code = code
|
||||
@ -1,4 +1,7 @@
|
||||
class PermissionDeniedException(Exception):
|
||||
from iti.exceptions import PermissionDenied
|
||||
|
||||
|
||||
class PermissionDeniedException(PermissionDenied):
|
||||
"""
|
||||
权限拒绝异常
|
||||
"""
|
||||
@ -0,0 +1,13 @@
|
||||
from .base import AuditMixin, Base, IdMixin, TimestampMixin
|
||||
from .session import configure_db, get_db, reset_db, session_scope
|
||||
|
||||
__all__ = [
|
||||
"AuditMixin",
|
||||
"Base",
|
||||
"IdMixin",
|
||||
"TimestampMixin",
|
||||
"configure_db",
|
||||
"get_db",
|
||||
"reset_db",
|
||||
"session_scope",
|
||||
]
|
||||
@ -0,0 +1,54 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, MetaData, String
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
naming_convention = {
|
||||
"ix": "ix_%(column_0_label)s",
|
||||
"uq": "uq_%(table_name)s_%(column_0_name)s",
|
||||
"ck": "ck_%(table_name)s_%(column_0_name)s",
|
||||
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
||||
"pk": "pk_%(table_name)s",
|
||||
}
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
metadata = MetaData(naming_convention=naming_convention)
|
||||
|
||||
|
||||
class IdMixin:
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
primary_key=True,
|
||||
default=lambda: uuid.uuid4().hex,
|
||||
comment="标识",
|
||||
)
|
||||
|
||||
|
||||
class TimestampMixin:
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
default=datetime.now,
|
||||
nullable=False,
|
||||
comment="创建时间",
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
default=datetime.now,
|
||||
onupdate=datetime.now,
|
||||
nullable=False,
|
||||
comment="更新时间",
|
||||
)
|
||||
|
||||
|
||||
class AuditMixin:
|
||||
created_by: Mapped[str | None] = mapped_column(
|
||||
String(36), nullable=True, index=True, comment="创建人"
|
||||
)
|
||||
updated_by: Mapped[str | None] = mapped_column(
|
||||
String(36), nullable=True, index=True, comment="更新人"
|
||||
)
|
||||
@ -0,0 +1,76 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
|
||||
from fastapi import Request
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
|
||||
engine: Engine | None = None
|
||||
SessionLocal: sessionmaker[Session] | None = None
|
||||
|
||||
|
||||
def configure_db(
|
||||
database_url: str,
|
||||
*,
|
||||
echo: bool = False,
|
||||
pool_pre_ping: bool = True,
|
||||
) -> tuple[Engine, sessionmaker[Session]]:
|
||||
global engine, SessionLocal
|
||||
engine_kwargs = {
|
||||
"echo": echo,
|
||||
"pool_pre_ping": pool_pre_ping,
|
||||
"future": True,
|
||||
}
|
||||
if database_url.startswith("sqlite"):
|
||||
engine_kwargs["connect_args"] = {"check_same_thread": False}
|
||||
if database_url in {"sqlite://", "sqlite:///:memory:", "sqlite+pysqlite:///:memory:"}:
|
||||
engine_kwargs["poolclass"] = StaticPool
|
||||
engine = create_engine(
|
||||
database_url,
|
||||
**engine_kwargs,
|
||||
)
|
||||
SessionLocal = sessionmaker(bind=engine, autoflush=False, expire_on_commit=False)
|
||||
return engine, SessionLocal
|
||||
|
||||
|
||||
def get_db(request: Request) -> Generator[Session, None, None]:
|
||||
factory = getattr(request.app.state, "db_sessionmaker", None)
|
||||
if factory is None:
|
||||
raise RuntimeError("database is not configured")
|
||||
db = factory()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def session_scope() -> Generator[Session, None, None]:
|
||||
if SessionLocal is None:
|
||||
raise RuntimeError("database is not configured")
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
db.commit()
|
||||
except Exception:
|
||||
db.rollback()
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def ping_database(db: Session) -> None:
|
||||
db.execute(text("SELECT 1"))
|
||||
|
||||
|
||||
def reset_db() -> None:
|
||||
global engine, SessionLocal
|
||||
if engine is not None:
|
||||
engine.dispose()
|
||||
engine = None
|
||||
SessionLocal = None
|
||||
@ -0,0 +1,3 @@
|
||||
from .bus import EventBus, eventbus
|
||||
|
||||
__all__ = ["EventBus", "eventbus"]
|
||||
@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
logger = logging.getLogger("iti.events")
|
||||
|
||||
|
||||
class EventBus:
|
||||
def __init__(self, *, max_workers: int = 10) -> None:
|
||||
self._handlers: dict[str, list[Callable]] = defaultdict(list)
|
||||
self._executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
|
||||
def on(self, event_name: str):
|
||||
def decorator(func: Callable) -> Callable:
|
||||
self.register_handler(event_name, func)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def register_handler(self, event_name: str, handler: Callable) -> None:
|
||||
self._handlers[event_name].append(handler)
|
||||
|
||||
def emit(self, event_name: str, *args, async_mode: bool = False, **kwargs) -> None:
|
||||
for handler in list(self._handlers.get(event_name, [])):
|
||||
if async_mode:
|
||||
self._executor.submit(self._run_handler, handler, *args, **kwargs)
|
||||
else:
|
||||
self._run_handler(handler, *args, **kwargs)
|
||||
|
||||
def _run_handler(self, handler: Callable, *args, **kwargs) -> None:
|
||||
try:
|
||||
handler(*args, **kwargs)
|
||||
except Exception:
|
||||
logger.exception("event handler failed: %s", handler)
|
||||
|
||||
|
||||
eventbus = EventBus()
|
||||
@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ItiError(Exception):
|
||||
status_code = 500
|
||||
message = "服务器错误"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str | None = None,
|
||||
*,
|
||||
code: int | None = None,
|
||||
status_code: int | None = None,
|
||||
data: Any = None,
|
||||
) -> None:
|
||||
self.message = message or self.message
|
||||
self.code = code or status_code or self.status_code
|
||||
self.status_code = status_code or self.status_code
|
||||
self.data = data
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class BizError(ItiError):
|
||||
status_code = 400
|
||||
message = "业务错误"
|
||||
|
||||
|
||||
class PermissionDenied(ItiError):
|
||||
status_code = 403
|
||||
message = "权限不足"
|
||||
|
||||
|
||||
class Unauthorized(ItiError):
|
||||
status_code = 401
|
||||
message = "未认证"
|
||||
@ -0,0 +1,3 @@
|
||||
from .routes import router
|
||||
|
||||
__all__ = ["router"]
|
||||
@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from iti.db import get_db
|
||||
from iti.db.session import ping_database
|
||||
|
||||
router = APIRouter(tags=["health"])
|
||||
|
||||
|
||||
@router.get("/health", include_in_schema=False)
|
||||
def health() -> dict[str, str]:
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.get("/ready", include_in_schema=False)
|
||||
def ready(request: Request, db: Session = Depends(get_db)):
|
||||
config = request.app.state.config
|
||||
if config.ready_check_db:
|
||||
try:
|
||||
ping_database(db)
|
||||
except Exception:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
content={"status": "error"},
|
||||
)
|
||||
return {"status": "ok"}
|
||||
@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import Depends, Request
|
||||
|
||||
from iti.exceptions import BizError
|
||||
|
||||
|
||||
class SimpleLimiter:
|
||||
def __init__(self, *, enabled: bool = True) -> None:
|
||||
self.enabled = enabled
|
||||
self._hits: dict[str, deque[float]] = defaultdict(deque)
|
||||
|
||||
def limit(self, rule: str) -> Callable:
|
||||
count, seconds = parse_rule(rule)
|
||||
|
||||
def dependency(request: Request) -> None:
|
||||
if not self.enabled:
|
||||
return
|
||||
client = request.client.host if request.client else "unknown"
|
||||
key = f"{client}:{request.url.path}:{rule}"
|
||||
now = time.time()
|
||||
hits = self._hits[key]
|
||||
while hits and hits[0] <= now - seconds:
|
||||
hits.popleft()
|
||||
if len(hits) >= count:
|
||||
raise BizError("请求过于频繁,请稍后再试", code=429, status_code=429)
|
||||
hits.append(now)
|
||||
|
||||
return dependency
|
||||
|
||||
|
||||
def parse_rule(rule: str) -> tuple[int, int]:
|
||||
parts = rule.strip().split()
|
||||
if len(parts) < 3 or parts[1] != "per":
|
||||
raise ValueError(f"invalid rate limit rule: {rule}")
|
||||
count = int(parts[0])
|
||||
unit = parts[2].lower()
|
||||
if unit.startswith("second"):
|
||||
seconds = 1
|
||||
elif unit.startswith("minute"):
|
||||
seconds = 60
|
||||
elif unit.startswith("hour"):
|
||||
seconds = 3600
|
||||
else:
|
||||
raise ValueError(f"invalid rate limit unit: {unit}")
|
||||
return count, seconds
|
||||
|
||||
|
||||
def limit(rule: str) -> Callable:
|
||||
def dependency(request: Request) -> None:
|
||||
limiter = getattr(request.app.state, "limiter", SimpleLimiter())
|
||||
return limiter.limit(rule)(request)
|
||||
|
||||
return Depends(dependency)
|
||||
@ -0,0 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from iti.config import BaseConfig
|
||||
|
||||
|
||||
class SafeFormatter(logging.Formatter):
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
for key in ("trace_id", "request_id", "actor_type", "actor_id", "response_code"):
|
||||
if not hasattr(record, key):
|
||||
setattr(record, key, "-")
|
||||
return super().format(record)
|
||||
|
||||
|
||||
def configure_logging(config: BaseConfig) -> None:
|
||||
level = getattr(logging, config.log_level.upper(), logging.INFO)
|
||||
formatter = SafeFormatter(
|
||||
"%(asctime)s %(levelname)s %(name)s "
|
||||
"trace=%(trace_id)s actor=%(actor_type)s:%(actor_id)s code=%(response_code)s - %(message)s"
|
||||
)
|
||||
|
||||
root_logger = logging.getLogger("iti")
|
||||
root_logger.setLevel(level)
|
||||
root_logger.handlers.clear()
|
||||
root_logger.propagate = False
|
||||
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(level)
|
||||
console_handler.setFormatter(formatter)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
error_logger = logging.getLogger("iti.error")
|
||||
error_logger.setLevel(logging.ERROR)
|
||||
error_logger.handlers.clear()
|
||||
error_logger.propagate = False
|
||||
|
||||
if config.log_file_enabled:
|
||||
log_dir = Path(config.log_dir)
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
app_handler = RotatingFileHandler(
|
||||
log_dir / "app.log",
|
||||
encoding="utf-8",
|
||||
maxBytes=config.log_max_bytes,
|
||||
backupCount=config.log_backup_count,
|
||||
)
|
||||
app_handler.setLevel(level)
|
||||
app_handler.setFormatter(formatter)
|
||||
root_logger.addHandler(app_handler)
|
||||
|
||||
error_handler = RotatingFileHandler(
|
||||
log_dir / "error.log",
|
||||
encoding="utf-8",
|
||||
maxBytes=config.log_max_bytes,
|
||||
backupCount=config.log_backup_count,
|
||||
)
|
||||
error_handler.setLevel(logging.ERROR)
|
||||
error_handler.setFormatter(formatter)
|
||||
root_logger.addHandler(error_handler)
|
||||
error_logger.addHandler(error_handler)
|
||||
|
||||
|
||||
def log_extra(request: Any | None = None) -> dict[str, Any]:
|
||||
if request is None:
|
||||
return {
|
||||
"trace_id": "-",
|
||||
"request_id": "-",
|
||||
"actor_type": "-",
|
||||
"actor_id": "-",
|
||||
"response_code": "-",
|
||||
}
|
||||
actor = getattr(request.state, "actor", None)
|
||||
principal = getattr(request.state, "principal", None)
|
||||
return {
|
||||
"trace_id": getattr(request.state, "trace_id", "-"),
|
||||
"request_id": getattr(request.state, "request_id", "-"),
|
||||
"actor_type": getattr(actor, "type", None) or ("user" if principal else "-"),
|
||||
"actor_id": getattr(actor, "id", None) or getattr(principal, "id", "-"),
|
||||
"response_code": getattr(request.state, "response_code", "-"),
|
||||
}
|
||||
@ -0,0 +1,4 @@
|
||||
from .auto import raw_response
|
||||
from .envelope import Envelope, fail, ok, page, pagination
|
||||
|
||||
__all__ = ["Envelope", "fail", "ok", "page", "pagination", "raw_response"]
|
||||
@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import fnmatch
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
|
||||
RAW_RESPONSE_ATTR = "__iti_raw_response__"
|
||||
ENVELOPE_FIELDS = {"data", "code", "message"}
|
||||
|
||||
|
||||
def raw_response(func: Callable) -> Callable:
|
||||
setattr(func, RAW_RESPONSE_ATTR, True)
|
||||
return func
|
||||
|
||||
|
||||
def is_envelope_payload(value: Any) -> bool:
|
||||
return isinstance(value, dict) and ENVELOPE_FIELDS.issubset(value.keys())
|
||||
|
||||
|
||||
def is_raw_response_request(request: Request, raw_paths: Iterable[str]) -> bool:
|
||||
endpoint = request.scope.get("endpoint")
|
||||
if endpoint is not None and getattr(endpoint, RAW_RESPONSE_ATTR, False):
|
||||
return True
|
||||
|
||||
path = request.url.path
|
||||
for pattern in raw_paths:
|
||||
if pattern == path or fnmatch.fnmatch(path, pattern):
|
||||
return True
|
||||
return False
|
||||
@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from iti.schemas import to_camel
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Envelope(BaseModel, Generic[T]):
|
||||
model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True)
|
||||
|
||||
data: T | None = None
|
||||
code: int = 200
|
||||
message: str = "成功"
|
||||
|
||||
|
||||
def ok(data: Any = None, message: str = "成功", code: int = 200) -> dict[str, Any]:
|
||||
return {"data": data, "code": code, "message": message}
|
||||
|
||||
|
||||
def fail(
|
||||
message: str = "操作失败", code: int = 500, data: Any = None
|
||||
) -> dict[str, Any]:
|
||||
return {"data": data, "code": code, "message": message}
|
||||
|
||||
|
||||
def pagination(
|
||||
*,
|
||||
page: int = 1,
|
||||
size: int = 10,
|
||||
total: int = 0,
|
||||
pages: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
pages = pages if pages is not None else math.ceil(total / size) if size > 0 else 0
|
||||
return {
|
||||
"page": page,
|
||||
"size": size,
|
||||
"pages": pages,
|
||||
"total": total,
|
||||
}
|
||||
|
||||
|
||||
def page(
|
||||
items: list[Any],
|
||||
page_info: dict[str, Any] | None = None,
|
||||
message: str = "成功",
|
||||
code: int = 200,
|
||||
) -> dict[str, Any]:
|
||||
return ok({"items": items, "page": page_info or pagination()}, message, code)
|
||||
@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_serializer
|
||||
|
||||
|
||||
def to_camel(value: str) -> str:
|
||||
head, *tail = value.split("_")
|
||||
return head + "".join(item[:1].upper() + item[1:] for item in tail)
|
||||
|
||||
|
||||
class ItiModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
alias_generator=to_camel,
|
||||
populate_by_name=True,
|
||||
from_attributes=True,
|
||||
use_enum_values=True,
|
||||
)
|
||||
|
||||
@field_serializer("*", when_used="json")
|
||||
def serialize_datetime(self, value: Any) -> Any:
|
||||
if isinstance(value, datetime):
|
||||
return value.isoformat()
|
||||
return value
|
||||
@ -0,0 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from iti.common.enums import StorageTypeEnum
|
||||
|
||||
from .interface import StorageInterface
|
||||
from .local import LocalStorage
|
||||
|
||||
|
||||
class StorageManager:
|
||||
_instances: dict[str, StorageInterface] = {}
|
||||
|
||||
@classmethod
|
||||
def get_storage(
|
||||
cls,
|
||||
storage_type: Optional[Union[str, StorageTypeEnum]] = None,
|
||||
*,
|
||||
config: dict | None = None,
|
||||
base_dir: str | os.PathLike | None = None,
|
||||
) -> StorageInterface:
|
||||
config = config or {}
|
||||
storage_type_str = cls._normalize_storage_type(storage_type, config)
|
||||
if storage_type_str not in cls._instances:
|
||||
cls._instances[storage_type_str] = cls._create_storage(
|
||||
storage_type_str,
|
||||
config,
|
||||
base_dir=base_dir,
|
||||
)
|
||||
return cls._instances[storage_type_str]
|
||||
|
||||
@staticmethod
|
||||
def _normalize_storage_type(
|
||||
storage_type: Optional[Union[str, StorageTypeEnum]],
|
||||
config: dict,
|
||||
) -> str:
|
||||
if storage_type is None:
|
||||
return config.get("DEFAULT_STORAGE_TYPE", StorageTypeEnum.LOCAL.value)
|
||||
if isinstance(storage_type, StorageTypeEnum):
|
||||
return storage_type.value
|
||||
return storage_type
|
||||
|
||||
@staticmethod
|
||||
def _create_storage(
|
||||
storage_type: str,
|
||||
config: dict,
|
||||
*,
|
||||
base_dir: str | os.PathLike | None = None,
|
||||
) -> StorageInterface:
|
||||
if storage_type == StorageTypeEnum.LOCAL.value:
|
||||
local_config = dict(config.get("LOCAL", {}))
|
||||
if not local_config.get("base_path"):
|
||||
local_config["base_path"] = str(Path(base_dir or Path.cwd()) / "runtime" / "uploads")
|
||||
return LocalStorage(local_config)
|
||||
|
||||
if storage_type == StorageTypeEnum.ALIYUN_OSS.value:
|
||||
from .aliyun_oss import AliyunOSSStorage
|
||||
|
||||
return AliyunOSSStorage(config.get("ALIYUN_OSS", {}))
|
||||
if storage_type == StorageTypeEnum.TENCENT_COS.value:
|
||||
from .tencent_cos import TencentCOSStorage
|
||||
|
||||
return TencentCOSStorage(config.get("TENCENT_COS", {}))
|
||||
if storage_type == StorageTypeEnum.QINIU_KODO.value:
|
||||
from .qiniu_kodo import QiniuKodoStorage
|
||||
|
||||
return QiniuKodoStorage(config.get("QINIU_KODO", {}))
|
||||
if storage_type == StorageTypeEnum.HUAWEI_OBS.value:
|
||||
from .huawei_obs import HuaweiOBSStorage
|
||||
|
||||
return HuaweiOBSStorage(config.get("HUAWEI_OBS", {}))
|
||||
if storage_type == StorageTypeEnum.MINIO.value:
|
||||
from .minio_storage import MinIOStorage
|
||||
|
||||
return MinIOStorage(config.get("MINIO", {}))
|
||||
if storage_type == StorageTypeEnum.AWS_S3.value:
|
||||
raise NotImplementedError("AWS S3 适配器尚未实现")
|
||||
raise ValueError(f"未支持的存储类型: {storage_type}")
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue